mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (602) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +13 -6
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -38
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +6 -7
  15. mindspore/_extends/parse/compile_config.py +83 -0
  16. mindspore/_extends/parse/deprecated/__init__.py +0 -0
  17. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
  18. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  19. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  20. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  21. mindspore/_extends/parse/parser.py +46 -197
  22. mindspore/_extends/parse/resources.py +1 -5
  23. mindspore/_extends/parse/standard_method.py +217 -98
  24. mindspore/_extends/pijit/__init__.py +2 -2
  25. mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
  26. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  27. mindspore/_extends/utils.py +1 -1
  28. mindspore/amp.py +11 -5
  29. mindspore/atlprov.dll +0 -0
  30. mindspore/avcodec-59.dll +0 -0
  31. mindspore/avdevice-59.dll +0 -0
  32. mindspore/avfilter-8.dll +0 -0
  33. mindspore/avformat-59.dll +0 -0
  34. mindspore/avutil-57.dll +0 -0
  35. mindspore/boost/__init__.py +2 -2
  36. mindspore/boost/base.py +3 -7
  37. mindspore/boost/boost_cell_wrapper.py +138 -43
  38. mindspore/c1.dll +0 -0
  39. mindspore/c1xx.dll +0 -0
  40. mindspore/c2.dll +0 -0
  41. mindspore/common/__init__.py +6 -3
  42. mindspore/common/_grad_function.py +56 -0
  43. mindspore/common/_pijit_context.py +14 -5
  44. mindspore/common/_register_for_tensor.py +1 -2
  45. mindspore/common/_stub_tensor.py +30 -14
  46. mindspore/common/_tensor_cpp_method.py +17 -0
  47. mindspore/common/_tensor_docs.py +4760 -0
  48. mindspore/common/api.py +435 -371
  49. mindspore/common/auto_dynamic_shape.py +41 -44
  50. mindspore/common/dtype.py +39 -36
  51. mindspore/common/dump.py +9 -6
  52. mindspore/common/file_system.py +9 -1
  53. mindspore/common/generator.py +2 -0
  54. mindspore/common/hook_handle.py +6 -2
  55. mindspore/common/initializer.py +13 -10
  56. mindspore/common/jit_begin_end.py +94 -0
  57. mindspore/common/jit_config.py +6 -1
  58. mindspore/common/jit_context.py +76 -0
  59. mindspore/common/jit_trace.py +378 -0
  60. mindspore/common/lazy_inline.py +9 -3
  61. mindspore/common/mindir_util.py +10 -2
  62. mindspore/common/mutable.py +5 -4
  63. mindspore/common/parameter.py +135 -52
  64. mindspore/common/seed.py +2 -2
  65. mindspore/common/sparse_tensor.py +23 -17
  66. mindspore/common/tensor.py +951 -1992
  67. mindspore/communication/__init__.py +7 -5
  68. mindspore/communication/_comm_helper.py +52 -2
  69. mindspore/communication/comm_func.py +240 -181
  70. mindspore/communication/management.py +95 -26
  71. mindspore/context.py +314 -566
  72. mindspore/dataset/__init__.py +65 -37
  73. mindspore/dataset/audio/__init__.py +2 -8
  74. mindspore/dataset/audio/transforms.py +3 -17
  75. mindspore/dataset/callback/ds_callback.py +2 -1
  76. mindspore/dataset/core/config.py +87 -6
  77. mindspore/dataset/engine/cache_admin.py +3 -3
  78. mindspore/dataset/engine/cache_client.py +6 -5
  79. mindspore/dataset/engine/datasets.py +292 -267
  80. mindspore/dataset/engine/datasets_audio.py +22 -8
  81. mindspore/dataset/engine/datasets_standard_format.py +46 -27
  82. mindspore/dataset/engine/datasets_text.py +78 -48
  83. mindspore/dataset/engine/datasets_user_defined.py +182 -116
  84. mindspore/dataset/engine/datasets_vision.py +120 -44
  85. mindspore/dataset/engine/iterators.py +283 -63
  86. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  87. mindspore/dataset/engine/obs/util.py +8 -0
  88. mindspore/dataset/engine/queue.py +40 -0
  89. mindspore/dataset/engine/samplers.py +289 -43
  90. mindspore/dataset/engine/serializer_deserializer.py +3 -2
  91. mindspore/dataset/engine/validators.py +53 -11
  92. mindspore/dataset/text/__init__.py +7 -6
  93. mindspore/dataset/text/transforms.py +6 -5
  94. mindspore/dataset/text/utils.py +3 -3
  95. mindspore/dataset/transforms/__init__.py +0 -9
  96. mindspore/dataset/transforms/py_transforms_util.py +17 -0
  97. mindspore/dataset/transforms/transforms.py +31 -14
  98. mindspore/dataset/utils/browse_dataset.py +1 -1
  99. mindspore/dataset/vision/__init__.py +2 -9
  100. mindspore/dataset/vision/transforms.py +202 -158
  101. mindspore/dataset/vision/utils.py +7 -5
  102. mindspore/dataset/vision/validators.py +1 -2
  103. mindspore/device_context/__init__.py +21 -0
  104. mindspore/device_context/ascend/__init__.py +25 -0
  105. mindspore/device_context/ascend/device.py +72 -0
  106. mindspore/device_context/ascend/op_debug.py +153 -0
  107. mindspore/device_context/ascend/op_precision.py +193 -0
  108. mindspore/device_context/ascend/op_tuning.py +123 -0
  109. mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
  110. mindspore/device_context/cpu/device.py +62 -0
  111. mindspore/device_context/cpu/op_tuning.py +43 -0
  112. mindspore/device_context/gpu/__init__.py +21 -0
  113. mindspore/device_context/gpu/device.py +70 -0
  114. mindspore/device_context/gpu/op_precision.py +67 -0
  115. mindspore/device_context/gpu/op_tuning.py +175 -0
  116. mindspore/device_manager.py +170 -0
  117. mindspore/dnnl.dll +0 -0
  118. mindspore/dpcmi.dll +0 -0
  119. mindspore/experimental/es/embedding_service.py +35 -27
  120. mindspore/experimental/llm_boost/__init__.py +1 -0
  121. mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
  122. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
  123. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
  124. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  125. mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
  126. mindspore/experimental/llm_boost/register.py +1 -0
  127. mindspore/experimental/map_parameter.py +4 -4
  128. mindspore/experimental/optim/adadelta.py +6 -6
  129. mindspore/experimental/optim/adagrad.py +4 -4
  130. mindspore/experimental/optim/adam.py +7 -0
  131. mindspore/experimental/optim/adamax.py +4 -4
  132. mindspore/experimental/optim/adamw.py +4 -0
  133. mindspore/experimental/optim/asgd.py +1 -1
  134. mindspore/experimental/optim/lr_scheduler.py +73 -46
  135. mindspore/experimental/optim/radam.py +34 -31
  136. mindspore/experimental/optim/rprop.py +1 -1
  137. mindspore/experimental/optim/sgd.py +1 -1
  138. mindspore/hal/contiguous_tensors_handle.py +6 -10
  139. mindspore/hal/device.py +55 -53
  140. mindspore/hal/event.py +52 -52
  141. mindspore/hal/memory.py +157 -117
  142. mindspore/hal/stream.py +150 -109
  143. mindspore/include/api/context.h +0 -1
  144. mindspore/include/dataset/constants.h +7 -4
  145. mindspore/include/dataset/execute.h +2 -2
  146. mindspore/jpeg62.dll +0 -0
  147. mindspore/log.py +50 -0
  148. mindspore/mindrecord/__init__.py +21 -8
  149. mindspore/mindrecord/config.py +17 -316
  150. mindspore/mindrecord/filereader.py +1 -9
  151. mindspore/mindrecord/filewriter.py +5 -15
  152. mindspore/mindrecord/mindpage.py +1 -9
  153. mindspore/mindspore_backend_common.dll +0 -0
  154. mindspore/mindspore_backend_manager.dll +0 -0
  155. mindspore/mindspore_common.dll +0 -0
  156. mindspore/mindspore_core.dll +0 -0
  157. mindspore/mindspore_dump.dll +0 -0
  158. mindspore/mindspore_frontend.dll +0 -0
  159. mindspore/mindspore_glog.dll +0 -0
  160. mindspore/mindspore_memory_pool.dll +0 -0
  161. mindspore/mindspore_ms_backend.dll +0 -0
  162. mindspore/mindspore_ops.dll +0 -0
  163. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  164. mindspore/mindspore_ops_kernel_common.dll +0 -0
  165. mindspore/mindspore_profiler.dll +0 -0
  166. mindspore/mindspore_pyboost.dll +0 -0
  167. mindspore/mindspore_pynative.dll +0 -0
  168. mindspore/mindspore_res_manager.dll +0 -0
  169. mindspore/mindspore_runtime_pipeline.dll +0 -0
  170. mindspore/mint/__init__.py +796 -759
  171. mindspore/mint/distributed/__init__.py +70 -4
  172. mindspore/mint/distributed/distributed.py +2679 -44
  173. mindspore/mint/linalg/__init__.py +8 -0
  174. mindspore/mint/nn/__init__.py +743 -22
  175. mindspore/mint/nn/functional.py +716 -23
  176. mindspore/mint/nn/layer/__init__.py +21 -4
  177. mindspore/mint/nn/layer/_functions.py +334 -0
  178. mindspore/mint/nn/layer/activation.py +276 -1
  179. mindspore/mint/nn/layer/basic.py +123 -0
  180. mindspore/mint/nn/layer/conv.py +921 -0
  181. mindspore/mint/nn/layer/normalization.py +223 -28
  182. mindspore/mint/nn/layer/padding.py +797 -0
  183. mindspore/mint/nn/layer/pooling.py +235 -0
  184. mindspore/mint/optim/__init__.py +3 -1
  185. mindspore/mint/optim/adam.py +223 -0
  186. mindspore/mint/optim/adamw.py +26 -19
  187. mindspore/mint/optim/sgd.py +171 -0
  188. mindspore/mint/special/__init__.py +2 -1
  189. mindspore/msobj140.dll +0 -0
  190. mindspore/mspdb140.dll +0 -0
  191. mindspore/mspdbcore.dll +0 -0
  192. mindspore/mspdbst.dll +0 -0
  193. mindspore/mspft140.dll +0 -0
  194. mindspore/msvcdis140.dll +0 -0
  195. mindspore/msvcp140_1.dll +0 -0
  196. mindspore/msvcp140_2.dll +0 -0
  197. mindspore/msvcp140_atomic_wait.dll +0 -0
  198. mindspore/msvcp140_codecvt_ids.dll +0 -0
  199. mindspore/multiprocessing/__init__.py +5 -0
  200. mindspore/nn/__init__.py +4 -1
  201. mindspore/nn/cell.py +1370 -189
  202. mindspore/nn/dynamic_lr.py +2 -1
  203. mindspore/nn/layer/activation.py +29 -27
  204. mindspore/nn/layer/basic.py +51 -35
  205. mindspore/nn/layer/channel_shuffle.py +3 -3
  206. mindspore/nn/layer/container.py +1 -1
  207. mindspore/nn/layer/conv.py +22 -17
  208. mindspore/nn/layer/embedding.py +12 -11
  209. mindspore/nn/layer/normalization.py +56 -49
  210. mindspore/nn/layer/padding.py +4 -3
  211. mindspore/nn/layer/pooling.py +120 -42
  212. mindspore/nn/layer/rnn_cells.py +1 -1
  213. mindspore/nn/layer/rnns.py +2 -1
  214. mindspore/nn/layer/timedistributed.py +5 -5
  215. mindspore/nn/layer/transformer.py +59 -36
  216. mindspore/nn/learning_rate_schedule.py +8 -4
  217. mindspore/nn/loss/loss.py +58 -55
  218. mindspore/nn/optim/ada_grad.py +7 -5
  219. mindspore/nn/optim/adadelta.py +11 -9
  220. mindspore/nn/optim/adafactor.py +1 -1
  221. mindspore/nn/optim/adam.py +17 -13
  222. mindspore/nn/optim/adamax.py +8 -7
  223. mindspore/nn/optim/adasum.py +5 -5
  224. mindspore/nn/optim/asgd.py +1 -1
  225. mindspore/nn/optim/ftrl.py +11 -9
  226. mindspore/nn/optim/lamb.py +1 -1
  227. mindspore/nn/optim/lars.py +1 -4
  228. mindspore/nn/optim/lazyadam.py +12 -10
  229. mindspore/nn/optim/momentum.py +7 -6
  230. mindspore/nn/optim/optimizer.py +3 -3
  231. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  232. mindspore/nn/optim/rmsprop.py +13 -12
  233. mindspore/nn/optim/rprop.py +11 -9
  234. mindspore/nn/optim/sgd.py +9 -6
  235. mindspore/nn/optim/tft_wrapper.py +5 -2
  236. mindspore/nn/optim/thor.py +2 -1
  237. mindspore/nn/probability/bijector/bijector.py +17 -11
  238. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  239. mindspore/nn/probability/bijector/invert.py +2 -2
  240. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  241. mindspore/nn/probability/bijector/softplus.py +3 -2
  242. mindspore/nn/probability/distribution/beta.py +3 -3
  243. mindspore/nn/probability/distribution/categorical.py +1 -1
  244. mindspore/nn/probability/distribution/cauchy.py +4 -2
  245. mindspore/nn/probability/distribution/exponential.py +6 -7
  246. mindspore/nn/probability/distribution/gamma.py +2 -2
  247. mindspore/nn/probability/distribution/gumbel.py +2 -2
  248. mindspore/nn/probability/distribution/half_normal.py +5 -3
  249. mindspore/nn/probability/distribution/logistic.py +5 -3
  250. mindspore/nn/probability/distribution/poisson.py +1 -1
  251. mindspore/nn/probability/distribution/uniform.py +5 -3
  252. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  253. mindspore/nn/reinforcement/tensor_array.py +1 -1
  254. mindspore/nn/utils/init.py +13 -11
  255. mindspore/nn/wrap/__init__.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +181 -122
  257. mindspore/nn/wrap/grad_reducer.py +45 -36
  258. mindspore/nn/wrap/loss_scale.py +6 -7
  259. mindspore/numpy/array_creations.py +63 -65
  260. mindspore/numpy/array_ops.py +149 -144
  261. mindspore/numpy/logic_ops.py +41 -42
  262. mindspore/numpy/math_ops.py +365 -363
  263. mindspore/numpy/utils.py +17 -18
  264. mindspore/numpy/utils_const.py +5 -6
  265. mindspore/opencv_core452.dll +0 -0
  266. mindspore/opencv_imgcodecs452.dll +0 -0
  267. mindspore/opencv_imgproc452.dll +0 -0
  268. mindspore/ops/__init__.py +5 -3
  269. mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
  270. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
  271. mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
  272. mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
  273. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  274. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  275. mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
  276. mindspore/ops/_register_for_op.py +0 -11
  277. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  278. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
  279. mindspore/ops/_vmap/vmap_array_ops.py +27 -25
  280. mindspore/ops/_vmap/vmap_base.py +0 -2
  281. mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
  282. mindspore/ops/_vmap/vmap_math_ops.py +15 -16
  283. mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
  284. mindspore/ops/auto_generate/__init__.py +4 -3
  285. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
  286. mindspore/ops/auto_generate/gen_extend_func.py +764 -124
  287. mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
  288. mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
  289. mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
  290. mindspore/ops/composite/__init__.py +2 -1
  291. mindspore/ops/composite/base.py +20 -25
  292. mindspore/ops/composite/math_ops.py +6 -16
  293. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  294. mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
  295. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  296. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  297. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  298. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  299. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  300. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  301. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  302. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  303. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  304. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  305. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  306. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  308. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  309. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  310. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  311. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  312. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  313. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  314. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  315. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  316. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  317. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  320. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
  321. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  324. mindspore/ops/function/__init__.py +40 -2
  325. mindspore/ops/function/_add_attr_func.py +58 -0
  326. mindspore/ops/function/array_func.py +2089 -2403
  327. mindspore/ops/function/clip_func.py +80 -23
  328. mindspore/ops/function/debug_func.py +57 -57
  329. mindspore/ops/function/grad/__init__.py +1 -0
  330. mindspore/ops/function/grad/grad_func.py +104 -71
  331. mindspore/ops/function/image_func.py +2 -2
  332. mindspore/ops/function/linalg_func.py +47 -78
  333. mindspore/ops/function/math_func.py +4501 -3802
  334. mindspore/ops/function/nn_func.py +1726 -620
  335. mindspore/ops/function/other_func.py +159 -1
  336. mindspore/ops/function/parameter_func.py +18 -84
  337. mindspore/ops/function/random_func.py +440 -387
  338. mindspore/ops/function/reshard_func.py +4 -70
  339. mindspore/ops/function/sparse_func.py +3 -3
  340. mindspore/ops/function/sparse_unary_func.py +6 -6
  341. mindspore/ops/function/spectral_func.py +25 -58
  342. mindspore/ops/function/vmap_func.py +24 -17
  343. mindspore/ops/functional.py +22 -7
  344. mindspore/ops/functional_overload.py +1440 -0
  345. mindspore/ops/op_info_register.py +32 -244
  346. mindspore/ops/operations/__init__.py +13 -7
  347. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  348. mindspore/ops/operations/_embedding_cache_ops.py +4 -4
  349. mindspore/ops/operations/_grad_ops.py +2 -43
  350. mindspore/ops/operations/_infer_ops.py +2 -1
  351. mindspore/ops/operations/_inner_ops.py +43 -84
  352. mindspore/ops/operations/_ms_kernel.py +4 -10
  353. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  354. mindspore/ops/operations/_scalar_ops.py +3 -2
  355. mindspore/ops/operations/_sequence_ops.py +1 -1
  356. mindspore/ops/operations/_tensor_array.py +1 -1
  357. mindspore/ops/operations/array_ops.py +81 -324
  358. mindspore/ops/operations/comm_ops.py +154 -108
  359. mindspore/ops/operations/custom_ops.py +232 -78
  360. mindspore/ops/operations/debug_ops.py +153 -59
  361. mindspore/ops/operations/inner_ops.py +7 -5
  362. mindspore/ops/operations/linalg_ops.py +1 -57
  363. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  364. mindspore/ops/operations/manually_defined/ops_def.py +928 -180
  365. mindspore/ops/operations/math_ops.py +32 -234
  366. mindspore/ops/operations/nn_ops.py +210 -498
  367. mindspore/ops/operations/other_ops.py +62 -9
  368. mindspore/ops/operations/random_ops.py +13 -7
  369. mindspore/ops/operations/reshard_ops.py +1 -1
  370. mindspore/ops/operations/sparse_ops.py +2 -2
  371. mindspore/ops/primitive.py +66 -53
  372. mindspore/ops/tensor_method.py +1888 -0
  373. mindspore/ops_generate/__init__.py +0 -5
  374. mindspore/ops_generate/aclnn/__init__.py +0 -0
  375. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
  376. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
  377. mindspore/ops_generate/api/__init__.py +0 -0
  378. mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
  379. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
  380. mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
  381. mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
  382. mindspore/ops_generate/api/functions_cc_generator.py +237 -0
  383. mindspore/ops_generate/api/gen_api.py +103 -0
  384. mindspore/ops_generate/api/op_api_proto.py +235 -0
  385. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
  386. mindspore/ops_generate/common/__init__.py +0 -0
  387. mindspore/ops_generate/common/base_generator.py +11 -0
  388. mindspore/ops_generate/common/gen_constants.py +91 -0
  389. mindspore/ops_generate/common/gen_utils.py +348 -0
  390. mindspore/ops_generate/common/op_proto.py +473 -0
  391. mindspore/ops_generate/common/template.py +523 -0
  392. mindspore/ops_generate/gen_ops.py +22 -1069
  393. mindspore/ops_generate/op_def/__init__.py +0 -0
  394. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  395. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
  396. mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
  397. mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
  398. mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
  399. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  400. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  401. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  402. mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
  403. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
  404. mindspore/ops_generate/pyboost/__init__.py +0 -0
  405. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
  406. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
  407. mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
  408. mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
  409. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
  410. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
  411. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
  412. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
  413. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
  414. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
  415. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
  416. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
  417. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
  418. mindspore/ops_generate/resources/__init__.py +0 -0
  419. mindspore/ops_generate/resources/resource_list.py +30 -0
  420. mindspore/ops_generate/resources/resource_loader.py +36 -0
  421. mindspore/ops_generate/resources/resource_manager.py +64 -0
  422. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  423. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  424. mindspore/parallel/__init__.py +7 -3
  425. mindspore/parallel/_auto_parallel_context.py +152 -34
  426. mindspore/parallel/_cell_wrapper.py +130 -15
  427. mindspore/parallel/_parallel_serialization.py +107 -5
  428. mindspore/parallel/_ps_context.py +1 -1
  429. mindspore/parallel/_recovery_context.py +7 -2
  430. mindspore/parallel/_tensor.py +142 -18
  431. mindspore/parallel/_utils.py +199 -23
  432. mindspore/parallel/algo_parameter_config.py +4 -4
  433. mindspore/parallel/auto_parallel.py +732 -0
  434. mindspore/parallel/checkpoint_convert.py +159 -0
  435. mindspore/parallel/checkpoint_transform.py +698 -35
  436. mindspore/parallel/cluster/process_entity/_api.py +276 -50
  437. mindspore/parallel/cluster/process_entity/_utils.py +41 -6
  438. mindspore/parallel/cluster/run.py +21 -4
  439. mindspore/parallel/function/__init__.py +24 -0
  440. mindspore/parallel/function/reshard_func.py +259 -0
  441. mindspore/parallel/nn/__init__.py +25 -0
  442. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  443. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  444. mindspore/parallel/parameter_broadcast.py +25 -14
  445. mindspore/parallel/shard.py +137 -58
  446. mindspore/parallel/transform_safetensors.py +363 -305
  447. mindspore/pgodb140.dll +0 -0
  448. mindspore/pgort140.dll +0 -0
  449. mindspore/profiler/__init__.py +22 -5
  450. mindspore/profiler/analysis/__init__.py +0 -0
  451. mindspore/profiler/analysis/parser/__init__.py +0 -0
  452. mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
  453. mindspore/profiler/analysis/parser/base_parser.py +158 -0
  454. mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
  455. mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
  456. mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
  457. mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
  458. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
  459. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
  460. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
  461. mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
  462. mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
  463. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
  464. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
  465. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
  466. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
  467. mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
  468. mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
  469. mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
  470. mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
  471. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
  472. mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
  473. mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
  474. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
  475. mindspore/profiler/analysis/task_manager.py +131 -0
  476. mindspore/profiler/analysis/time_converter.py +84 -0
  477. mindspore/profiler/analysis/viewer/__init__.py +0 -0
  478. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
  479. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
  480. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
  481. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
  482. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
  483. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
  484. mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
  485. mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
  486. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
  487. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
  488. mindspore/profiler/analysis/work_flow.py +73 -0
  489. mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
  490. mindspore/profiler/common/command_executor.py +90 -0
  491. mindspore/profiler/common/constant.py +186 -3
  492. mindspore/profiler/common/file_manager.py +208 -0
  493. mindspore/profiler/common/log.py +130 -0
  494. mindspore/profiler/common/msprof_cmd_tool.py +221 -0
  495. mindspore/profiler/common/path_manager.py +395 -0
  496. mindspore/profiler/common/process_bar.py +168 -0
  497. mindspore/profiler/common/process_pool.py +9 -3
  498. mindspore/profiler/common/profiler_context.py +500 -0
  499. mindspore/profiler/common/profiler_info.py +304 -0
  500. mindspore/profiler/common/profiler_meta_data.py +74 -0
  501. mindspore/profiler/common/profiler_output_path.py +284 -0
  502. mindspore/profiler/common/profiler_parameters.py +251 -0
  503. mindspore/profiler/common/profiler_path_manager.py +179 -0
  504. mindspore/profiler/common/record_function.py +76 -0
  505. mindspore/profiler/common/tlv_decoder.py +76 -0
  506. mindspore/profiler/common/util.py +75 -2
  507. mindspore/profiler/dynamic_profiler.py +341 -75
  508. mindspore/profiler/envprofiler.py +163 -0
  509. mindspore/profiler/experimental_config.py +197 -0
  510. mindspore/profiler/mstx.py +242 -0
  511. mindspore/profiler/platform/__init__.py +21 -0
  512. mindspore/profiler/platform/base_profiler.py +40 -0
  513. mindspore/profiler/platform/cpu_profiler.py +124 -0
  514. mindspore/profiler/platform/gpu_profiler.py +74 -0
  515. mindspore/profiler/platform/npu_profiler.py +335 -0
  516. mindspore/profiler/profiler.py +1073 -90
  517. mindspore/profiler/profiler_action_controller.py +187 -0
  518. mindspore/profiler/profiler_interface.py +118 -0
  519. mindspore/profiler/schedule.py +243 -0
  520. mindspore/rewrite/api/node.py +15 -13
  521. mindspore/rewrite/api/symbol_tree.py +2 -3
  522. mindspore/run_check/_check_version.py +27 -20
  523. mindspore/run_check/run_check.py +1 -1
  524. mindspore/runtime/__init__.py +37 -0
  525. mindspore/runtime/device.py +27 -0
  526. mindspore/runtime/event.py +209 -0
  527. mindspore/runtime/executor.py +177 -0
  528. mindspore/runtime/memory.py +409 -0
  529. mindspore/runtime/stream.py +460 -0
  530. mindspore/runtime/thread_bind_core.py +401 -0
  531. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  532. mindspore/swresample-4.dll +0 -0
  533. mindspore/swscale-6.dll +0 -0
  534. mindspore/tbbmalloc.dll +0 -0
  535. mindspore/tinyxml2.dll +0 -0
  536. mindspore/train/__init__.py +8 -8
  537. mindspore/train/_utils.py +88 -25
  538. mindspore/train/amp.py +9 -5
  539. mindspore/train/callback/__init__.py +2 -2
  540. mindspore/train/callback/_callback.py +2 -16
  541. mindspore/train/callback/_checkpoint.py +53 -55
  542. mindspore/train/callback/_cluster_monitor.py +14 -18
  543. mindspore/train/callback/_early_stop.py +1 -1
  544. mindspore/train/callback/_flops_collector.py +103 -68
  545. mindspore/train/callback/_history.py +8 -5
  546. mindspore/train/callback/_lambda_callback.py +2 -2
  547. mindspore/train/callback/_landscape.py +0 -3
  548. mindspore/train/callback/_loss_monitor.py +2 -1
  549. mindspore/train/callback/_on_request_exit.py +6 -5
  550. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  551. mindspore/train/callback/_summary_collector.py +52 -19
  552. mindspore/train/callback/_time_monitor.py +2 -1
  553. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
  554. mindspore/train/data_sink.py +25 -2
  555. mindspore/train/dataset_helper.py +15 -16
  556. mindspore/train/loss_scale_manager.py +8 -7
  557. mindspore/train/metrics/accuracy.py +3 -3
  558. mindspore/train/metrics/confusion_matrix.py +9 -9
  559. mindspore/train/metrics/error.py +3 -3
  560. mindspore/train/metrics/hausdorff_distance.py +4 -4
  561. mindspore/train/metrics/mean_surface_distance.py +3 -3
  562. mindspore/train/metrics/metric.py +0 -12
  563. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  564. mindspore/train/metrics/precision.py +11 -10
  565. mindspore/train/metrics/recall.py +9 -9
  566. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  567. mindspore/train/mind_ir_pb2.py +174 -46
  568. mindspore/train/model.py +184 -113
  569. mindspore/train/serialization.py +622 -978
  570. mindspore/train/summary/_summary_adapter.py +2 -2
  571. mindspore/train/summary/summary_record.py +2 -3
  572. mindspore/train/train_thor/model_thor.py +1 -1
  573. mindspore/turbojpeg.dll +0 -0
  574. mindspore/utils/__init__.py +6 -3
  575. mindspore/utils/dryrun.py +140 -0
  576. mindspore/utils/hooks.py +81 -0
  577. mindspore/utils/runtime_execution_order_check.py +550 -0
  578. mindspore/utils/utils.py +138 -4
  579. mindspore/vcmeta.dll +0 -0
  580. mindspore/vcruntime140.dll +0 -0
  581. mindspore/vcruntime140_1.dll +0 -0
  582. mindspore/version.py +1 -1
  583. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
  584. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
  585. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
  586. mindspore/_install_custom.py +0 -43
  587. mindspore/common/_register_for_adapter.py +0 -74
  588. mindspore/common/_tensor_overload.py +0 -139
  589. mindspore/mindspore_np_dtype.dll +0 -0
  590. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  591. mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
  592. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  593. mindspore/ops_generate/gen_aclnn_implement.py +0 -263
  594. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  595. mindspore/ops_generate/gen_pyboost_func.py +0 -1052
  596. mindspore/ops_generate/gen_utils.py +0 -209
  597. mindspore/ops_generate/op_proto.py +0 -145
  598. mindspore/ops_generate/template.py +0 -261
  599. mindspore/profiler/envprofiling.py +0 -254
  600. mindspore/profiler/profiling.py +0 -1926
  601. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  602. {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -20,14 +20,133 @@ import mindspore as ms
20
20
  from mindspore import ops
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore.ops.operations._sequence_ops import TensorToScalar, TensorToTuple
23
- from mindspore.ops_generate.gen_ops_inner_prim import TupleToList, ListToTuple
24
23
  from mindspore._c_expression import OpDtype
24
+ from mindspore._c_expression import typing
25
+ from mindspore._c_expression import op_enum
26
+ from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register
25
27
 
26
28
  tensor_to_tuple_ = TensorToTuple()
29
+ tensor_to_scalar_ = TensorToScalar()
30
+
31
+
32
+ class TupleToList(Primitive):
33
+ r"""
34
+ Convert tuple to list.
35
+
36
+ Inputs:
37
+ - **x** (tuple) - The input
38
+
39
+ Outputs:
40
+ List, has the same elements as the `input`.
41
+
42
+ Supported Platforms:
43
+ ``CPU``
44
+
45
+ Examples:
46
+ >>> from mindspore.ops._utils.arg_dtype_cast import TupleToList
47
+ >>> x = (1, 2, 3)
48
+ >>> result = TupleToList()(x)
49
+ >>> print(result)
50
+ [1, 2, 3]
51
+ """
52
+ @prim_arg_register
53
+ def __init__(self):
54
+ """Initialize TupleToList"""
55
+
56
+ def __call__(self, input):
57
+ return list(input)
58
+
59
+
60
+ class ListToTuple(Primitive):
61
+ r"""
62
+ Convert list to tuple.
63
+
64
+ Inputs:
65
+ - **x** (list) - The input
66
+
67
+ Outputs:
68
+ Tuple, has the same elements as the `input`.
69
+
70
+ Supported Platforms:
71
+ ``CPU``
72
+
73
+ Examples:
74
+ >>> from mindspore.ops._utils.arg_dtype_cast import ListToTuple
75
+ >>> x = [1, 2, 3]
76
+ >>> result = ListToTuple()(x)
77
+ >>> print(result)
78
+ (1, 2, 3)
79
+ """
80
+ @prim_arg_register
81
+ def __init__(self):
82
+ """Initialize TupleToList"""
83
+
84
+ def __call__(self, input):
85
+ return tuple(input)
86
+
87
+
27
88
  tuple_to_list = TupleToList()
28
89
  list_to_tuple = ListToTuple()
29
90
 
30
91
 
92
+ class DtypeToEnum(Primitive):
93
+ r"""
94
+ Convert mindspore dtype to enum.
95
+
96
+ Inputs:
97
+ - **op_name** (str) - The op name
98
+ - **arg_name** (str) - The arg name
99
+ - **dtype** (mindspore.dtype) - The data type.
100
+
101
+ Outputs:
102
+ An integer.
103
+
104
+ Supported Platforms:
105
+ ``Ascend`` ``GPU`` ``CPU``
106
+ """
107
+
108
+ @prim_attr_register
109
+ def __init__(self):
110
+ """Initialize"""
111
+
112
+ def __call__(self, op_name, arg_name, dtype):
113
+ """Run in PyNative mode"""
114
+ if not isinstance(dtype, typing.Type):
115
+ raise TypeError(
116
+ f"For '{op_name}', the input '{arg_name}' should be mindspore dtype, but got {dtype}.")
117
+ return typing.type_to_type_id(dtype)
118
+
119
+
120
+ class StringToEnum(Primitive):
121
+ r"""
122
+ Convert string to enum.
123
+
124
+ Inputs:
125
+ - **op_name** (str) - The op name
126
+ - **arg_name** (str) - The arg name
127
+ - **enum_str** (str) - The str data.
128
+
129
+ Outputs:
130
+ An integer.
131
+
132
+ Supported Platforms:
133
+ ``CPU``
134
+ """
135
+
136
+ @prim_attr_register
137
+ def __init__(self):
138
+ """Initialize"""
139
+
140
+ def __call__(self, op_name, arg_name, enum_str):
141
+ """Run in PyNative mode"""
142
+ if enum_str is None:
143
+ return None
144
+ if not isinstance(enum_str, str):
145
+ raise TypeError(
146
+ f"For '{op_name}', the input '{arg_name}' should be a str, but got {type(enum_str)}.")
147
+ return op_enum.str_to_enum(op_name, arg_name, enum_str)
148
+
149
+
31
150
  def int_to_float(data):
32
151
  return float(data)
33
152
 
@@ -184,7 +303,7 @@ def get_support_dtype_list(src_type, dst_type):
184
303
  return support_list
185
304
 
186
305
 
187
- def to_py_number(data, dst_type):
306
+ def tensor_to_number(data, dst_type):
188
307
  """Convert tensor to python number"""
189
308
  if dst_type == DT_INT_VAL:
190
309
  data = ops.cast(data, ms.int64)
@@ -197,7 +316,7 @@ def to_py_number(data, dst_type):
197
316
  data = ops.cast(data, ms.int64)
198
317
  elif src_type in (ms.bfloat16, ms.float16, ms.float32, ms.float64):
199
318
  data = ops.cast(data, ms.float32)
200
- return TensorToScalar()(data)
319
+ return tensor_to_scalar_(data)
201
320
 
202
321
 
203
322
  def do_type_cast(data, dst_type):
@@ -230,7 +349,7 @@ def do_type_cast(data, dst_type):
230
349
  return list_to_tensor(data)
231
350
  elif is_number(dst_type):
232
351
  if isinstance(data, Tensor):
233
- return to_py_number(data, dst_type)
352
+ return tensor_to_number(data, dst_type)
234
353
  raise TypeError("Type conversion failed.")
235
354
 
236
355
 
@@ -14,13 +14,11 @@
14
14
  # ============================================================================
15
15
  """Operator argument handle function."""
16
16
 
17
- from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum, StringToEnum
18
- # Enum Class:
19
- from mindspore._c_expression import FormatEnum as Format
20
- from mindspore._c_expression import ReductionEnum as Reduction
21
17
  from mindspore.common import Tensor
22
18
  from mindspore.common import dtype as mstype
23
19
 
20
+ from .arg_dtype_cast import DtypeToEnum, StringToEnum
21
+
24
22
 
25
23
  def arg_invalid_info(op_name, arg_name, arg_val):
26
24
  """
@@ -116,67 +114,6 @@ def to_2d_paddings(op_name, arg_name, pad):
116
114
  raise ValueError(arg_invalid_info(op_name, arg_name, pad))
117
115
 
118
116
 
119
- def to_paddings(op_name, arg_name, pad):
120
- """
121
- convert paddings: int -> tuple[int*4].
122
- """
123
- if isinstance(pad, int):
124
- return (pad,) * 4
125
- if isinstance(pad, (tuple, list)):
126
- return pad
127
- raise ValueError(arg_invalid_info(op_name, arg_name, pad))
128
-
129
-
130
- def to_3d_kernel_size(op_name, arg_name, kernel_size):
131
- """
132
- convert 3d kernel_size: int/tuple[int*6] -> tuple[int*3].
133
- """
134
- if isinstance(kernel_size, int):
135
- return (kernel_size, kernel_size, kernel_size)
136
- if isinstance(kernel_size, (tuple, list)):
137
- if len(kernel_size) == 5:
138
- return (kernel_size[2], kernel_size[3], kernel_size[4])
139
- return kernel_size
140
- raise ValueError(arg_invalid_info(op_name, arg_name, kernel_size))
141
-
142
-
143
- def to_3d_strides(op_name, arg_name, stride):
144
- """
145
- convert 3d stride: int/tuple[int*6] -> tuple[int*3].
146
- """
147
- if isinstance(stride, int):
148
- return (stride, stride, stride)
149
- if isinstance(stride, (tuple, list)):
150
- if len(stride) == 5:
151
- return (stride[2], stride[3], stride[4])
152
- return stride
153
- raise ValueError(arg_invalid_info(op_name, arg_name, stride))
154
-
155
-
156
- def to_3d_dilations(op_name, arg_name, dilation):
157
- """
158
- convert 3d dilation: int/tuple[int*6] -> tuple[int*3].
159
- """
160
- if isinstance(dilation, int):
161
- return (dilation, dilation, dilation)
162
- if isinstance(dilation, (tuple, list)):
163
- if len(dilation) == 5:
164
- return (dilation[2], dilation[3], dilation[4])
165
- return dilation
166
- raise ValueError(arg_invalid_info(op_name, arg_name, dilation))
167
-
168
-
169
- def to_3d_paddings(op_name, arg_name, pad):
170
- """
171
- convert 3d paddings: int -> tuple[int*6].
172
- """
173
- if isinstance(pad, int):
174
- return (pad,) * 6
175
- if isinstance(pad, (tuple, list)):
176
- return pad
177
- raise ValueError(arg_invalid_info(op_name, arg_name, pad))
178
-
179
-
180
117
  def generator_handler(op_name, arg_name, inputs):
181
118
  """
182
119
  convert constant value in tuple to tensor
@@ -189,6 +126,7 @@ def generator_handler(op_name, arg_name, inputs):
189
126
  new_inputs.append(input_)
190
127
  return tuple(new_inputs)
191
128
 
129
+
192
130
  dtype_to_type_id = DtypeToEnum()
193
131
 
194
132
  # string to enum
@@ -15,12 +15,13 @@
15
15
 
16
16
  """array_ops vmap impl."""
17
17
  from __future__ import absolute_import
18
+ from enum import Enum
18
19
 
19
20
  import mindspore
20
21
  import mindspore.numpy as mnp
21
22
  from mindspore import ops
22
23
  from mindspore.common import Tensor
23
- from mindspore._c_expression import Tensor as Tensor_
24
+ from mindspore._c_expression import TensorPy as Tensor_
24
25
  from mindspore.ops import operations as P
25
26
  from mindspore.ops import functional as F
26
27
  from mindspore.ops.primitive import constexpr, _primexpr
@@ -140,6 +141,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
140
141
  the generated prefix is a Tensor([[[0], [0]],
141
142
  [[1], [1]]])
142
143
  """
144
+ cast_op = P.Cast()
145
+
143
146
  def _check(indices_shape):
144
147
  if not indices_shape:
145
148
  raise ValueError("indices_shape is empty in _get_prefix.")
@@ -147,8 +150,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
147
150
  _check(indices_shape)
148
151
  indices_len = len(indices_shape)
149
152
  if indices_len == 1:
150
- prefix = P.Range()(Tensor(0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype))
151
- return prefix
153
+ prefix = P.Range()(0, axis_size, 1)
154
+ return cast_op(prefix, indices_dtype)
152
155
 
153
156
  indices_end = indices_len - 1
154
157
  prefix_shape = ()
@@ -163,9 +166,8 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
163
166
  else:
164
167
  expand_shape = expand_shape + (1,)
165
168
 
166
- prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(Tensor(
167
- 0, indices_dtype), Tensor(axis_size, indices_dtype), Tensor(1, indices_dtype)), expand_shape))
168
- return prefix
169
+ prefix = P.BroadcastTo(prefix_shape)(P.Reshape()(P.Range()(0, axis_size, 1), expand_shape))
170
+ return cast_op(prefix, indices_dtype)
169
171
 
170
172
 
171
173
  @vmap_rules_getters.register(P.Transpose)
@@ -1488,16 +1490,19 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1488
1490
  """VmapRule for `P.Meshgrid` operation."""
1489
1491
  if isinstance(prim, str):
1490
1492
  prim = Primitive(prim)
1491
- indexing = prim.indexing
1492
1493
 
1493
- def vmap_rule(*inputs_bdim):
1494
- is_all_none, result = vmap_general_preprocess(prim, *inputs_bdim)
1494
+ class Indexing(Enum):
1495
+ ij = 0
1496
+ xy = 1
1497
+
1498
+ def vmap_rule(inputs_bdim, indexing_bdim):
1499
+ is_all_none, result = vmap_general_preprocess(prim, inputs_bdim, indexing_bdim)
1495
1500
  if is_all_none:
1496
1501
  return result
1497
1502
 
1498
1503
  if not isinstance(inputs_bdim, (tuple)):
1499
1504
  _raise_value_error("The inputs of P.Meshgrid is not tuple.")
1500
- args = inputs_bdim[0]
1505
+ args = inputs_bdim
1501
1506
  if len(args) <= 1:
1502
1507
  _raise_value_error(
1503
1508
  "The input number of P.Meshgrid must be greater than 1.")
@@ -1518,7 +1523,9 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1518
1523
  output_shape.insert(0, axis_size)
1519
1524
  ones_shape.insert(0, axis_size)
1520
1525
 
1521
- if indexing == "xy":
1526
+ indexing, _ = indexing_bdim
1527
+
1528
+ if indexing == Indexing.xy.value:
1522
1529
  output_shape[1], output_shape[2] = output_shape[2], output_shape[1]
1523
1530
  shape = tuple(output_shape)
1524
1531
 
@@ -1531,7 +1538,7 @@ def get_meshgrid_vmap_rule(prim, axis_size):
1531
1538
  for each_arg in args:
1532
1539
  x, bdim = each_arg
1533
1540
  x = _bdim_at_front(x, bdim, axis_size)
1534
- shape_index = (1 - index) if (index <= 1 and indexing == "xy") else index
1541
+ shape_index = (1 - index) if (index <= 1 and indexing == Indexing.xy.value) else index
1535
1542
  ones_shape[shape_index + 1] = output_shape[shape_index + 1]
1536
1543
  x = P.Reshape()(x, tuple(ones_shape))
1537
1544
  output = P.Mul()(x, ones_tensor)
@@ -1889,10 +1896,6 @@ def get_slice_vmap_rule(prim, axis_size):
1889
1896
  @vmap_rules_getters.register(P.Squeeze)
1890
1897
  def get_squeeze_vmap_rule(prim, axis_size):
1891
1898
  """VmapRule for `Squeeze`."""
1892
- if hasattr(prim, 'axis'):
1893
- prim_axis = prim.axis
1894
- else:
1895
- prim_axis = None
1896
1899
 
1897
1900
  @_primexpr
1898
1901
  def move_axis(axes):
@@ -1911,27 +1914,26 @@ def get_squeeze_vmap_rule(prim, axis_size):
1911
1914
  new_axis += (i,)
1912
1915
  return new_axis
1913
1916
 
1914
- def vmap_rule(x_bdim):
1915
- is_all_none, result = vmap_general_preprocess(prim, x_bdim)
1917
+ def vmap_rule(x_bdim, axis_bdim):
1918
+ is_all_none, result = vmap_general_preprocess(prim, x_bdim, axis_bdim)
1916
1919
  if is_all_none:
1917
1920
  return result
1918
1921
 
1919
1922
  x, x_dim = x_bdim
1923
+ axis, _ = axis_bdim
1920
1924
  x = _bdim_at_front(x, x_dim, axis_size)
1921
1925
 
1922
- if prim_axis is None:
1926
+ if axis is None:
1923
1927
  if axis_size == 1:
1924
1928
  new_axis = generate_all_axis_except_first(F.rank(x))
1925
- batch_squeeze = P.Squeeze(axis=new_axis)
1926
- out = batch_squeeze(x)
1929
+ out = prim(x, new_axis)
1927
1930
  return out, 0
1928
1931
 
1929
- out = prim(x)
1932
+ out = prim(x, axis)
1930
1933
  return out, 0
1931
1934
 
1932
- new_axis = move_axis(prim_axis)
1933
- batch_squeeze = P.Squeeze(axis=new_axis)
1934
- out = batch_squeeze(x)
1935
+ new_axis = move_axis(axis)
1936
+ out = prim(x, new_axis)
1935
1937
  return out, 0
1936
1938
 
1937
1939
  return vmap_rule
@@ -512,8 +512,6 @@ _ops_vmap_clone_prim_dict = {
512
512
  "ApplyAdagradV2": P.ApplyAdagradV2,
513
513
  "UniformCandidateSampler": UniformCandidateSampler,
514
514
  "UniqueWithPad": P.UniqueWithPad,
515
- "CdistGrad": G.CdistGrad,
516
- "Cdist": P.Cdist,
517
515
  "STFT": math_ops.STFT,
518
516
  "Conv2D": P.Conv2D,
519
517
  "Conv3D": P.Conv3D,
@@ -25,7 +25,9 @@ from mindspore.ops.primitive import _primexpr
25
25
  from mindspore.ops.function import _VmapGeneralRule
26
26
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error, \
27
27
  _bdim_at_front, _vmap_clone_prim, _bdim_at_any, _handle_broadcasting
28
- from mindspore.ops.auto_generate.gen_arg_handler import Format, Reduction
28
+ from mindspore.ops import auto_generate as gen
29
+ from mindspore._c_expression import FormatEnum as Format
30
+ from mindspore._c_expression import ReductionEnum as Reduction
29
31
 
30
32
 
31
33
  @vmap_rules_getters.register(G.NLLLossGrad)
@@ -225,33 +227,35 @@ def get_max_pool3d_grad_with_argmax_vmap_rule(prim, axis_size):
225
227
  return vmap_rule
226
228
 
227
229
 
228
- @vmap_rules_getters.register(G.CdistGrad)
230
+ @vmap_rules_getters.register(gen.CdistGrad)
229
231
  def get_cdist_grad_vmap_rule(prim, axis_size):
230
232
  """VmapRule for `cdist grad` operation."""
231
- if hasattr(prim, 'batch_rank'):
232
- batch_rank = prim.batch_rank + 1
233
+ if prim.has_label("batch_rank"):
234
+ batch_rank = prim.get_label("batch_rank") + 1
233
235
  else:
234
236
  batch_rank = 1
235
237
 
236
- batch_prim = _vmap_clone_prim(prim)
237
- batch_prim.add_prim_attr("batch_rank", batch_rank)
238
+ prim = prim.clone()
239
+ prim.set_label('batch_rank', batch_rank)
238
240
 
239
- def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim):
240
- is_all_none, result = vmap_general_preprocess(prim,
241
- grad_bdim, x_bdim, y_bdim, cdist_bdim)
241
+ def vmap_rule(grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim):
242
+ is_all_none, result = vmap_general_preprocess(
243
+ prim, grad_bdim, x_bdim, y_bdim, cdist_bdim, p_bdim
244
+ )
242
245
  if is_all_none:
243
246
  return result
244
247
  grad, grad_dim = grad_bdim
245
248
  x, x_dim = x_bdim
246
249
  y, y_dim = y_bdim
247
250
  cdist, cdist_dim = cdist_bdim
251
+ p, _ = p_bdim
248
252
 
249
253
  grad = _bdim_at_front(grad, grad_dim, axis_size)
250
254
  x = _bdim_at_front(x, x_dim, axis_size)
251
255
  y = _bdim_at_front(y, y_dim, axis_size)
252
256
  cdist = _bdim_at_front(cdist, cdist_dim, axis_size)
253
257
 
254
- out = batch_prim(grad, x, y, cdist)
258
+ out = prim(grad, x, y, cdist, p)
255
259
  return out, 0
256
260
 
257
261
  return vmap_rule
@@ -673,10 +677,11 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
673
677
  else:
674
678
  _raise_value_error("The prim name must be `GridSampler2D` or `GridSampler3D`, but got {}.".format(prim_name))
675
679
 
676
-
677
- def vmap_rule(grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim):
680
+ def vmap_rule(grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
681
+ output_mask_bdim):
678
682
  is_all_none, result = vmap_general_preprocess(
679
- prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim)
683
+ prim, grad_bdim, input_x_bdim, grid_bdim, interpolation_mode_bdim, padding_mode_bdim, align_corners_bdim,
684
+ output_mask_bdim)
680
685
  if is_all_none:
681
686
  return result
682
687
 
@@ -686,6 +691,7 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
686
691
  interpolation_mode, _ = interpolation_mode_bdim
687
692
  padding_mode, _ = padding_mode_bdim
688
693
  align_corners, _ = align_corners_bdim
694
+ output_mask, _ = output_mask_bdim
689
695
 
690
696
  grad = _bdim_at_front(grad, grad_dim, axis_size)
691
697
  grad_shape = F.shape(grad)
@@ -699,7 +705,8 @@ def get_grid_sampler_grad_vmap_rule(prim, axis_size):
699
705
  grid_shape = F.shape(grid)
700
706
  grid = F.reshape(grid, (-1,) + grid_shape[non_batch_dim_index:])
701
707
 
702
- dx, dgrid = prim(grad, input_x, grid, interpolation_mode, padding_mode, align_corners)
708
+ dx, dgrid = prim(grad, input_x, grid, interpolation_mode,
709
+ padding_mode, align_corners, output_mask)
703
710
  dx_shape = F.shape(dx)
704
711
  dx_return_shape = input_x_shape[:non_batch_dim_index] + dx_shape[non_batch_dim_index:]
705
712
  dx = F.reshape(dx, dx_return_shape)
@@ -19,6 +19,7 @@ from __future__ import absolute_import
19
19
  import mindspore.numpy as mnp
20
20
  from mindspore.ops import operations as P
21
21
  from mindspore.ops import functional as F
22
+ from mindspore.ops import auto_generate as gen
22
23
  from mindspore.ops.auto_generate import MatMulExt
23
24
  from mindspore.ops.primitive import _primexpr
24
25
  from mindspore.common import Tensor
@@ -29,7 +30,7 @@ from mindspore.ops.primitive import Primitive
29
30
  from mindspore.ops.function import _VmapGeneralRule
30
31
  from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_assign_vmap_rule, \
31
32
  get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
32
- _vmap_clone_prim, _bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
33
+ _bdim_at_any, _get_reduce_batch_axis, _get_reduce_out_dim
33
34
  from mindspore.ops.operations.math_ops import Bernoulli, BesselI0, BesselI1, BesselJ0, BesselJ1, \
34
35
  BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1, BesselK1e, Median
35
36
 
@@ -128,28 +129,29 @@ def get_addcxxx_vmap_rule(prim, axis_size):
128
129
  return vmap_rule
129
130
 
130
131
 
131
- @vmap_rules_getters.register(P.Cdist)
132
+ @vmap_rules_getters.register(gen.Cdist)
132
133
  def get_cdist_vmap_rule(prim, axis_size):
133
134
  """VmapRule for `cdist` operation."""
134
- if hasattr(prim, 'batch_rank'):
135
- batch_rank = prim.batch_rank + 1
135
+ if prim.has_label("batch_rank"):
136
+ batch_rank = prim.get_label("batch_rank") + 1
136
137
  else:
137
138
  batch_rank = 1
138
139
 
139
- batch_prim = _vmap_clone_prim(prim)
140
- batch_prim.add_prim_attr("batch_rank", batch_rank)
140
+ prim = prim.clone()
141
+ prim.set_label('batch_rank', batch_rank)
141
142
 
142
- def vmap_rule(x_bdim, y_bdim):
143
+ def vmap_rule(x_bdim, y_bdim, p_bdim):
143
144
  x, x_dim = x_bdim
144
145
  y, y_dim = y_bdim
146
+ p, _ = p_bdim
145
147
 
146
- if x_dim is None and y_dim is None:
148
+ if x_dim is None and y_dim is None and p is None:
147
149
  out = prim(x, y)
148
150
  return (out, None)
149
151
  x = _bdim_at_front(x, x_dim, axis_size)
150
152
  y = _bdim_at_front(y, y_dim, axis_size)
151
153
 
152
- out = batch_prim(x, y)
154
+ out = prim(x, y, p)
153
155
  return out, 0
154
156
 
155
157
  return vmap_rule
@@ -559,20 +561,17 @@ def get_index_add_vmap_rule(prim, axis_size):
559
561
  @vmap_rules_getters.register(linalg_ops.Svd)
560
562
  def get_svd_vmap_rule(prim, axis_size):
561
563
  """VmapRule for 'Svd' operation."""
562
- if isinstance(prim, str):
563
- prim = Primitive(prim)
564
- compute_uv = True
565
- else:
566
- compute_uv = prim.compute_uv
567
564
 
568
- def vmap_rule(x_bdim):
565
+ def vmap_rule(x_bdim, full_matrices_bdim, compute_uv_bdim):
569
566
  is_all_none, result = vmap_general_preprocess(prim, x_bdim)
570
567
  if is_all_none:
571
568
  return result
572
569
 
573
570
  x, x_dim = x_bdim
571
+ full_matrices, _ = full_matrices_bdim
572
+ compute_uv, _ = compute_uv_bdim
574
573
  x = _bdim_at_front(x, x_dim, axis_size)
575
- s, u, v = prim(x)
574
+ s, u, v = prim(x, full_matrices, compute_uv)
576
575
  if compute_uv:
577
576
  return (s, 0), (u, 0), (v, 0)
578
577
  return (s, 0), (u, None), (v, None)