mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (493) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +25 -194
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +109 -75
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +2014 -3386
  46. mindspore/common/api.py +386 -355
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/generator.py +3 -0
  52. mindspore/common/hook_handle.py +5 -3
  53. mindspore/common/initializer.py +10 -6
  54. mindspore/common/jit_begin_end.py +94 -0
  55. mindspore/common/jit_config.py +6 -1
  56. mindspore/common/jit_context.py +76 -0
  57. mindspore/common/jit_trace.py +378 -0
  58. mindspore/common/lazy_inline.py +2 -2
  59. mindspore/common/mutable.py +5 -4
  60. mindspore/common/parameter.py +106 -39
  61. mindspore/common/seed.py +2 -2
  62. mindspore/common/sparse_tensor.py +23 -17
  63. mindspore/common/tensor.py +332 -714
  64. mindspore/communication/__init__.py +7 -5
  65. mindspore/communication/_comm_helper.py +47 -2
  66. mindspore/communication/comm_func.py +70 -53
  67. mindspore/communication/management.py +83 -17
  68. mindspore/context.py +228 -571
  69. mindspore/dataset/__init__.py +44 -20
  70. mindspore/dataset/audio/__init__.py +2 -8
  71. mindspore/dataset/audio/transforms.py +3 -17
  72. mindspore/dataset/core/config.py +3 -3
  73. mindspore/dataset/engine/cache_client.py +1 -1
  74. mindspore/dataset/engine/datasets.py +102 -120
  75. mindspore/dataset/engine/datasets_audio.py +22 -22
  76. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  77. mindspore/dataset/engine/datasets_text.py +78 -85
  78. mindspore/dataset/engine/datasets_user_defined.py +109 -77
  79. mindspore/dataset/engine/datasets_vision.py +111 -108
  80. mindspore/dataset/engine/iterators.py +5 -3
  81. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  82. mindspore/dataset/engine/samplers.py +279 -57
  83. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  84. mindspore/dataset/engine/validators.py +10 -0
  85. mindspore/dataset/text/__init__.py +7 -6
  86. mindspore/dataset/text/transforms.py +6 -5
  87. mindspore/dataset/text/utils.py +3 -3
  88. mindspore/dataset/transforms/__init__.py +0 -9
  89. mindspore/dataset/transforms/transforms.py +3 -3
  90. mindspore/dataset/utils/browse_dataset.py +1 -1
  91. mindspore/dataset/vision/__init__.py +2 -9
  92. mindspore/dataset/vision/transforms.py +202 -158
  93. mindspore/dataset/vision/utils.py +7 -5
  94. mindspore/device_context/ascend/op_debug.py +60 -1
  95. mindspore/device_context/ascend/op_tuning.py +0 -4
  96. mindspore/device_manager.py +39 -3
  97. mindspore/dnnl.dll +0 -0
  98. mindspore/dpcmi.dll +0 -0
  99. mindspore/experimental/es/embedding_service.py +35 -27
  100. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -2
  101. mindspore/experimental/map_parameter.py +4 -4
  102. mindspore/experimental/optim/adadelta.py +22 -26
  103. mindspore/experimental/optim/adagrad.py +4 -4
  104. mindspore/experimental/optim/adam.py +4 -0
  105. mindspore/experimental/optim/adamax.py +4 -4
  106. mindspore/experimental/optim/adamw.py +4 -0
  107. mindspore/experimental/optim/asgd.py +1 -1
  108. mindspore/experimental/optim/lr_scheduler.py +40 -22
  109. mindspore/experimental/optim/radam.py +5 -5
  110. mindspore/experimental/optim/rprop.py +1 -1
  111. mindspore/experimental/optim/sgd.py +1 -1
  112. mindspore/hal/contiguous_tensors_handle.py +6 -10
  113. mindspore/hal/device.py +55 -81
  114. mindspore/hal/event.py +38 -55
  115. mindspore/hal/memory.py +115 -147
  116. mindspore/hal/stream.py +81 -125
  117. mindspore/include/dataset/constants.h +7 -4
  118. mindspore/include/dataset/execute.h +2 -2
  119. mindspore/jpeg62.dll +0 -0
  120. mindspore/log.py +40 -2
  121. mindspore/mindrecord/__init__.py +20 -7
  122. mindspore/mindspore_backend_common.dll +0 -0
  123. mindspore/mindspore_backend_manager.dll +0 -0
  124. mindspore/mindspore_common.dll +0 -0
  125. mindspore/mindspore_core.dll +0 -0
  126. mindspore/mindspore_dump.dll +0 -0
  127. mindspore/mindspore_frontend.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_memory_pool.dll +0 -0
  130. mindspore/mindspore_ms_backend.dll +0 -0
  131. mindspore/mindspore_ops.dll +0 -0
  132. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  133. mindspore/mindspore_ops_kernel_common.dll +0 -0
  134. mindspore/mindspore_profiler.dll +0 -0
  135. mindspore/mindspore_pyboost.dll +0 -0
  136. mindspore/mindspore_pynative.dll +0 -0
  137. mindspore/mindspore_res_manager.dll +0 -0
  138. mindspore/mindspore_runtime_pipeline.dll +0 -0
  139. mindspore/mint/__init__.py +133 -702
  140. mindspore/mint/distributed/__init__.py +5 -1
  141. mindspore/mint/distributed/distributed.py +198 -113
  142. mindspore/mint/linalg/__init__.py +2 -0
  143. mindspore/mint/nn/__init__.py +280 -18
  144. mindspore/mint/nn/functional.py +282 -64
  145. mindspore/mint/nn/layer/__init__.py +4 -0
  146. mindspore/mint/nn/layer/_functions.py +7 -3
  147. mindspore/mint/nn/layer/activation.py +120 -13
  148. mindspore/mint/nn/layer/conv.py +234 -28
  149. mindspore/mint/nn/layer/normalization.py +15 -16
  150. mindspore/mint/nn/layer/padding.py +1 -1
  151. mindspore/mint/nn/layer/pooling.py +66 -1
  152. mindspore/mint/optim/__init__.py +2 -1
  153. mindspore/mint/optim/sgd.py +171 -0
  154. mindspore/msobj140.dll +0 -0
  155. mindspore/mspdb140.dll +0 -0
  156. mindspore/mspdbcore.dll +0 -0
  157. mindspore/mspdbst.dll +0 -0
  158. mindspore/mspft140.dll +0 -0
  159. mindspore/msvcdis140.dll +0 -0
  160. mindspore/msvcp140_1.dll +0 -0
  161. mindspore/msvcp140_2.dll +0 -0
  162. mindspore/msvcp140_atomic_wait.dll +0 -0
  163. mindspore/msvcp140_codecvt_ids.dll +0 -0
  164. mindspore/nn/__init__.py +4 -1
  165. mindspore/nn/cell.py +1253 -179
  166. mindspore/nn/layer/activation.py +23 -21
  167. mindspore/nn/layer/basic.py +22 -16
  168. mindspore/nn/layer/container.py +1 -1
  169. mindspore/nn/layer/conv.py +53 -42
  170. mindspore/nn/layer/embedding.py +9 -8
  171. mindspore/nn/layer/normalization.py +48 -42
  172. mindspore/nn/layer/pooling.py +75 -31
  173. mindspore/nn/layer/transformer.py +11 -10
  174. mindspore/nn/learning_rate_schedule.py +4 -2
  175. mindspore/nn/loss/loss.py +27 -19
  176. mindspore/nn/optim/ada_grad.py +6 -5
  177. mindspore/nn/optim/adadelta.py +9 -7
  178. mindspore/nn/optim/adafactor.py +1 -1
  179. mindspore/nn/optim/adam.py +18 -14
  180. mindspore/nn/optim/adamax.py +8 -7
  181. mindspore/nn/optim/adasum.py +5 -5
  182. mindspore/nn/optim/asgd.py +3 -1
  183. mindspore/nn/optim/ftrl.py +11 -9
  184. mindspore/nn/optim/lamb.py +1 -1
  185. mindspore/nn/optim/lazyadam.py +12 -10
  186. mindspore/nn/optim/momentum.py +7 -6
  187. mindspore/nn/optim/optimizer.py +2 -2
  188. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  189. mindspore/nn/optim/rmsprop.py +13 -12
  190. mindspore/nn/optim/rprop.py +9 -7
  191. mindspore/nn/optim/sgd.py +9 -6
  192. mindspore/nn/optim/tft_wrapper.py +5 -2
  193. mindspore/nn/probability/bijector/bijector.py +17 -11
  194. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  195. mindspore/nn/probability/bijector/invert.py +2 -2
  196. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  197. mindspore/nn/probability/bijector/softplus.py +3 -2
  198. mindspore/nn/probability/distribution/beta.py +3 -3
  199. mindspore/nn/probability/distribution/categorical.py +1 -1
  200. mindspore/nn/probability/distribution/cauchy.py +4 -2
  201. mindspore/nn/probability/distribution/exponential.py +6 -7
  202. mindspore/nn/probability/distribution/gamma.py +2 -2
  203. mindspore/nn/probability/distribution/gumbel.py +2 -2
  204. mindspore/nn/probability/distribution/half_normal.py +5 -3
  205. mindspore/nn/probability/distribution/logistic.py +5 -3
  206. mindspore/nn/probability/distribution/poisson.py +1 -1
  207. mindspore/nn/probability/distribution/uniform.py +5 -3
  208. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  209. mindspore/nn/reinforcement/tensor_array.py +1 -1
  210. mindspore/nn/wrap/__init__.py +6 -6
  211. mindspore/nn/wrap/cell_wrapper.py +178 -117
  212. mindspore/nn/wrap/grad_reducer.py +45 -36
  213. mindspore/nn/wrap/loss_scale.py +3 -3
  214. mindspore/numpy/array_creations.py +3 -3
  215. mindspore/numpy/array_ops.py +1 -1
  216. mindspore/numpy/utils.py +1 -2
  217. mindspore/numpy/utils_const.py +1 -2
  218. mindspore/opencv_core452.dll +0 -0
  219. mindspore/opencv_imgcodecs452.dll +0 -0
  220. mindspore/opencv_imgproc452.dll +0 -0
  221. mindspore/ops/__init__.py +3 -2
  222. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  223. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  224. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  225. mindspore/ops/_register_for_op.py +0 -11
  226. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  227. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  228. mindspore/ops/_vmap/vmap_array_ops.py +32 -6
  229. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  230. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  231. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  232. mindspore/ops/auto_generate/__init__.py +4 -3
  233. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +127 -52
  234. mindspore/ops/auto_generate/gen_extend_func.py +286 -208
  235. mindspore/ops/auto_generate/gen_ops_def.py +2783 -2335
  236. mindspore/ops/auto_generate/gen_ops_prim.py +8992 -2686
  237. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  238. mindspore/ops/composite/__init__.py +2 -1
  239. mindspore/ops/composite/base.py +19 -24
  240. mindspore/ops/composite/math_ops.py +6 -16
  241. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  242. mindspore/ops/composite/multitype_ops/_compile_utils.py +4 -5
  243. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  244. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  248. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  249. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  250. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  251. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  252. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  254. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  255. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  256. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  257. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  259. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  260. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  263. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  264. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  268. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  271. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  272. mindspore/ops/function/__init__.py +28 -2
  273. mindspore/ops/function/_add_attr_func.py +58 -0
  274. mindspore/ops/function/array_func.py +1631 -2347
  275. mindspore/ops/function/clip_func.py +38 -45
  276. mindspore/ops/function/debug_func.py +36 -44
  277. mindspore/ops/function/grad/__init__.py +1 -0
  278. mindspore/ops/function/grad/grad_func.py +104 -71
  279. mindspore/ops/function/image_func.py +1 -1
  280. mindspore/ops/function/linalg_func.py +46 -78
  281. mindspore/ops/function/math_func.py +3024 -3855
  282. mindspore/ops/function/nn_func.py +678 -274
  283. mindspore/ops/function/other_func.py +159 -1
  284. mindspore/ops/function/parameter_func.py +17 -30
  285. mindspore/ops/function/random_func.py +216 -361
  286. mindspore/ops/function/reshard_func.py +4 -70
  287. mindspore/ops/function/sparse_func.py +3 -3
  288. mindspore/ops/function/sparse_unary_func.py +5 -5
  289. mindspore/ops/function/spectral_func.py +25 -58
  290. mindspore/ops/function/vmap_func.py +26 -18
  291. mindspore/ops/functional.py +8 -5
  292. mindspore/ops/functional_overload.py +655 -4
  293. mindspore/ops/op_info_register.py +32 -244
  294. mindspore/ops/operations/__init__.py +21 -14
  295. mindspore/ops/operations/_custom_ops_utils.py +235 -0
  296. mindspore/ops/operations/_grad_ops.py +1 -10
  297. mindspore/ops/operations/_inner_ops.py +5 -76
  298. mindspore/ops/operations/_ms_kernel.py +4 -10
  299. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  300. mindspore/ops/operations/_scalar_ops.py +3 -2
  301. mindspore/ops/operations/_sequence_ops.py +1 -1
  302. mindspore/ops/operations/_tensor_array.py +1 -1
  303. mindspore/ops/operations/array_ops.py +39 -24
  304. mindspore/ops/operations/comm_ops.py +150 -107
  305. mindspore/ops/operations/custom_ops.py +287 -32
  306. mindspore/ops/operations/debug_ops.py +119 -16
  307. mindspore/ops/operations/inner_ops.py +1 -1
  308. mindspore/ops/operations/linalg_ops.py +1 -58
  309. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  310. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  311. mindspore/ops/operations/math_ops.py +21 -18
  312. mindspore/ops/operations/nn_ops.py +67 -224
  313. mindspore/ops/operations/other_ops.py +62 -9
  314. mindspore/ops/operations/random_ops.py +13 -7
  315. mindspore/ops/operations/reshard_ops.py +1 -1
  316. mindspore/ops/operations/sparse_ops.py +2 -2
  317. mindspore/ops/primitive.py +43 -32
  318. mindspore/ops/tensor_method.py +243 -17
  319. mindspore/ops_generate/__init__.py +0 -5
  320. mindspore/ops_generate/aclnn/__init__.py +0 -0
  321. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  322. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  323. mindspore/ops_generate/api/__init__.py +0 -0
  324. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  325. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  326. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  327. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  328. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  329. mindspore/ops_generate/api/gen_api.py +103 -0
  330. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  331. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  332. mindspore/ops_generate/common/__init__.py +0 -0
  333. mindspore/ops_generate/common/gen_constants.py +91 -0
  334. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  335. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  336. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  337. mindspore/ops_generate/gen_ops.py +23 -325
  338. mindspore/ops_generate/op_def/__init__.py +0 -0
  339. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  340. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  341. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -10
  342. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  343. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  344. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  345. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  346. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  347. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  348. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  349. mindspore/ops_generate/pyboost/__init__.py +0 -0
  350. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  351. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  352. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  353. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  354. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  355. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  356. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  357. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  358. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  359. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  360. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  361. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  362. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  363. mindspore/ops_generate/resources/__init__.py +0 -0
  364. mindspore/ops_generate/resources/resource_list.py +30 -0
  365. mindspore/ops_generate/resources/resource_loader.py +36 -0
  366. mindspore/ops_generate/resources/resource_manager.py +64 -0
  367. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  368. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  369. mindspore/parallel/__init__.py +6 -2
  370. mindspore/parallel/_auto_parallel_context.py +140 -12
  371. mindspore/parallel/_cell_wrapper.py +132 -15
  372. mindspore/parallel/_parallel_serialization.py +95 -4
  373. mindspore/parallel/_ps_context.py +1 -1
  374. mindspore/parallel/_recovery_context.py +7 -2
  375. mindspore/parallel/_tensor.py +142 -18
  376. mindspore/parallel/_utils.py +198 -25
  377. mindspore/parallel/algo_parameter_config.py +3 -3
  378. mindspore/parallel/auto_parallel.py +732 -0
  379. mindspore/parallel/checkpoint_convert.py +159 -0
  380. mindspore/parallel/checkpoint_transform.py +658 -37
  381. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  382. mindspore/parallel/cluster/run.py +1 -1
  383. mindspore/parallel/function/__init__.py +24 -0
  384. mindspore/parallel/function/reshard_func.py +258 -0
  385. mindspore/parallel/nn/__init__.py +25 -0
  386. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  387. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  388. mindspore/parallel/parameter_broadcast.py +24 -13
  389. mindspore/parallel/shard.py +137 -62
  390. mindspore/parallel/transform_safetensors.py +288 -95
  391. mindspore/pgodb140.dll +0 -0
  392. mindspore/pgort140.dll +0 -0
  393. mindspore/profiler/__init__.py +9 -5
  394. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  395. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  397. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +25 -0
  398. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  399. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  400. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  401. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  402. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  403. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  404. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  405. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  406. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  407. mindspore/profiler/common/constant.py +12 -0
  408. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  409. mindspore/profiler/common/path_manager.py +24 -0
  410. mindspore/profiler/common/profiler_context.py +26 -2
  411. mindspore/profiler/common/profiler_meta_data.py +74 -0
  412. mindspore/profiler/common/profiler_parameters.py +59 -18
  413. mindspore/profiler/common/profiler_path_manager.py +66 -7
  414. mindspore/profiler/dynamic_profiler.py +112 -79
  415. mindspore/profiler/envprofiler.py +26 -1
  416. mindspore/profiler/experimental_config.py +197 -0
  417. mindspore/profiler/mstx.py +57 -14
  418. mindspore/profiler/platform/npu_profiler.py +33 -7
  419. mindspore/profiler/profiler.py +541 -45
  420. mindspore/profiler/profiler_action_controller.py +1 -1
  421. mindspore/profiler/profiler_interface.py +4 -0
  422. mindspore/profiler/schedule.py +57 -22
  423. mindspore/rewrite/api/node.py +15 -13
  424. mindspore/rewrite/api/symbol_tree.py +1 -1
  425. mindspore/run_check/_check_version.py +25 -14
  426. mindspore/run_check/run_check.py +1 -1
  427. mindspore/runtime/__init__.py +2 -2
  428. mindspore/runtime/executor.py +40 -11
  429. mindspore/runtime/memory.py +37 -13
  430. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  431. mindspore/swresample-4.dll +0 -0
  432. mindspore/swscale-6.dll +0 -0
  433. mindspore/tbbmalloc.dll +0 -0
  434. mindspore/tinyxml2.dll +0 -0
  435. mindspore/train/__init__.py +8 -8
  436. mindspore/train/_utils.py +43 -9
  437. mindspore/train/amp.py +1 -1
  438. mindspore/train/callback/__init__.py +2 -2
  439. mindspore/train/callback/_callback.py +2 -16
  440. mindspore/train/callback/_checkpoint.py +24 -40
  441. mindspore/train/callback/_cluster_monitor.py +14 -18
  442. mindspore/train/callback/_flops_collector.py +2 -3
  443. mindspore/train/callback/_history.py +7 -4
  444. mindspore/train/callback/_lambda_callback.py +2 -2
  445. mindspore/train/callback/_landscape.py +0 -3
  446. mindspore/train/callback/_loss_monitor.py +2 -1
  447. mindspore/train/callback/_on_request_exit.py +6 -5
  448. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  449. mindspore/train/callback/_summary_collector.py +8 -13
  450. mindspore/train/callback/_time_monitor.py +2 -1
  451. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -105
  452. mindspore/train/data_sink.py +25 -2
  453. mindspore/train/dataset_helper.py +4 -5
  454. mindspore/train/loss_scale_manager.py +8 -7
  455. mindspore/train/metrics/accuracy.py +3 -3
  456. mindspore/train/metrics/confusion_matrix.py +9 -9
  457. mindspore/train/metrics/error.py +3 -3
  458. mindspore/train/metrics/hausdorff_distance.py +4 -4
  459. mindspore/train/metrics/mean_surface_distance.py +3 -3
  460. mindspore/train/metrics/metric.py +0 -12
  461. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  462. mindspore/train/metrics/precision.py +8 -6
  463. mindspore/train/metrics/recall.py +9 -9
  464. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  465. mindspore/train/mind_ir_pb2.py +19 -12
  466. mindspore/train/model.py +262 -127
  467. mindspore/train/serialization.py +246 -988
  468. mindspore/train/summary/_summary_adapter.py +2 -2
  469. mindspore/train/summary/summary_record.py +1 -1
  470. mindspore/turbojpeg.dll +0 -0
  471. mindspore/utils/__init__.py +3 -2
  472. mindspore/utils/dryrun.py +4 -2
  473. mindspore/utils/hooks.py +81 -0
  474. mindspore/utils/runtime_execution_order_check.py +2 -0
  475. mindspore/utils/utils.py +138 -4
  476. mindspore/vcmeta.dll +0 -0
  477. mindspore/vcruntime140.dll +0 -0
  478. mindspore/vcruntime140_1.dll +0 -0
  479. mindspore/version.py +1 -1
  480. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/METADATA +2 -1
  481. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/RECORD +485 -440
  482. mindspore/_install_custom.py +0 -43
  483. mindspore/common/_register_for_adapter.py +0 -74
  484. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  485. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  486. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  487. mindspore/ops_generate/gen_constants.py +0 -190
  488. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  489. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  490. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
  492. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +0 -0
  493. {mindspore-2.5.0.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ from mindspore.ops.operations._inner_ops import issubclass_
27
27
  from mindspore.common.sparse_tensor import RowTensorInner
28
28
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
29
29
  from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce,
30
- NeighborExchange, AlltoAll, NeighborExchangeV2, Broadcast,
30
+ NeighborExchange, AlltoAll, AlltoAllV, NeighborExchangeV2, Broadcast,
31
31
  _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
32
32
  ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
33
33
  _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator,
@@ -192,7 +192,7 @@ def get_bprop_virtual_assign_add(self):
192
192
 
193
193
  def bprop(x, y, out, dout):
194
194
  if reduce_scatter:
195
- dout = reduce_scatter(dout)
195
+ dout = reduce_scatter(cast(dout, dtype(y)))
196
196
  return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), assign_add(y, dout))
197
197
 
198
198
  return bprop
@@ -451,7 +451,7 @@ def get_bprop_micro_step_all_gather(self):
451
451
  if with_mirror_operator:
452
452
  if not do_mirror:
453
453
  return (dout, cast(out_tensor, dtype(z)))
454
- real_grad = reduce_scatter(dout)
454
+ real_grad = reduce_scatter(cast(dout, dtype(z)))
455
455
  if mean_flag:
456
456
  real_grad = F.tensor_mul(real_grad, scale)
457
457
  return (real_grad, cast(out_tensor, dtype(z)))
@@ -635,6 +635,21 @@ def get_bprop_all_to_all(self):
635
635
  return bprop
636
636
 
637
637
 
638
+ @bprop_getters.register(AlltoAllV)
639
+ def get_bprop_all_to_all_v(self):
640
+ """Generate bprop for AlltoAll."""
641
+ all_to_all_v_grad = AlltoAllV(self.group, self.block_size)
642
+ if hasattr(self, "instance_name") and self.instance_name:
643
+ instance_name = "grad" + self.instance_name
644
+ all_to_all_v_grad.set_prim_instance_name(instance_name)
645
+
646
+ def bprop(x, send_numel_list, recv_numel_list, out, dout):
647
+ dx = all_to_all_v_grad(dout, recv_numel_list, send_numel_list)
648
+ return (dx, zeros_like(send_numel_list), zeros_like(recv_numel_list))
649
+
650
+ return bprop
651
+
652
+
638
653
  @bprop_getters.register(NeighborExchangeV2)
639
654
  def get_bprop_neighborexchangev2(self):
640
655
  """Generate bprop for NeighborExchangeV2."""
@@ -17,7 +17,8 @@
17
17
 
18
18
  import mindspore.ops.functional as F
19
19
  from mindspore.ops import operations as P
20
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
20
+ from mindspore.ops.composite import multitype_ops as C
21
+ from mindspore.ops._grad_experimental.grad_base import bprop_getters, bprops
21
22
 
22
23
  # Unused parameters are placeholders.
23
24
 
@@ -34,3 +35,9 @@ def get_bprop_insert_gradient_of(self):
34
35
  return (dout,)
35
36
  return (fdout,)
36
37
  return bprop
38
+
39
+
40
+ @bprops.register("TensorDump")
41
+ def bprop_tensor_dump(file, input_x, out, dout):
42
+ """Generate bprop for TensorDump"""
43
+ return file, C.zeros_like(input_x)
@@ -127,6 +127,35 @@ def taylor_realdiv(self):
127
127
  return taylor_fprop_realdiv
128
128
 
129
129
 
130
+ def _taylor_fprop_div(input_x, input_y):
131
+ """The rule to generate `Div` taylor rule."""
132
+ if not input_y.shape:
133
+ return ops.div(input_x, input_y)
134
+ input_x, input_y = _trans_scalar_inputs(input_x, input_y)
135
+ primals = ops.div(input_x[0], input_y[0])
136
+ series_num = len(input_x) - 1
137
+ factorial = _factorial(series_num)
138
+ series = zeros_like(input_x)
139
+ series[0] = primals
140
+ for k in range(1, series_num + 1):
141
+ for i in range(0, k):
142
+ tmp = ops.div(series[i] * input_y[k - i], (factorial[k - i] * factorial[i]))
143
+ series[k] += tmp
144
+ series[k] = ops.div(input_x[k] - factorial[k] * series[k], input_y[0])
145
+ return series
146
+
147
+
148
+ @taylor_fprop_getters.register(P.Div)
149
+ def taylor_div(self):
150
+ """Higher order derivatives rule definition for `Div` operation."""
151
+
152
+ def taylor_fprop_div(input_x, input_y):
153
+ series = _taylor_fprop_div(input_x, input_y)
154
+ return series
155
+
156
+ return taylor_fprop_div
157
+
158
+
130
159
  @taylor_fprop_getters.register(P.Exp)
131
160
  def taylor_exp(self):
132
161
  """Higher order derivatives rule definition for `Exp` operation."""
@@ -60,14 +60,3 @@ class PyFuncRegistry(UserDict):
60
60
  if key not in self:
61
61
  raise ValueError(f"Python function with key{key} not registered.")
62
62
  return self[key]
63
-
64
-
65
- class OpaquePredicateRegistry(PyFuncRegistry):
66
- """Registry opaque predicate functions used for dynamic obfuscation"""
67
- def __init__(self):
68
- super(OpaquePredicateRegistry, self).__init__()
69
- self.func_names = []
70
-
71
- def register(self, key, value):
72
- self[key] = value
73
- self.func_names.append(key)
@@ -20,14 +20,133 @@ import mindspore as ms
20
20
  from mindspore import ops
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
23
- from mindspore.ops_generate.gen_ops_inner_prim import TupleToList, ListToTuple
24
23
  from mindspore._c_expression import OpDtype
24
+ from mindspore._c_expression import typing
25
+ from mindspore._c_expression import op_enum
26
+ from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register
25
27
 
26
28
  tensor_to_tuple_ = TensorToTuple()
29
+ tensor_to_scalar_ = TensorToScalar()
30
+
31
+
32
+ class TupleToList(Primitive):
33
+ r"""
34
+ Convert tuple to list.
35
+
36
+ Inputs:
37
+ - **x** (tuple) - The input
38
+
39
+ Outputs:
40
+ List, has the same elements as the `input`.
41
+
42
+ Supported Platforms:
43
+ ``CPU``
44
+
45
+ Examples:
46
+ >>> from mindspore.ops._utils.arg_dtype_cast import TupleToList
47
+ >>> x = (1, 2, 3)
48
+ >>> result = TupleToList()(x)
49
+ >>> print(result)
50
+ [1, 2, 3]
51
+ """
52
+ @prim_arg_register
53
+ def __init__(self):
54
+ """Initialize TupleToList"""
55
+
56
+ def __call__(self, input):
57
+ return list(input)
58
+
59
+
60
+ class ListToTuple(Primitive):
61
+ r"""
62
+ Convert list to tuple.
63
+
64
+ Inputs:
65
+ - **x** (list) - The input
66
+
67
+ Outputs:
68
+ Tuple, has the same elements as the `input`.
69
+
70
+ Supported Platforms:
71
+ ``CPU``
72
+
73
+ Examples:
74
+ >>> from mindspore.ops._utils.arg_dtype_cast import ListToTuple
75
+ >>> x = [1, 2, 3]
76
+ >>> result = ListToTuple()(x)
77
+ >>> print(result)
78
+ (1, 2, 3)
79
+ """
80
+ @prim_arg_register
81
+ def __init__(self):
82
+ """Initialize TupleToList"""
83
+
84
+ def __call__(self, input):
85
+ return tuple(input)
86
+
87
+
27
88
  tuple_to_list = TupleToList()
28
89
  list_to_tuple = ListToTuple()
29
90
 
30
91
 
92
+ class DtypeToEnum(Primitive):
93
+ r"""
94
+ Convert mindspore dtype to enum.
95
+
96
+ Inputs:
97
+ - **op_name** (str) - The op name
98
+ - **arg_name** (str) - The arg name
99
+ - **dtype** (mindspore.dtype) - The data type.
100
+
101
+ Outputs:
102
+ An integer.
103
+
104
+ Supported Platforms:
105
+ ``Ascend`` ``GPU`` ``CPU``
106
+ """
107
+
108
+ @prim_attr_register
109
+ def __init__(self):
110
+ """Initialize"""
111
+
112
+ def __call__(self, op_name, arg_name, dtype):
113
+ """Run in PyNative mode"""
114
+ if not isinstance(dtype, typing.Type):
115
+ raise TypeError(
116
+ f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
117
+ return typing.type_to_type_id(dtype)
118
+
119
+
120
+ class StringToEnum(Primitive):
121
+ r"""
122
+ Convert string to enum.
123
+
124
+ Inputs:
125
+ - **op_name** (str) - The op name
126
+ - **arg_name** (str) - The arg name
127
+ - **enum_str** (str) - The str data.
128
+
129
+ Outputs:
130
+ An integer.
131
+
132
+ Supported Platforms:
133
+ ``CPU``
134
+ """
135
+
136
+ @prim_attr_register
137
+ def __init__(self):
138
+ """Initialize"""
139
+
140
+ def __call__(self, op_name, arg_name, enum_str):
141
+ """Run in PyNative mode"""
142
+ if enum_str is None:
143
+ return None
144
+ if not isinstance(enum_str, str):
145
+ raise TypeError(
146
+ f"For '{op_name}', the input '{arg_name}' should be a str, but got {type(enum_str)}.")
147
+ return op_enum.str_to_enum(op_name, arg_name, enum_str)
148
+
149
+
31
150
  def int_to_float(data):
32
151
  return float(data)
33
152
 
@@ -184,7 +303,7 @@ def get_support_dtype_list(src_type, dst_type):
184
303
  return support_list
185
304
 
186
305
 
187
- def to_py_number(data, dst_type):
306
+ def tensor_to_number(data, dst_type):
188
307
  """Convert tensor to python number"""
189
308
  if dst_type == DT_INT_VAL:
190
309
  data = ops.cast(data, ms.int64)
@@ -197,7 +316,7 @@ def to_py_number(data, dst_type):
197
316
  data = ops.cast(data, ms.int64)
198
317
  elif src_type in (ms.bfloat16, ms.float16, ms.float32, ms.float64):
199
318
  data = ops.cast(data, ms.float32)
200
- return TensorToScalar()(data)
319
+ return tensor_to_scalar_(data)
201
320
 
202
321
 
203
322
  def do_type_cast(data, dst_type):
@@ -230,7 +349,7 @@ def do_type_cast(data, dst_type):
230
349
  return list_to_tensor(data)
231
350
  elif is_number(dst_type):
232
351
  if isinstance(data, Tensor):
233
- return to_py_number(data, dst_type)
352
+ return tensor_to_number(data, dst_type)
234
353
  raise TypeError("Type conversion failed.")
235
354
 
236
355
 
@@ -14,13 +14,11 @@
14
14
  # ============================================================================
15
15
  """Operator argument handle function."""
16
16
 
17
- from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
18
- # Enum Class:
19
- from mindspore._c_expression import FormatEnum as Format
20
- from mindspore._c_expression import ReductionEnum as Reduction
21
17
  from mindspore.common import Tensor
22
18
  from mindspore.common import dtype as mstype
23
19
 
20
+ from .arg_dtype_cast import DtypeToEnum, StringToEnum
21
+
24
22
 
25
23
  def arg_invalid_info(op_name, arg_name, arg_val):
26
24
  """
@@ -128,6 +126,7 @@ def generator_handler(op_name, arg_name, inputs):
128
126
  new_inputs.append(input_)
129
127
  return tuple(new_inputs)
130
128
 
129
+
131
130
  dtype_to_type_id = DtypeToEnum()
132
131
 
133
132
  # string to enum
@@ -21,7 +21,7 @@ import mindspore
21
21
  import mindspore.numpy as mnp
22
22
  from mindspore import ops
23
23
  from mindspore.common import Tensor
24
- from mindspore._c_expression import Tensor as Tensor_
24
+ from mindspore._c_expression import TensorPy as Tensor_
25
25
  from mindspore.ops import operations as P
26
26
  from mindspore.ops import functional as F
27
27
  from mindspore.ops.primitive import constexpr, _primexpr
@@ -141,6 +141,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
141
141
  the generated prefix is a Tensor([[[0], [0]],
142
142
  [[1], [1]]])
143
143
  """
144
+ cast_op = P.Cast()
145
+
144
146
  def _check(indices_shape):
145
147
  if not indices_shape:
146
148
  raise ValueError("indices_shape is empty in _get_prefix.")
@@ -148,8 +150,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
148
150
  _check(indices_shape)
149
151
  indices_len = len(indices_shape)
150
152
  if indices_len == 1:
151
- prefix = P.Range()(Tensor(0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype))
152
- return prefix
153
+ prefix = P.Range()(0, axis_size, 1)
154
+ return cast_op(prefix, indices_dtype)
153
155
 
154
156
  indices_end = indices_len - 1
155
157
  prefix_shape = ()
@@ -164,9 +166,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
164
166
  else:
165
167
  expand_shape = expand_shape + (1,)
166
168
 
167
- prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
168
- 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
169
- return prefix
169
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(0, axis_size, 1), expand_shape))
170
+ return cast_op(prefix, indices_dtype)
170
171
 
171
172
 
172
173
  @vmap_rules_getters.register(P.Transpose)
@@ -213,6 +214,31 @@ def get_transpose_vmap_rule(prim, axis_size):
213
214
  return vmap_rule
214
215
 
215
216
 
217
+ @vmap_rules_getters.register("TransposeExtView")
218
+ def get_transpose_ext_vmap_rule(prim, axis_size):
219
+ """VmapRule for `TransposeExtView` operation."""
220
+ if isinstance(prim, str):
221
+ prim = Primitive(prim)
222
+
223
+ def vmap_rule(x_bdim, dim1_bdim, dim2_bdim):
224
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, dim1_bdim, dim2_bdim)
225
+ if is_all_none:
226
+ return result
227
+
228
+ x, dim = x_bdim
229
+ dim1, dim1_dim = dim1_bdim
230
+ dim2, dim2_dim = dim2_bdim
231
+ if dim1_dim is not None or dim2_dim is not None:
232
+ _raise_value_error("The source axis of dim1_dim and dim2_dim in `TransposeExtView` must be None, "
233
+ "but got {} and {}.".format(dim1_dim, dim2_dim))
234
+ batch_dim1 = dim1 if dim1 < dim else dim1 + 1
235
+ batch_dim2 = dim2 if dim2 < dim else dim2 + 1
236
+ out = prim(x, batch_dim1, batch_dim2)
237
+ return out, dim
238
+
239
+ return vmap_rule
240
+
241
+
216
242
  @vmap_rules_getters.register("Tile")
217
243
  def get_tile_vmap_rule(prim, axis_size):
218
244
  """VmapRule for `P.Tile` operation."""
@@ -25,8 +25,9 @@ from mindspore.ops.primitive import _primexpr
25
25
  from mindspore.ops.function import _VmapGeneralRule
26
26
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
27
27
  _bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
28
- from mindspore.ops.auto_generate.gen_arg_handler import Format, Reduction
29
28
  from mindspore.ops import auto_generate as gen
29
+ from mindspore._c_expression import FormatEnum as Format
30
+ from mindspore._c_expression import ReductionEnum as Reduction
30
31
 
31
32
 
32
33
  @vmap_rules_getters.register(G.NLLLossGrad)
@@ -561,20 +561,17 @@ def get_index_add_vmap_rule(prim, axis_size):
561
561
  @vmap_rules_getters.register(linalg_ops.Svd)
562
562
  def get_svd_vmap_rule(prim, axis_size):
563
563
  """VmapRule for 'Svd' operation."""
564
- if isinstance(prim, str):
565
- prim = Primitive(prim)
566
- compute_uv = True
567
- else:
568
- compute_uv = prim.compute_uv
569
564
 
570
- def vmap_rule(x_bdim):
565
+ def vmap_rule(x_bdim, full_matrices_bdim, compute_uv_bdim):
571
566
  is_all_none, result = vmap_general_preprocess(prim, x_bdim)
572
567
  if is_all_none:
573
568
  return result
574
569
 
575
570
  x, x_dim = x_bdim
571
+ full_matrices, _ = full_matrices_bdim
572
+ compute_uv, _ = compute_uv_bdim
576
573
  x = _bdim_at_front(x, x_dim, axis_size)
577
- s, u, v = prim(x)
574
+ s, u, v = prim(x, full_matrices, compute_uv)
578
575
  if compute_uv:
579
576
  return (s, 0), (u, 0), (v, 0)
580
577
  return (s, 0), (u, None), (v, None)
@@ -29,9 +29,9 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
29
29
  _bdim_at_any, _bdim_at_front, _bdim_at_back, _handle_broadcasting, get_unary_grad_vmap_rule, _raise_value_error, \
30
30
  _vmap_clone_prim, _get_reduce_batch_axis
31
31
  from mindspore.ops.primitive import Primitive
32
- from mindspore.ops.auto_generate.gen_arg_handler import Format
33
32
  from mindspore.ops.auto_generate import Embedding
34
- from mindspore.ops.auto_generate import gen_arg_handler as handler
33
+ from mindspore.ops._utils import arg_handler as handler
34
+ from mindspore._c_expression import FormatEnum as Format
35
35
 
36
36
 
37
37
  @vmap_rules_getters.register(P.ApplyAdaMax)
@@ -1632,13 +1632,12 @@ def get_apply_adagrad_da_vmap_rule(prim, axis_size):
1632
1632
  return vmap_rule
1633
1633
 
1634
1634
 
1635
- @vmap_rules_getters.register(NN.AdaptiveMaxPool2D)
1635
+ @vmap_rules_getters.register(P.AdaptiveMaxPool2D)
1636
1636
  def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1637
1637
  """VmapRule for `AdaptiveMaxPool2D`."""
1638
1638
  nchw_index = 4
1639
1639
  chw_reverse_index = -3
1640
1640
  hw_size = 2
1641
- output_size = prim.output_size
1642
1641
 
1643
1642
  @_primexpr
1644
1643
  def get_output_shape(x_ori_shape, output_size):
@@ -1661,8 +1660,8 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1661
1660
  output_shape += (w_out,)
1662
1661
  return output_shape
1663
1662
 
1664
- def vmap_rule(input_x_bdim):
1665
- is_all_none, result = vmap_general_preprocess(prim, input_x_bdim)
1663
+ def vmap_rule(input_x_bdim, output_size_bdim):
1664
+ is_all_none, result = vmap_general_preprocess(prim, input_x_bdim, output_size_bdim)
1666
1665
  if is_all_none:
1667
1666
  return result
1668
1667
 
@@ -1670,18 +1669,20 @@ def get_adaptive_max_pool_2d_vmap_rule(prim, axis_size):
1670
1669
  x = _bdim_at_front(input_x, input_x_dim, axis_size)
1671
1670
  x_ndim = F.rank(x)
1672
1671
 
1672
+ output_size, _ = output_size_bdim
1673
+
1673
1674
  if x_ndim > nchw_index:
1674
1675
  # for the case of NCHW
1675
1676
  x_ori_shape = F.shape(x)
1676
1677
  x = F.reshape(x, (-1,) + x_ori_shape[chw_reverse_index:])
1677
1678
  output_shape = get_output_shape(x_ori_shape, output_size)
1678
- out, indices = prim(x)
1679
+ out, indices = prim(x, output_size)
1679
1680
  out = F.reshape(out, output_shape)
1680
1681
  indices = F.reshape(indices, output_shape)
1681
1682
  return (out, 0), (indices, 0)
1682
1683
 
1683
1684
  # for the case of CHW
1684
- out, indices = prim(x)
1685
+ out, indices = prim(x, output_size)
1685
1686
  return (out, 0), (indices, 0)
1686
1687
 
1687
1688
  return vmap_rule
@@ -19,13 +19,14 @@ Primitive operator classes and operator functional.
19
19
  A collection of operators to build neural networks or to compute functions.
20
20
  """
21
21
 
22
- from . import gen_ops_def, gen_arg_handler, gen_arg_dtype_cast
22
+ from . import gen_ops_def
23
+ from .._utils import arg_handler, arg_dtype_cast
23
24
 
24
25
  from .gen_ops_prim import *
25
26
  from .gen_ops_def import *
26
- from .gen_arg_handler import *
27
- from .gen_arg_dtype_cast import *
28
27
  from ..operations.manually_defined.ops_def import *
28
+ from .._utils.arg_handler import *
29
+ from .._utils.arg_dtype_cast import *
29
30
 
30
31
 
31
32
  __all__ = []