mindspore 2.5.0__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -5845,13 +5845,13 @@ def correlate(a, v, mode='valid'):
5845
5845
  >>> from mindspore import Tensor
5846
5846
  >>> output = mnp.correlate(Tensor([1., 2., 3.]), Tensor([0., 1., 0.5]))
5847
5847
  >>> print(output)
5848
- [3.5]
5848
+ Tensor(shape=[1], dtype=Float32, value= [ 3.50000000e+00])
5849
5849
  >>> output = mnp.correlate(Tensor([1., 2., 3.]), Tensor([0., 1., 0.5]), mode="same")
5850
5850
  >>> print(output)
5851
- [2. 3.5 3. ]
5852
- >>> output = mnp.correlate(Tensor([1., 2., 3., 4., 5.]), Tensor([1., 2.]), mode="full")
5851
+ Tensor(shape=[3], dtype=Float32, value= [ 2.00000000e+00, 3.50000000e+00, 3.00000000e+00])
5852
+ >>> output = mnp.correlate(Tensor([1., 2., 3.]), Tensor([1., 2.]), mode="full")
5853
5853
  >>> print(output)
5854
- [ 2. 5. 8. 11. 14. 5.]
5854
+ Tensor(shape=[4], dtype=Float32, value= [ 2.00000000e+00, 5.00000000e+00, 8.00000000e+00, 3.00000000e+00])
5855
5855
  """
5856
5856
  if isinstance(a, list):
5857
5857
  a = ops.auto_generate.list_to_tuple(a)
mindspore/numpy/utils.py CHANGED
@@ -18,7 +18,6 @@ from __future__ import absolute_import
18
18
  import types
19
19
 
20
20
  from mindspore.common import Tensor
21
- from mindspore._c_expression import Tensor as Tensor_
22
21
  from mindspore.common import dtype as mstype
23
22
  from mindspore import ops
24
23
 
@@ -129,7 +128,7 @@ def _to_tensor(*args):
129
128
  for arg in args:
130
129
  if isinstance(arg, (int, float, bool, list, tuple)):
131
130
  if isinstance(arg, (list, tuple)) and not arg:
132
- arg = Tensor_(arg)
131
+ arg = Tensor(arg)
133
132
  arg = _convert_64_to_32(_type_convert(Tensor, arg))
134
133
  elif not isinstance(arg, Tensor):
135
134
  _raise_type_error("Expect input to be array like.")
@@ -24,7 +24,6 @@ from mindspore.ops.primitive import constexpr
24
24
  from mindspore.ops.primitive import _primexpr
25
25
  from mindspore.common import dtype as mstype
26
26
  from mindspore.common import Tensor
27
- from mindspore._c_expression import Tensor as Tensor_
28
27
  from mindspore._c_expression import typing
29
28
  from mindspore import _checkparam as validator
30
29
  from mindspore import ops
@@ -282,7 +281,7 @@ def _raise_unimplemented_error(info, param=None):
282
281
  @_primexpr
283
282
  def _empty(dtype, shape):
284
283
  """Returns an uninitialized array with dtype and shape."""
285
- return Tensor_(dtype, shape)
284
+ return Tensor(dtype=dtype, shape=shape)
286
285
 
287
286
 
288
287
  @constexpr
Binary file
Binary file
Binary file
mindspore/ops/__init__.py CHANGED
@@ -31,13 +31,14 @@ from mindspore.ops.op_info_register import op_info_register, custom_info_registe
31
31
  from mindspore.ops.primitive import constexpr
32
32
  from mindspore.ops import composite, operations, functional, function
33
33
  from mindspore.ops import signature
34
- from mindspore.ops.auto_generate import cpp_create_prim_instance_helper, gen_arg_dtype_cast, gen_arg_handler, \
34
+ from mindspore.ops.auto_generate import cpp_create_prim_instance_helper, \
35
35
  gen_extend_func, gen_ops_def, gen_ops_prim, pyboost_inner_prim
36
36
  from mindspore.ops.functional_overload import all_gather_matmul, matmul_reduce_scatter
37
37
  from mindspore.ops.composite import *
38
38
  from mindspore.ops.operations import *
39
39
  from mindspore.ops.function import *
40
40
  from mindspore.ops.functional import *
41
+ from mindspore.ops._utils import arg_dtype_cast, arg_handler
41
42
 
42
43
  __primitive__ = [
43
44
  "prim_attr_register", "prim_arg_register", "Primitive", "PrimitiveWithInfer", "PrimitiveWithCheck", "signature"
@@ -47,7 +48,7 @@ __all__ = ["get_vm_impl_fn", "vm_impl_registry",
47
48
  "op_info_register", "custom_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp",
48
49
  "CpuRegOp", "CustomRegOp", "DataType",
49
50
  "constexpr", "reshard",
50
- "cpp_create_prim_instance_helper", "gen_arg_dtype_cast", "gen_arg_handler", "gen_extend_func", "gen_ops_def",
51
+ "cpp_create_prim_instance_helper", "arg_dtype_cast", "arg_handler", "gen_extend_func", "gen_ops_def",
51
52
  "gen_ops_prim", "pyboost_inner_prim", "all_gather_matmul", "matmul_reduce_scatter"]
52
53
  __all__.extend(__primitive__)
53
54
  __all__.extend(composite.__all__)
@@ -27,7 +27,7 @@ from mindspore.ops.operations._inner_ops import issubclass_
27
27
  from mindspore.common.sparse_tensor import RowTensorInner
28
28
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
29
29
  from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce,
30
- NeighborExchange, AlltoAll, NeighborExchangeV2, Broadcast,
30
+ NeighborExchange, AlltoAll, AlltoAllV, NeighborExchangeV2, Broadcast,
31
31
  _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
32
32
  ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
33
33
  _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator,
@@ -192,7 +192,7 @@ def get_bprop_virtual_assign_add(self):
192
192
 
193
193
  def bprop(x, y, out, dout):
194
194
  if reduce_scatter:
195
- dout = reduce_scatter(dout)
195
+ dout = reduce_scatter(cast(dout, dtype(y)))
196
196
  return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), assign_add(y, dout))
197
197
 
198
198
  return bprop
@@ -451,7 +451,7 @@ def get_bprop_micro_step_all_gather(self):
451
451
  if with_mirror_operator:
452
452
  if not do_mirror:
453
453
  return (dout, cast(out_tensor, dtype(z)))
454
- real_grad = reduce_scatter(dout)
454
+ real_grad = reduce_scatter(cast(dout, dtype(z)))
455
455
  if mean_flag:
456
456
  real_grad = F.tensor_mul(real_grad, scale)
457
457
  return (real_grad, cast(out_tensor, dtype(z)))
@@ -635,6 +635,21 @@ def get_bprop_all_to_all(self):
635
635
  return bprop
636
636
 
637
637
 
638
+ @bprop_getters.register(AlltoAllV)
639
+ def get_bprop_all_to_all_v(self):
640
+ """Generate bprop for AlltoAll."""
641
+ all_to_all_v_grad = AlltoAllV(self.group, self.block_size)
642
+ if hasattr(self, "instance_name") and self.instance_name:
643
+ instance_name = "grad" + self.instance_name
644
+ all_to_all_v_grad.set_prim_instance_name(instance_name)
645
+
646
+ def bprop(x, send_numel_list, recv_numel_list, out, dout):
647
+ dx = all_to_all_v_grad(dout, recv_numel_list, send_numel_list)
648
+ return (dx, zeros_like(send_numel_list), zeros_like(recv_numel_list))
649
+
650
+ return bprop
651
+
652
+
638
653
  @bprop_getters.register(NeighborExchangeV2)
639
654
  def get_bprop_neighborexchangev2(self):
640
655
  """Generate bprop for NeighborExchangeV2."""
@@ -17,7 +17,8 @@
17
17
 
18
18
  import mindspore.ops.functional as F
19
19
  from mindspore.ops import operations as P
20
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
20
+ from mindspore.ops.composite import multitype_ops as C
21
+ from mindspore.ops._grad_experimental.grad_base import bprop_getters, bprops
21
22
 
22
23
  # Unused parameters are placeholders.
23
24
 
@@ -34,3 +35,9 @@ def get_bprop_insert_gradient_of(self):
34
35
  return (dout,)
35
36
  return (fdout,)
36
37
  return bprop
38
+
39
+
40
+ @bprops.register("TensorDump")
41
+ def bprop_tensor_dump(file, input_x, out, dout):
42
+ """Generate bprop for TensorDump"""
43
+ return file, C.zeros_like(input_x)
@@ -127,6 +127,35 @@ def taylor_realdiv(self):
127
127
  return taylor_fprop_realdiv
128
128
 
129
129
 
130
+ def _taylor_fprop_div(input_x, input_y):
131
+ """The rule to generate `Div` taylor rule."""
132
+ if not input_y.shape:
133
+ return ops.div(input_x, input_y)
134
+ input_x, input_y = _trans_scalar_inputs(input_x, input_y)
135
+ primals = ops.div(input_x[0], input_y[0])
136
+ series_num = len(input_x) - 1
137
+ factorial = _factorial(series_num)
138
+ series = zeros_like(input_x)
139
+ series[0] = primals
140
+ for k in range(1, series_num + 1):
141
+ for i in range(0, k):
142
+ tmp = ops.div(series[i] * input_y[k - i], (factorial[k - i] * factorial[i]))
143
+ series[k] += tmp
144
+ series[k] = ops.div(input_x[k] - factorial[k] * series[k], input_y[0])
145
+ return series
146
+
147
+
148
+ @taylor_fprop_getters.register(P.Div)
149
+ def taylor_div(self):
150
+ """Higher order derivatives rule definition for `Div` operation."""
151
+
152
+ def taylor_fprop_div(input_x, input_y):
153
+ series = _taylor_fprop_div(input_x, input_y)
154
+ return series
155
+
156
+ return taylor_fprop_div
157
+
158
+
130
159
  @taylor_fprop_getters.register(P.Exp)
131
160
  def taylor_exp(self):
132
161
  """Higher order derivatives rule definition for `Exp` operation."""
@@ -60,14 +60,3 @@ class PyFuncRegistry(UserDict):
60
60
  if key not in self:
61
61
  raise ValueError(f"Python function with key{key} not registered.")
62
62
  return self[key]
63
-
64
-
65
- class OpaquePredicateRegistry(PyFuncRegistry):
66
- """Registry opaque predicate functions used for dynamic obfuscation"""
67
- def __init__(self):
68
- super(OpaquePredicateRegistry, self).__init__()
69
- self.func_names = []
70
-
71
- def register(self, key, value):
72
- self[key] = value
73
- self.func_names.append(key)
@@ -20,14 +20,133 @@ import mindspore as ms
20
20
  from mindspore import ops
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
23
- from mindspore.ops_generate.gen_ops_inner_prim import TupleToList, ListToTuple
24
23
  from mindspore._c_expression import OpDtype
24
+ from mindspore._c_expression import typing
25
+ from mindspore._c_expression import op_enum
26
+ from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register
25
27
 
26
28
  tensor_to_tuple_ = TensorToTuple()
29
+ tensor_to_scalar_ = TensorToScalar()
30
+
31
+
32
+ class TupleToList(Primitive):
33
+ r"""
34
+ Convert tuple to list.
35
+
36
+ Inputs:
37
+ - **x** (tuple) - The input
38
+
39
+ Outputs:
40
+ List, has the same elements as the `input`.
41
+
42
+ Supported Platforms:
43
+ ``CPU``
44
+
45
+ Examples:
46
+ >>> from mindspore.ops._utils.arg_dtype_cast import TupleToList
47
+ >>> x = (1, 2, 3)
48
+ >>> result = TupleToList()(x)
49
+ >>> print(result)
50
+ [1, 2, 3]
51
+ """
52
+ @prim_arg_register
53
+ def __init__(self):
54
+ """Initialize TupleToList"""
55
+
56
+ def __call__(self, input):
57
+ return list(input)
58
+
59
+
60
+ class ListToTuple(Primitive):
61
+ r"""
62
+ Convert list to tuple.
63
+
64
+ Inputs:
65
+ - **x** (list) - The input
66
+
67
+ Outputs:
68
+ Tuple, has the same elements as the `input`.
69
+
70
+ Supported Platforms:
71
+ ``CPU``
72
+
73
+ Examples:
74
+ >>> from mindspore.ops._utils.arg_dtype_cast import ListToTuple
75
+ >>> x = [1, 2, 3]
76
+ >>> result = ListToTuple()(x)
77
+ >>> print(result)
78
+ (1, 2, 3)
79
+ """
80
+ @prim_arg_register
81
+ def __init__(self):
82
+ """Initialize TupleToList"""
83
+
84
+ def __call__(self, input):
85
+ return tuple(input)
86
+
87
+
27
88
  tuple_to_list = TupleToList()
28
89
  list_to_tuple = ListToTuple()
29
90
 
30
91
 
92
+ class DtypeToEnum(Primitive):
93
+ r"""
94
+ Convert mindspore dtype to enum.
95
+
96
+ Inputs:
97
+ - **op_name** (str) - The op name
98
+ - **arg_name** (str) - The arg name
99
+ - **dtype** (mindspore.dtype) - The data type.
100
+
101
+ Outputs:
102
+ An integer.
103
+
104
+ Supported Platforms:
105
+ ``Ascend`` ``GPU`` ``CPU``
106
+ """
107
+
108
+ @prim_attr_register
109
+ def __init__(self):
110
+ """Initialize"""
111
+
112
+ def __call__(self, op_name, arg_name, dtype):
113
+ """Run in PyNative mode"""
114
+ if not isinstance(dtype, typing.Type):
115
+ raise TypeError(
116
+ f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
117
+ return typing.type_to_type_id(dtype)
118
+
119
+
120
+ class StringToEnum(Primitive):
121
+ r"""
122
+ Convert string to enum.
123
+
124
+ Inputs:
125
+ - **op_name** (str) - The op name
126
+ - **arg_name** (str) - The arg name
127
+ - **enum_str** (str) - The str data.
128
+
129
+ Outputs:
130
+ An integer.
131
+
132
+ Supported Platforms:
133
+ ``CPU``
134
+ """
135
+
136
+ @prim_attr_register
137
+ def __init__(self):
138
+ """Initialize"""
139
+
140
+ def __call__(self, op_name, arg_name, enum_str):
141
+ """Run in PyNative mode"""
142
+ if enum_str is None:
143
+ return None
144
+ if not isinstance(enum_str, str):
145
+ raise TypeError(
146
+ f"For '{op_name}', the input '{arg_name}' should be a str, but got {type(enum_str)}.")
147
+ return op_enum.str_to_enum(op_name, arg_name, enum_str)
148
+
149
+
31
150
  def int_to_float(data):
32
151
  return float(data)
33
152
 
@@ -184,7 +303,7 @@ def get_support_dtype_list(src_type, dst_type):
184
303
  return support_list
185
304
 
186
305
 
187
- def to_py_number(data, dst_type):
306
+ def tensor_to_number(data, dst_type):
188
307
  """Convert tensor to python number"""
189
308
  if dst_type == DT_INT_VAL:
190
309
  data = ops.cast(data, ms.int64)
@@ -197,7 +316,7 @@ def to_py_number(data, dst_type):
197
316
  data = ops.cast(data, ms.int64)
198
317
  elif src_type in (ms.bfloat16, ms.float16, ms.float32, ms.float64):
199
318
  data = ops.cast(data, ms.float32)
200
- return TensorToScalar()(data)
319
+ return tensor_to_scalar_(data)
201
320
 
202
321
 
203
322
  def do_type_cast(data, dst_type):
@@ -230,7 +349,7 @@ def do_type_cast(data, dst_type):
230
349
  return list_to_tensor(data)
231
350
  elif is_number(dst_type):
232
351
  if isinstance(data, Tensor):
233
- return to_py_number(data, dst_type)
352
+ return tensor_to_number(data, dst_type)
234
353
  raise TypeError("Type conversion failed.")
235
354
 
236
355
 
@@ -14,13 +14,11 @@
14
14
  # ============================================================================
15
15
  """Operator argument handle function."""
16
16
 
17
- from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
18
- # Enum Class:
19
- from mindspore._c_expression import FormatEnum as Format
20
- from mindspore._c_expression import ReductionEnum as Reduction
21
17
  from mindspore.common import Tensor
22
18
  from mindspore.common import dtype as mstype
23
19
 
20
+ from .arg_dtype_cast import DtypeToEnum, StringToEnum
21
+
24
22
 
25
23
  def arg_invalid_info(op_name, arg_name, arg_val):
26
24
  """
@@ -128,6 +126,7 @@ def generator_handler(op_name, arg_name, inputs):
128
126
  new_inputs.append(input_)
129
127
  return tuple(new_inputs)
130
128
 
129
+
131
130
  dtype_to_type_id = DtypeToEnum()
132
131
 
133
132
  # string to enum
@@ -21,7 +21,7 @@ import mindspore
21
21
  import mindspore.numpy as mnp
22
22
  from mindspore import ops
23
23
  from mindspore.common import Tensor
24
- from mindspore._c_expression import Tensor as Tensor_
24
+ from mindspore._c_expression import TensorPy as Tensor_
25
25
  from mindspore.ops import operations as P
26
26
  from mindspore.ops import functional as F
27
27
  from mindspore.ops.primitive import constexpr, _primexpr
@@ -141,6 +141,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
141
141
  the generated prefix is a Tensor([[[0], [0]],
142
142
  [[1], [1]]])
143
143
  """
144
+ cast_op = P.Cast()
145
+
144
146
  def _check(indices_shape):
145
147
  if not indices_shape:
146
148
  raise ValueError("indices_shape is empty in _get_prefix.")
@@ -148,8 +150,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
148
150
  _check(indices_shape)
149
151
  indices_len = len(indices_shape)
150
152
  if indices_len == 1:
151
- prefix = P.Range()(Tensor(0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype))
152
- return prefix
153
+ prefix = P.Range()(0, axis_size, 1)
154
+ return cast_op(prefix, indices_dtype)
153
155
 
154
156
  indices_end = indices_len - 1
155
157
  prefix_shape = ()
@@ -164,9 +166,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
164
166
  else:
165
167
  expand_shape = expand_shape + (1,)
166
168
 
167
- prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
168
- 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
169
- return prefix
169
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(0, axis_size, 1), expand_shape))
170
+ return cast_op(prefix, indices_dtype)
170
171
 
171
172
 
172
173
  @vmap_rules_getters.register(P.Transpose)
@@ -25,8 +25,9 @@ from mindspore.ops.primitive import _primexpr
25
25
  from mindspore.ops.function import _VmapGeneralRule
26
26
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
27
27
  _bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
28
- from mindspore.ops.auto_generate.gen_arg_handler import Format, Reduction
29
28
  from mindspore.ops import auto_generate as gen
29
+ from mindspore._c_expression import FormatEnum as Format
30
+ from mindspore._c_expression import ReductionEnum as Reduction
30
31
 
31
32
 
32
33
  @vmap_rules_getters.register(G.NLLLossGrad)
@@ -561,20 +561,17 @@ def get_index_add_vmap_rule(prim, axis_size):
561
561
  @vmap_rules_getters.register(linalg_ops.Svd)
562
562
  def get_svd_vmap_rule(prim, axis_size):
563
563
  """VmapRule for 'Svd' operation."""
564
- if isinstance(prim, str):
565
- prim = Primitive(prim)
566
- compute_uv = True
567
- else:
568
- compute_uv = prim.compute_uv
569
564
 
570
- def vmap_rule(x_bdim):
565
+ def vmap_rule(x_bdim, full_matrices_bdim, compute_uv_bdim):
571
566
  is_all_none, result = vmap_general_preprocess(prim, x_bdim)
572
567
  if is_all_none:
573
568
  return result
574
569
 
575
570
  x, x_dim = x_bdim
571
+ full_matrices, _ = full_matrices_bdim
572
+ compute_uv, _ = compute_uv_bdim
576
573
  x = _bdim_at_front(x, x_dim, axis_size)
577
- s, u, v = prim(x)
574
+ s, u, v = prim(x, full_matrices, compute_uv)
578
575
  if compute_uv:
579
576
  return (s, 0), (u, 0), (v, 0)
580
577
  return (s, 0), (u, None), (v, None)
@@ -29,9 +29,9 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
29
29
  _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
30
30
  _vmap_clone_prim, _get_reduce_batch_axis
31
31
  from mindspore.ops.primitive import Primitive
32
- from mindspore.ops.auto_generate.gen_arg_handler import Format
33
32
  from mindspore.ops.auto_generate import Embedding
34
- from mindspore.ops.auto_generate import gen_arg_handler as handler
33
+ from mindspore.ops._utils import arg_handler as handler
34
+ from mindspore._c_expression import FormatEnum as Format
35
35
 
36
36
 
37
37
  @vmap_rules_getters.register(P.ApplyAdaMax)
@@ -1632,13 +1632,12 @@ def get_apply_adagrad_da_vmap_rule(prim, axis_size):
1632
1632
  return vmap_rule
1633
1633
 
1634
1634
 
1635
- @vmap_rules_getters.register(NN.AdaptiveMaxPool2D)
1635
+ @vmap_rules_getters.register(P.AdaptiveMaxPool2D)
1636
1636
  def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1637
1637
  """VmapRule for `AdaptiveMaxPool2D`."""
1638
1638
  nchw_index = 4
1639
1639
  chw_reverse_index = -3
1640
1640
  hw_size = 2
1641
- output_size = prim.output_size
1642
1641
 
1643
1642
  @_primexpr
1644
1643
  def get_output_shape(x_ori_shape, output_size):
@@ -1661,8 +1660,8 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1661
1660
  output_shape += (w_out,)
1662
1661
  return output_shape
1663
1662
 
1664
- def vmap_rule(input_x_bdim):
1665
- is_all_none, result = vmap_general_preprocess(prim, input_x_bdim)
1663
+ def vmap_rule(input_x_bdim, output_size_bdim):
1664
+ is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, output_size_bdim)
1666
1665
  if is_all_none:
1667
1666
  return result
1668
1667
 
@@ -1670,18 +1669,20 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1670
1669
  x = _bdim_at_front(input_x, input_x_dim, axis_size)
1671
1670
  x_ndim = F.rank(x)
1672
1671
 
1672
+ output_size, _ = output_size_bdim
1673
+
1673
1674
  if x_ndim > nchw_index:
1674
1675
  # for the case of NCHW
1675
1676
  x_ori_shape = F.shape(x)
1676
1677
  x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
1677
1678
  output_shape = get_output_shape(x_ori_shape, output_size)
1678
- out, indices = prim(x)
1679
+ out, indices = prim(x, output_size)
1679
1680
  out = F.reshape(out, output_shape)
1680
1681
  indices = F.reshape(indices, output_shape)
1681
1682
  return (out, 0), (indices, 0)
1682
1683
 
1683
1684
  # for the case of CHW
1684
- out, indices = prim(x)
1685
+ out, indices = prim(x, output_size)
1685
1686
  return (out, 0), (indices, 0)
1686
1687
 
1687
1688
  return vmap_rule
@@ -19,13 +19,14 @@ Primitive operator classes and operator functional.
19
19
  A collection of operators to build neural networks or to compute functions.
20
20
  """
21
21
 
22
- from . import gen_ops_def, gen_arg_handler, gen_arg_dtype_cast
22
+ from . import gen_ops_def
23
+ from .._utils import arg_handler, arg_dtype_cast
23
24
 
24
25
  from .gen_ops_prim import *
25
26
  from .gen_ops_def import *
26
- from .gen_arg_handler import *
27
- from .gen_arg_dtype_cast import *
28
27
  from ..operations.manually_defined.ops_def import *
28
+ from .._utils.arg_handler import *
29
+ from .._utils.arg_dtype_cast import *
29
30
 
30
31
 
31
32
  __all__ = []