mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (491) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +6 -4
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -0
  9. mindspore/_checkparam.py +3 -33
  10. mindspore/_deprecated/__init__.py +17 -0
  11. mindspore/_deprecated/jit.py +198 -0
  12. mindspore/_extends/builtin_operations.py +1 -1
  13. mindspore/_extends/parse/__init__.py +6 -7
  14. mindspore/_extends/parse/compile_config.py +19 -0
  15. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
  16. mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
  17. mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
  18. mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
  19. mindspore/_extends/parse/parser.py +24 -193
  20. mindspore/_extends/parse/resources.py +1 -5
  21. mindspore/_extends/parse/standard_method.py +97 -74
  22. mindspore/_extends/pijit/__init__.py +2 -2
  23. mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
  24. mindspore/_extends/pijit/tensor_func_list.py +27 -0
  25. mindspore/_extends/utils.py +1 -1
  26. mindspore/amp.py +4 -4
  27. mindspore/atlprov.dll +0 -0
  28. mindspore/avcodec-59.dll +0 -0
  29. mindspore/avdevice-59.dll +0 -0
  30. mindspore/avfilter-8.dll +0 -0
  31. mindspore/avformat-59.dll +0 -0
  32. mindspore/avutil-57.dll +0 -0
  33. mindspore/boost/__init__.py +2 -2
  34. mindspore/boost/base.py +3 -7
  35. mindspore/boost/boost_cell_wrapper.py +2 -2
  36. mindspore/c1.dll +0 -0
  37. mindspore/c1xx.dll +0 -0
  38. mindspore/c2.dll +0 -0
  39. mindspore/common/__init__.py +4 -3
  40. mindspore/common/_grad_function.py +56 -0
  41. mindspore/common/_pijit_context.py +14 -5
  42. mindspore/common/_register_for_tensor.py +1 -1
  43. mindspore/common/_stub_tensor.py +5 -10
  44. mindspore/common/_tensor_cpp_method.py +1 -1
  45. mindspore/common/_tensor_docs.py +1915 -3287
  46. mindspore/common/api.py +341 -354
  47. mindspore/common/auto_dynamic_shape.py +41 -44
  48. mindspore/common/dtype.py +5 -2
  49. mindspore/common/dump.py +7 -5
  50. mindspore/common/file_system.py +3 -0
  51. mindspore/common/hook_handle.py +5 -3
  52. mindspore/common/initializer.py +10 -6
  53. mindspore/common/jit_begin_end.py +94 -0
  54. mindspore/common/jit_config.py +6 -1
  55. mindspore/common/jit_context.py +76 -0
  56. mindspore/common/jit_trace.py +378 -0
  57. mindspore/common/lazy_inline.py +2 -2
  58. mindspore/common/mutable.py +5 -4
  59. mindspore/common/parameter.py +106 -39
  60. mindspore/common/seed.py +2 -2
  61. mindspore/common/sparse_tensor.py +23 -17
  62. mindspore/common/tensor.py +297 -714
  63. mindspore/communication/__init__.py +7 -5
  64. mindspore/communication/_comm_helper.py +47 -2
  65. mindspore/communication/comm_func.py +70 -53
  66. mindspore/communication/management.py +83 -17
  67. mindspore/context.py +214 -560
  68. mindspore/dataset/__init__.py +44 -20
  69. mindspore/dataset/audio/__init__.py +2 -8
  70. mindspore/dataset/audio/transforms.py +3 -17
  71. mindspore/dataset/core/config.py +3 -3
  72. mindspore/dataset/engine/cache_client.py +1 -1
  73. mindspore/dataset/engine/datasets.py +102 -120
  74. mindspore/dataset/engine/datasets_audio.py +22 -22
  75. mindspore/dataset/engine/datasets_standard_format.py +43 -24
  76. mindspore/dataset/engine/datasets_text.py +78 -85
  77. mindspore/dataset/engine/datasets_user_defined.py +108 -76
  78. mindspore/dataset/engine/datasets_vision.py +111 -108
  79. mindspore/dataset/engine/iterators.py +5 -3
  80. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
  81. mindspore/dataset/engine/samplers.py +279 -57
  82. mindspore/dataset/engine/serializer_deserializer.py +2 -1
  83. mindspore/dataset/engine/validators.py +10 -0
  84. mindspore/dataset/text/__init__.py +7 -6
  85. mindspore/dataset/text/transforms.py +6 -5
  86. mindspore/dataset/text/utils.py +3 -3
  87. mindspore/dataset/transforms/__init__.py +0 -9
  88. mindspore/dataset/transforms/transforms.py +3 -3
  89. mindspore/dataset/utils/browse_dataset.py +1 -1
  90. mindspore/dataset/vision/__init__.py +2 -9
  91. mindspore/dataset/vision/transforms.py +202 -158
  92. mindspore/dataset/vision/utils.py +7 -5
  93. mindspore/device_context/ascend/op_debug.py +60 -1
  94. mindspore/device_context/ascend/op_tuning.py +0 -4
  95. mindspore/device_manager.py +39 -3
  96. mindspore/dnnl.dll +0 -0
  97. mindspore/dpcmi.dll +0 -0
  98. mindspore/experimental/es/embedding_service.py +35 -27
  99. mindspore/experimental/map_parameter.py +4 -4
  100. mindspore/experimental/optim/adadelta.py +22 -26
  101. mindspore/experimental/optim/adagrad.py +4 -4
  102. mindspore/experimental/optim/adam.py +4 -0
  103. mindspore/experimental/optim/adamax.py +4 -4
  104. mindspore/experimental/optim/adamw.py +4 -0
  105. mindspore/experimental/optim/asgd.py +1 -1
  106. mindspore/experimental/optim/lr_scheduler.py +40 -22
  107. mindspore/experimental/optim/radam.py +5 -5
  108. mindspore/experimental/optim/rprop.py +1 -1
  109. mindspore/experimental/optim/sgd.py +1 -1
  110. mindspore/hal/contiguous_tensors_handle.py +6 -10
  111. mindspore/hal/device.py +55 -81
  112. mindspore/hal/event.py +38 -55
  113. mindspore/hal/memory.py +93 -144
  114. mindspore/hal/stream.py +81 -125
  115. mindspore/include/dataset/constants.h +7 -4
  116. mindspore/include/dataset/execute.h +2 -2
  117. mindspore/jpeg62.dll +0 -0
  118. mindspore/log.py +40 -2
  119. mindspore/mindrecord/__init__.py +20 -7
  120. mindspore/mindspore_backend_common.dll +0 -0
  121. mindspore/mindspore_backend_manager.dll +0 -0
  122. mindspore/mindspore_common.dll +0 -0
  123. mindspore/mindspore_core.dll +0 -0
  124. mindspore/mindspore_dump.dll +0 -0
  125. mindspore/mindspore_frontend.dll +0 -0
  126. mindspore/mindspore_glog.dll +0 -0
  127. mindspore/mindspore_memory_pool.dll +0 -0
  128. mindspore/mindspore_ms_backend.dll +0 -0
  129. mindspore/mindspore_ops.dll +0 -0
  130. mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
  131. mindspore/mindspore_ops_kernel_common.dll +0 -0
  132. mindspore/mindspore_profiler.dll +0 -0
  133. mindspore/mindspore_pyboost.dll +0 -0
  134. mindspore/mindspore_pynative.dll +0 -0
  135. mindspore/mindspore_res_manager.dll +0 -0
  136. mindspore/mindspore_runtime_pipeline.dll +0 -0
  137. mindspore/mint/__init__.py +131 -700
  138. mindspore/mint/distributed/__init__.py +5 -1
  139. mindspore/mint/distributed/distributed.py +194 -109
  140. mindspore/mint/linalg/__init__.py +2 -0
  141. mindspore/mint/nn/__init__.py +280 -18
  142. mindspore/mint/nn/functional.py +282 -64
  143. mindspore/mint/nn/layer/__init__.py +4 -0
  144. mindspore/mint/nn/layer/_functions.py +7 -3
  145. mindspore/mint/nn/layer/activation.py +120 -13
  146. mindspore/mint/nn/layer/conv.py +218 -24
  147. mindspore/mint/nn/layer/normalization.py +15 -16
  148. mindspore/mint/nn/layer/padding.py +1 -1
  149. mindspore/mint/nn/layer/pooling.py +66 -1
  150. mindspore/mint/optim/__init__.py +2 -1
  151. mindspore/mint/optim/sgd.py +171 -0
  152. mindspore/msobj140.dll +0 -0
  153. mindspore/mspdb140.dll +0 -0
  154. mindspore/mspdbcore.dll +0 -0
  155. mindspore/mspdbst.dll +0 -0
  156. mindspore/mspft140.dll +0 -0
  157. mindspore/msvcdis140.dll +0 -0
  158. mindspore/msvcp140_1.dll +0 -0
  159. mindspore/msvcp140_2.dll +0 -0
  160. mindspore/msvcp140_atomic_wait.dll +0 -0
  161. mindspore/msvcp140_codecvt_ids.dll +0 -0
  162. mindspore/nn/__init__.py +4 -1
  163. mindspore/nn/cell.py +1250 -176
  164. mindspore/nn/layer/activation.py +23 -21
  165. mindspore/nn/layer/basic.py +22 -16
  166. mindspore/nn/layer/container.py +1 -1
  167. mindspore/nn/layer/conv.py +22 -17
  168. mindspore/nn/layer/embedding.py +9 -8
  169. mindspore/nn/layer/normalization.py +48 -42
  170. mindspore/nn/layer/pooling.py +75 -31
  171. mindspore/nn/layer/transformer.py +11 -10
  172. mindspore/nn/learning_rate_schedule.py +4 -2
  173. mindspore/nn/loss/loss.py +27 -19
  174. mindspore/nn/optim/ada_grad.py +6 -5
  175. mindspore/nn/optim/adadelta.py +9 -7
  176. mindspore/nn/optim/adafactor.py +1 -1
  177. mindspore/nn/optim/adam.py +16 -12
  178. mindspore/nn/optim/adamax.py +8 -7
  179. mindspore/nn/optim/adasum.py +5 -5
  180. mindspore/nn/optim/asgd.py +1 -1
  181. mindspore/nn/optim/ftrl.py +11 -9
  182. mindspore/nn/optim/lamb.py +1 -1
  183. mindspore/nn/optim/lazyadam.py +12 -10
  184. mindspore/nn/optim/momentum.py +7 -6
  185. mindspore/nn/optim/optimizer.py +2 -2
  186. mindspore/nn/optim/proximal_ada_grad.py +12 -10
  187. mindspore/nn/optim/rmsprop.py +13 -12
  188. mindspore/nn/optim/rprop.py +9 -7
  189. mindspore/nn/optim/sgd.py +9 -6
  190. mindspore/nn/optim/tft_wrapper.py +5 -2
  191. mindspore/nn/probability/bijector/bijector.py +17 -11
  192. mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
  193. mindspore/nn/probability/bijector/invert.py +2 -2
  194. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  195. mindspore/nn/probability/bijector/softplus.py +3 -2
  196. mindspore/nn/probability/distribution/beta.py +3 -3
  197. mindspore/nn/probability/distribution/categorical.py +1 -1
  198. mindspore/nn/probability/distribution/cauchy.py +4 -2
  199. mindspore/nn/probability/distribution/exponential.py +6 -7
  200. mindspore/nn/probability/distribution/gamma.py +2 -2
  201. mindspore/nn/probability/distribution/gumbel.py +2 -2
  202. mindspore/nn/probability/distribution/half_normal.py +5 -3
  203. mindspore/nn/probability/distribution/logistic.py +5 -3
  204. mindspore/nn/probability/distribution/poisson.py +1 -1
  205. mindspore/nn/probability/distribution/uniform.py +5 -3
  206. mindspore/nn/reinforcement/_tensors_queue.py +1 -1
  207. mindspore/nn/reinforcement/tensor_array.py +1 -1
  208. mindspore/nn/wrap/__init__.py +6 -6
  209. mindspore/nn/wrap/cell_wrapper.py +178 -117
  210. mindspore/nn/wrap/grad_reducer.py +45 -36
  211. mindspore/nn/wrap/loss_scale.py +3 -3
  212. mindspore/numpy/array_creations.py +3 -3
  213. mindspore/numpy/array_ops.py +1 -1
  214. mindspore/numpy/math_ops.py +4 -4
  215. mindspore/numpy/utils.py +1 -2
  216. mindspore/numpy/utils_const.py +1 -2
  217. mindspore/opencv_core452.dll +0 -0
  218. mindspore/opencv_imgcodecs452.dll +0 -0
  219. mindspore/opencv_imgproc452.dll +0 -0
  220. mindspore/ops/__init__.py +3 -2
  221. mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
  222. mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
  223. mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
  224. mindspore/ops/_register_for_op.py +0 -11
  225. mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
  226. mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
  227. mindspore/ops/_vmap/vmap_array_ops.py +7 -6
  228. mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
  229. mindspore/ops/_vmap/vmap_math_ops.py +4 -7
  230. mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
  231. mindspore/ops/auto_generate/__init__.py +4 -3
  232. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
  233. mindspore/ops/auto_generate/gen_extend_func.py +281 -135
  234. mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
  235. mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
  236. mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
  237. mindspore/ops/composite/__init__.py +2 -1
  238. mindspore/ops/composite/base.py +19 -24
  239. mindspore/ops/composite/math_ops.py +6 -16
  240. mindspore/ops/composite/multitype_ops/__init__.py +5 -2
  241. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
  242. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
  243. mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
  244. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  245. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  246. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
  247. mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
  248. mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
  249. mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
  250. mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
  251. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
  252. mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
  253. mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
  254. mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
  255. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
  256. mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
  257. mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
  258. mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
  259. mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
  260. mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
  261. mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
  262. mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
  263. mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
  264. mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
  265. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
  266. mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
  267. mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
  268. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  270. mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
  271. mindspore/ops/function/__init__.py +28 -2
  272. mindspore/ops/function/_add_attr_func.py +58 -0
  273. mindspore/ops/function/array_func.py +1629 -2345
  274. mindspore/ops/function/clip_func.py +38 -45
  275. mindspore/ops/function/debug_func.py +36 -44
  276. mindspore/ops/function/grad/__init__.py +1 -0
  277. mindspore/ops/function/grad/grad_func.py +104 -71
  278. mindspore/ops/function/image_func.py +1 -1
  279. mindspore/ops/function/linalg_func.py +46 -78
  280. mindspore/ops/function/math_func.py +3035 -3705
  281. mindspore/ops/function/nn_func.py +676 -241
  282. mindspore/ops/function/other_func.py +159 -1
  283. mindspore/ops/function/parameter_func.py +17 -30
  284. mindspore/ops/function/random_func.py +204 -361
  285. mindspore/ops/function/reshard_func.py +4 -70
  286. mindspore/ops/function/sparse_func.py +3 -3
  287. mindspore/ops/function/sparse_unary_func.py +5 -5
  288. mindspore/ops/function/spectral_func.py +25 -58
  289. mindspore/ops/function/vmap_func.py +24 -17
  290. mindspore/ops/functional.py +6 -4
  291. mindspore/ops/functional_overload.py +547 -4
  292. mindspore/ops/op_info_register.py +32 -244
  293. mindspore/ops/operations/__init__.py +10 -5
  294. mindspore/ops/operations/_custom_ops_utils.py +247 -0
  295. mindspore/ops/operations/_grad_ops.py +1 -10
  296. mindspore/ops/operations/_inner_ops.py +5 -76
  297. mindspore/ops/operations/_ms_kernel.py +4 -10
  298. mindspore/ops/operations/_rl_inner_ops.py +1 -1
  299. mindspore/ops/operations/_scalar_ops.py +3 -2
  300. mindspore/ops/operations/_sequence_ops.py +1 -1
  301. mindspore/ops/operations/_tensor_array.py +1 -1
  302. mindspore/ops/operations/array_ops.py +37 -22
  303. mindspore/ops/operations/comm_ops.py +150 -107
  304. mindspore/ops/operations/custom_ops.py +221 -23
  305. mindspore/ops/operations/debug_ops.py +115 -16
  306. mindspore/ops/operations/inner_ops.py +1 -1
  307. mindspore/ops/operations/linalg_ops.py +1 -58
  308. mindspore/ops/operations/manually_defined/_inner.py +1 -1
  309. mindspore/ops/operations/manually_defined/ops_def.py +746 -79
  310. mindspore/ops/operations/math_ops.py +21 -18
  311. mindspore/ops/operations/nn_ops.py +65 -191
  312. mindspore/ops/operations/other_ops.py +62 -9
  313. mindspore/ops/operations/random_ops.py +13 -7
  314. mindspore/ops/operations/reshard_ops.py +1 -1
  315. mindspore/ops/operations/sparse_ops.py +2 -2
  316. mindspore/ops/primitive.py +43 -32
  317. mindspore/ops/tensor_method.py +232 -13
  318. mindspore/ops_generate/__init__.py +0 -5
  319. mindspore/ops_generate/aclnn/__init__.py +0 -0
  320. mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
  321. mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
  322. mindspore/ops_generate/api/__init__.py +0 -0
  323. mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
  324. mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
  325. mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
  326. mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
  327. mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
  328. mindspore/ops_generate/api/gen_api.py +103 -0
  329. mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
  330. mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
  331. mindspore/ops_generate/common/__init__.py +0 -0
  332. mindspore/ops_generate/common/gen_constants.py +91 -0
  333. mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
  334. mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
  335. mindspore/ops_generate/{template.py → common/template.py} +96 -84
  336. mindspore/ops_generate/gen_ops.py +23 -325
  337. mindspore/ops_generate/op_def/__init__.py +0 -0
  338. mindspore/ops_generate/op_def/gen_op_def.py +90 -0
  339. mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
  340. mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
  341. mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
  342. mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
  343. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
  344. mindspore/ops_generate/op_def_py/__init__.py +0 -0
  345. mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
  346. mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
  347. mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
  348. mindspore/ops_generate/pyboost/__init__.py +0 -0
  349. mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
  350. mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
  351. mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
  352. mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
  353. mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
  354. mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
  355. mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
  356. mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
  357. mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
  358. mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
  359. mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
  360. mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
  361. mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
  362. mindspore/ops_generate/resources/__init__.py +0 -0
  363. mindspore/ops_generate/resources/resource_list.py +30 -0
  364. mindspore/ops_generate/resources/resource_loader.py +36 -0
  365. mindspore/ops_generate/resources/resource_manager.py +64 -0
  366. mindspore/ops_generate/resources/yaml_loader.py +88 -0
  367. mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
  368. mindspore/parallel/__init__.py +6 -2
  369. mindspore/parallel/_auto_parallel_context.py +133 -6
  370. mindspore/parallel/_cell_wrapper.py +130 -15
  371. mindspore/parallel/_parallel_serialization.py +95 -4
  372. mindspore/parallel/_ps_context.py +1 -1
  373. mindspore/parallel/_recovery_context.py +7 -2
  374. mindspore/parallel/_tensor.py +142 -18
  375. mindspore/parallel/_utils.py +198 -25
  376. mindspore/parallel/algo_parameter_config.py +3 -3
  377. mindspore/parallel/auto_parallel.py +732 -0
  378. mindspore/parallel/checkpoint_convert.py +159 -0
  379. mindspore/parallel/checkpoint_transform.py +656 -37
  380. mindspore/parallel/cluster/process_entity/_api.py +151 -19
  381. mindspore/parallel/cluster/run.py +1 -1
  382. mindspore/parallel/function/__init__.py +24 -0
  383. mindspore/parallel/function/reshard_func.py +259 -0
  384. mindspore/parallel/nn/__init__.py +25 -0
  385. mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
  386. mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
  387. mindspore/parallel/parameter_broadcast.py +24 -13
  388. mindspore/parallel/shard.py +137 -61
  389. mindspore/parallel/transform_safetensors.py +287 -95
  390. mindspore/pgodb140.dll +0 -0
  391. mindspore/pgort140.dll +0 -0
  392. mindspore/profiler/__init__.py +9 -5
  393. mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
  394. mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
  395. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
  396. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
  397. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  398. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
  399. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
  400. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
  401. mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
  402. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
  403. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
  404. mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
  405. mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
  406. mindspore/profiler/common/constant.py +12 -0
  407. mindspore/profiler/common/msprof_cmd_tool.py +42 -23
  408. mindspore/profiler/common/path_manager.py +24 -0
  409. mindspore/profiler/common/profiler_context.py +26 -2
  410. mindspore/profiler/common/profiler_meta_data.py +74 -0
  411. mindspore/profiler/common/profiler_parameters.py +59 -18
  412. mindspore/profiler/common/profiler_path_manager.py +66 -7
  413. mindspore/profiler/dynamic_profiler.py +112 -79
  414. mindspore/profiler/envprofiler.py +26 -1
  415. mindspore/profiler/experimental_config.py +197 -0
  416. mindspore/profiler/mstx.py +57 -14
  417. mindspore/profiler/platform/npu_profiler.py +33 -7
  418. mindspore/profiler/profiler.py +541 -45
  419. mindspore/profiler/profiler_action_controller.py +1 -1
  420. mindspore/profiler/profiler_interface.py +4 -0
  421. mindspore/profiler/schedule.py +57 -22
  422. mindspore/rewrite/api/node.py +15 -13
  423. mindspore/rewrite/api/symbol_tree.py +1 -1
  424. mindspore/run_check/_check_version.py +25 -14
  425. mindspore/run_check/run_check.py +1 -1
  426. mindspore/runtime/__init__.py +2 -2
  427. mindspore/runtime/executor.py +40 -11
  428. mindspore/runtime/memory.py +25 -8
  429. mindspore/safeguard/rewrite_obfuscation.py +12 -9
  430. mindspore/swresample-4.dll +0 -0
  431. mindspore/swscale-6.dll +0 -0
  432. mindspore/tbbmalloc.dll +0 -0
  433. mindspore/tinyxml2.dll +0 -0
  434. mindspore/train/__init__.py +8 -8
  435. mindspore/train/_utils.py +35 -7
  436. mindspore/train/amp.py +1 -1
  437. mindspore/train/callback/__init__.py +2 -2
  438. mindspore/train/callback/_callback.py +2 -16
  439. mindspore/train/callback/_checkpoint.py +24 -40
  440. mindspore/train/callback/_cluster_monitor.py +14 -18
  441. mindspore/train/callback/_flops_collector.py +2 -3
  442. mindspore/train/callback/_history.py +7 -4
  443. mindspore/train/callback/_lambda_callback.py +2 -2
  444. mindspore/train/callback/_landscape.py +0 -3
  445. mindspore/train/callback/_loss_monitor.py +2 -1
  446. mindspore/train/callback/_on_request_exit.py +6 -5
  447. mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
  448. mindspore/train/callback/_summary_collector.py +8 -13
  449. mindspore/train/callback/_time_monitor.py +2 -1
  450. mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
  451. mindspore/train/data_sink.py +25 -2
  452. mindspore/train/dataset_helper.py +4 -5
  453. mindspore/train/loss_scale_manager.py +8 -7
  454. mindspore/train/metrics/accuracy.py +3 -3
  455. mindspore/train/metrics/confusion_matrix.py +9 -9
  456. mindspore/train/metrics/error.py +3 -3
  457. mindspore/train/metrics/hausdorff_distance.py +4 -4
  458. mindspore/train/metrics/mean_surface_distance.py +3 -3
  459. mindspore/train/metrics/metric.py +0 -12
  460. mindspore/train/metrics/occlusion_sensitivity.py +4 -2
  461. mindspore/train/metrics/precision.py +8 -6
  462. mindspore/train/metrics/recall.py +9 -9
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
  464. mindspore/train/mind_ir_pb2.py +19 -12
  465. mindspore/train/model.py +176 -103
  466. mindspore/train/serialization.py +246 -988
  467. mindspore/train/summary/_summary_adapter.py +2 -2
  468. mindspore/train/summary/summary_record.py +1 -1
  469. mindspore/turbojpeg.dll +0 -0
  470. mindspore/utils/__init__.py +3 -2
  471. mindspore/utils/dryrun.py +4 -2
  472. mindspore/utils/hooks.py +81 -0
  473. mindspore/utils/utils.py +138 -4
  474. mindspore/vcmeta.dll +0 -0
  475. mindspore/vcruntime140.dll +0 -0
  476. mindspore/vcruntime140_1.dll +0 -0
  477. mindspore/version.py +1 -1
  478. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
  479. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
  480. mindspore/_install_custom.py +0 -43
  481. mindspore/common/_register_for_adapter.py +0 -74
  482. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
  483. mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
  484. mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
  485. mindspore/ops_generate/gen_constants.py +0 -190
  486. mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
  487. mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
  488. /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
  489. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
  490. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
  491. {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
1
+ # Copyright 2023-2025 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,30 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ============================================================================
15
15
  """Defines parameter operators with functional form."""
16
- from mindspore.ops import operations as P
17
- from mindspore.ops._primitive_cache import _get_cache_prim
18
- from mindspore.parallel.shard import Layout
19
- from mindspore.common.tensor import Tensor
16
+ from mindspore.parallel.function.reshard_func import reshard as new_reshard
20
17
 
21
18
 
22
19
  def reshard(tensor, layout):
23
20
  r"""
24
- Specify the tensor by the given layout. The given layout must be type mindspore.Layout,
25
- can check :class:`mindspore.Layout` for reference.
26
-
27
- - In the Graph mode, this function can set the sharding propagation strategy of a tensor.
28
- For those tensor do not manually be set, their strategies are decided by the sharding
29
- strategy propagation algorithm automatically.
30
- - In the PyNative mode, this function can set a tensor sharding strategy in a Cell that
31
- runs in the Graph mode (i.e. inside the Cell processed by Cell.shard/F.shard).
32
-
33
- Note:
34
- - In the auto parallel mode, an exception will throw if the search mode is not
35
- "sharding_propagation".
36
- - In the semi-auto parallel mode, the parallel mode will automatically switch to auto
37
- parallel mode with the search mode be set to "sharding_propagation".
38
- - Currently, configuring multi-dimension and multi-copy reshard strategy in
39
- mindspore.Layout is not supported.
21
+ Specify the tensor by the given layout.
40
22
 
41
23
  Args:
42
24
  tensor (Tensor): The tensor to be set the sharding strategy.
@@ -46,56 +28,8 @@ def reshard(tensor, layout):
46
28
 
47
29
  Returns:
48
30
  Tensor. The mathematically equivalent of the input tensor.
49
-
50
- Raises:
51
- TypeError: Reshard takes in Tensor type as the first input param, but got: `type(tensor)`.
52
- TypeError: Reshard only support type mindspore.Layout but got: `type(layout)`.
53
-
54
- Examples:
55
- >>> import numpy as np
56
- >>> import mindspore as ms
57
- >>> from mindspore import ops, nn, Tensor, context, Layout
58
- >>> context.set_context(mode=ms.GRAPH_MODE)
59
- >>> context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
60
- ... search_mode="sharding_propagation")
61
- >>> class Network(nn.Cell):
62
- ... def __init__(self):
63
- ... super().__init__()
64
- ... self.matmul = ops.MatMul()
65
- ... self.relu = ops.ReLU()
66
- ... def construct(self, x, layout):
67
- ... x = self.relu(x)
68
- ... x_reshard = ops.reshard(x, layout)
69
- ... y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
70
- ... x = self.matmul(x_reshard, y)
71
- ... return x
72
- >>>
73
- >>> layout = Layout((4, 2), ("dp", "mp"))
74
- >>> input_layout = layout("dp", "mp")
75
- >>> net = Network()
76
- >>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
77
- >>> out = net(tensor, input_layout)
78
31
  """
79
- if not isinstance(tensor, Tensor):
80
- raise TypeError(f"Reshard takes in Tensor type as the first input param, but got: {type(tensor)}.")
81
- if not isinstance(layout, Layout):
82
- raise TypeError(f"Reshard only support type mindspore.Layout, but got: {type(layout)}.")
83
-
84
- def layout_to_tuple(layout):
85
- layout_dict = layout.to_dict()
86
- tensor_map = layout_dict["tensor_map"]
87
- device_matrix_rev = layout_dict["device_matrix"][::-1]
88
- axis_stgy = ()
89
- for ind in tensor_map:
90
- if ind == -1:
91
- axis_stgy += (1,)
92
- else:
93
- axis_stgy += (device_matrix_rev[ind],)
94
- return axis_stgy
95
-
96
- in_strategy = layout_to_tuple(layout)
97
- _reshard = _get_cache_prim(P.Reshard)(in_layout=(layout,), out_layout=(layout,), in_strategy=(in_strategy,))
98
- return _reshard(tensor)
32
+ return new_reshard(tensor, layout)
99
33
 
100
34
  __all__ = [
101
35
  'reshard'
@@ -461,8 +461,8 @@ def dense_to_sparse_coo(tensor: Tensor) -> COOTensor:
461
461
  - shape (tuple(int)): the shape of the COOTensor, is the same as the original dense tensor.
462
462
 
463
463
  Raises:
464
- TypeError: If input is not a tensor.
465
- ValueError: If input tensor is not 2-D.
464
+ TypeError: If `tensor` is not a tensor.
465
+ ValueError: If `tensor` is not 2-D.
466
466
 
467
467
  Supported Platforms:
468
468
  ``GPU``
@@ -798,7 +798,7 @@ def csr_add(a: CSRTensor, b: CSRTensor, alpha: Tensor, beta: Tensor) -> CSRTenso
798
798
 
799
799
  - **indptr** - Indicates the start and end point for non-zero values in each row.
800
800
  - **indices** - The column positions of all non-zero values of the input.
801
- - **values** - The non-zero values of the dense tensor.
801
+ - **values** - The non-zero values.
802
802
  - **shape** - The shape of the CSRTensor.
803
803
 
804
804
  Supported Platforms:
@@ -674,7 +674,7 @@ def csr_sqrt(x: CSRTensor) -> CSRTensor:
674
674
 
675
675
  def coo_sqrt(x: COOTensor) -> COOTensor:
676
676
  r"""
677
- Returns sqrt of a COOTensor element-wise.
677
+ Computes sqrt of a COOTensor element-wise.
678
678
 
679
679
  .. math::
680
680
 
@@ -798,11 +798,11 @@ def csr_isnan(x: CSRTensor) -> CSRTensor:
798
798
  .. math::
799
799
 
800
800
  out_i = \begin{cases}
801
- & \ True,\ \text{ if } x_{i} = \text{Nan} \\
802
- & \ False,\ \text{ if } x_{i} \ne \text{Nan}
801
+ & \ True,\ \text{ if } x_{i} = \text{NaN} \\
802
+ & \ False,\ \text{ if } x_{i} \ne \text{NaN}
803
803
  \end{cases}
804
804
 
805
- where :math:`Nan` means not a number.
805
+ where :math:`NaN` means not a number.
806
806
 
807
807
  Args:
808
808
  x (CSRTensor): The input CSRTensor.
@@ -1905,7 +1905,7 @@ def coo_neg(x: COOTensor) -> COOTensor:
1905
1905
  x (COOTensor): The input COOTensor with a dtype of Number.
1906
1906
 
1907
1907
  Returns:
1908
- COOTensor, has the same shape and dtype as input.
1908
+ COOTensor, has the same shape and dtype as input `x`.
1909
1909
 
1910
1910
  Raises:
1911
1911
  TypeError: If `x` is not a COOTensor.
@@ -22,15 +22,9 @@ from .._primitive_cache import _get_cache_prim
22
22
 
23
23
  def blackman_window(window_length, periodic=True, *, dtype=None):
24
24
  r"""
25
- Blackman window function, usually used to extract finite signal segment for FFT.
25
+ Blackman window function.
26
26
 
27
- The `window_length` is a input tensor which determines the returned window size, and its data should be
28
- an integer. In particular, if `window_length` is equal to `1`, only a single value `1` exists in the
29
- returned window.
30
-
31
- Attr `periodic` determines whether the returned window removes the last duplicate value
32
- from the symmetric window and prepares to be a periodic window with functions.
33
- Therefore, if attr `periodic` is true, the :math:`N` in formula is :math:`window\_length + 1`.
27
+ Usually used to extract finite signal segment for FFT.
34
28
 
35
29
  .. math::
36
30
 
@@ -39,35 +33,23 @@ def blackman_window(window_length, periodic=True, *, dtype=None):
39
33
  where :math:`N` is the full window size, and n is natural number less than :math:`N` :[0, 1, ..., N-1].
40
34
 
41
35
  Args:
42
- window_length (Tensor): The size of returned window, with data type int32, int64.
43
- The input data should be an integer with a value of [0, 1000000].
44
- periodic (bool, optional): Indicates whether to returns a window to be used as periodic function or
45
- a symmetric window. Default: ``True`` .
36
+ window_length (Tensor): The size of window.
37
+ periodic (bool, optional): If ``True`` , return a periodic window. If ``False``, return a symmetric window.
38
+ Default ``True`` .
46
39
 
47
40
  Keyword Args:
48
- dtype (mindspore.dtype, optional): The data type of returned tensor.
49
- Only float16, float32 and float64 is allowed. Default: ``None`` .
41
+ dtype (mindspore.dtype, optional): The data type specified. Default ``None`` .
50
42
 
51
43
  Returns:
52
- A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
53
- If 'dtype' is None, output datatype is float32.
54
-
55
- Raises:
56
- TypeError: If `window_length` is not a Tensor.
57
- TypeError: If `periodic` is not a bool.
58
- TypeError: If `dtype` is not one of: float16, float32, float64.
59
- TypeError: If the type of `window_length` is not one of: int32, int64.
60
- ValueError: If the value range of `window_length` is not [0, 1000000].
61
- ValueError: If the dimension of `window_length` is not 0.
44
+ A 1-D tensor.
62
45
 
63
46
  Supported Platforms:
64
47
  ``Ascend`` ``GPU`` ``CPU``
65
48
 
66
49
  Examples:
67
50
  >>> import mindspore
68
- >>> from mindspore import Tensor, ops
69
- >>> window_length = Tensor(10, mindspore.int32)
70
- >>> output = ops.blackman_window(window_length, periodic=True, dtype=mindspore.float32)
51
+ >>> window_length = mindspore.tensor(10)
52
+ >>> output = mindspore.ops.blackman_window(window_length)
71
53
  >>> print(output)
72
54
  [-2.9802322e-08 4.0212840e-02 2.0077014e-01 5.0978714e-01
73
55
  8.4922993e-01 1.0000000e+00 8.4922981e-01 5.0978690e-01
@@ -82,16 +64,10 @@ def blackman_window(window_length, periodic=True, *, dtype=None):
82
64
 
83
65
  def bartlett_window(window_length, periodic=True, *, dtype=None):
84
66
  r"""
85
- Bartlett window function is a triangular-shaped weighting function used for smoothing or frequency analysis of
86
- signals in digital signal processing.
67
+ Bartlett window function.
87
68
 
88
- The `window_length` is a input tensor which determines the returned window size, and its data should be
89
- an integer. In particular, if `window_length` is equal to `1`, only a single value 1 exists in the
90
- returned window.
91
-
92
- Attr `periodic` determines whether the returned window removes the last duplicate value from the symmetric
93
- window and prepares to be a periodic window with functions. Therefore, if attr `periodic` is true,
94
- the :math:`N` in formula is :math:`window\_length + 1`.
69
+ A triangular-shaped weighting function used for smoothing or frequency analysis of signals in digital signal
70
+ processing.
95
71
 
96
72
  .. math::
97
73
 
@@ -100,40 +76,31 @@ def bartlett_window(window_length, periodic=True, *, dtype=None):
100
76
  2 - \frac{2n}{N - 1} & \text{if } \frac{N - 1}{2} < n < N \\
101
77
  \end{cases},
102
78
 
103
- where N is the full window size.
79
+ where :math:`N` is the full window size, and n is natural number less than :math:`N` :[0, 1, ..., N-1].
104
80
 
105
81
  Args:
106
- window_length (Tensor): The size of returned window, with data type int32, int64.
107
- The input data should be an integer with a value of [0, 1000000].
108
- periodic (bool, optional): Indicates whether to returns a window to be used as periodic function or
109
- a symmetric window. Default: ``True`` , indicating that the returned window is a periodic function.
82
+ window_length (Tensor): The size of window.
83
+ periodic (bool, optional): If ``True`` , return a periodic window. If ``False``, return a symmetric window.
84
+ Default ``True`` .
110
85
 
111
86
  Keyword Args:
112
- dtype (mindspore.dtype, optional): The datatype of returned tensor.
113
- Only float16, float32 and float64 are allowed. Default: ``None`` .
87
+ dtype (mindspore.dtype, optional): The data type specified. Default ``None`` .
114
88
 
115
89
  Returns:
116
- A 1-D tensor of size `window_length` containing the window. Its datatype is set by the attr `dtype`.
117
- If `dtype` is None, output datatype is float32.
118
-
119
- Raises:
120
- TypeError: If `window_length` is not a Tensor.
121
- TypeError: If the type of `window_length` is not one of: int32, int64.
122
- TypeError: If `periodic` is not a bool.
123
- TypeError: If `dtype` is not one of: float16, float32, float64.
124
- ValueError: If the value range of `window_length` is not [0, 1000000].
125
- ValueError: If the dimension of `window_length` is not 0.
90
+ A 1-D tensor.
126
91
 
127
92
  Supported Platforms:
128
93
  ``Ascend`` ``GPU`` ``CPU``
129
94
 
130
95
  Examples:
131
- >>> from mindspore import Tensor, ops
132
- >>> from mindspore import dtype as mstype
133
- >>> window_length = Tensor(5, mstype.int32)
134
- >>> output = ops.bartlett_window(window_length, periodic=True, dtype=mstype.float32)
96
+ >>> import mindspore
97
+ >>> window_length = mindspore.tensor(5)
98
+ >>> output = mindspore.ops.bartlett_window(window_length)
99
+ >>> print(output)
100
+ [0. 0.4 0.8 0.8 0.4]
101
+ >>> output = mindspore.ops.bartlett_window(window_length, periodic=False)
135
102
  >>> print(output)
136
- [0. 0.4 0.8 0.8 0.4]
103
+ [0. 0.5 1. 0.5 0. ]
137
104
  """
138
105
  if dtype is None:
139
106
  dtype = mstype.float32
@@ -26,9 +26,7 @@ def vmap(fn, in_axes=0, out_axes=0):
26
26
 
27
27
  Vmap is pioneered by Jax and it removes the restriction of batch dimension on the operator, and provides a
28
28
  more convenient and unified operator expression. Moreover, it allows users to composite with other functional
29
- modules such as :func:`mindspore.grad`, to improve the development efficiency, please refer to the
30
- `Automatic Vectorization (Vmap) <https://www.mindspore.cn/docs/en/master/model_train/train_process/optimize/vmap.html>`_
31
- tutorial for more detail.
29
+ modules such as :func:`mindspore.grad`, to improve the development efficiency.
32
30
  In addition, the vectorizing map does not execute loops outside the function, but sinks loops
33
31
  into the primitive operations of the function for better performance. When combined with `Graph Kernel Fusion`,
34
32
  operational efficiency would be further improved.
@@ -49,21 +47,30 @@ def vmap(fn, in_axes=0, out_axes=0):
49
47
  argument and returns one or more Tensors or the type of data supported by the MindSpore Tensor. When it is
50
48
  a CellList, the model ensembling scenario, please make sure that the structure of each cell is the same
51
49
  and the number of cells is consistent with the sizes of the mapped axes (`axis_size`).
52
- in_axes (Union[int, list, tuple]): Specifies which dimensions (axes) of the inputs should be mapped over.
53
- If `in_axes` is an integer, all arguments of `fn` are mapped over according to this axis index. If `in_axes`
54
- is a tuple or list, which only composed of integers or Nones and the length should equal to the number of
55
- positional arguments to `fn`, indicates which axis to map for each corresponding positional argument.
56
- Note that, axis integers must be in range :math:`[-ndim, ndim)` for each argument, where `ndim` is the
57
- number of dimensions of the corresponding argument. None means not mapping along any axis. Also the
58
- mapping axis index of the `in_axes` must have at least one positional parameter not None. The sizes of
59
- the mapped axes (`axis_size`) for all arguments must be equal. Default: ``0`` .
50
+ in_axes (Union[int, list, tuple]): Specifies which dimensions (axes)
51
+ of the inputs should be mapped over. Default: ``0`` .
52
+
53
+ - If `in_axes` is an integer, all arguments of `fn` are mapped over according to this axis index.
54
+ - If `in_axes` is a tuple or list, which only composed of integers or Nones
55
+ and the length should equal to the number of
56
+ positional arguments to `fn`, indicates which axis to map for each corresponding positional argument.
57
+ Note that, axis integers must be in range :math:`[-ndim, ndim)` for each argument, where `ndim` is the
58
+ number of dimensions of the corresponding argument.
59
+ - None means not mapping along any axis. Also the
60
+ mapping axis index of the `in_axes` must have at least one positional parameter not None. The sizes of
61
+ the mapped axes (`axis_size`) for all arguments must be equal.
62
+
60
63
  out_axes (Union[int, list, tuple]): Specifies where the mapped dimensions (axes) should appear in the
61
- outputs. If `out_axes` is an integer, all outputs of `fn` are specified according to this axis. If
62
- `out_axes` is a tuple or list, which only composed of integers or Nones. And its length also should be equal
63
- to the number of outputs of `fn`. Note that, axis integers must be in range :math:`[-ndim, ndim)` for each
64
- output, where `ndim` is the dimension of the output of the `vmap`-mapped function. All outputs with a
65
- non-None mapped axis must specify a non-None `out_axes`, and if outputs with None mapped axis specifies
66
- a non-None `out_axes`, the result broadcasts across the mapped axis. Default: ``0`` .
64
+ outputs. Default: ``0`` .
65
+
66
+ - If `out_axes` is an integer, all outputs of `fn` are specified according to this axis.
67
+ - If `out_axes` is a tuple or list, which only composed of integers or Nones.
68
+ And its length also should be equal
69
+ to the number of outputs of `fn`. Note that, axis integers must be in range :math:`[-ndim, ndim)` for each
70
+ output, where `ndim` is the dimension of the output of the `vmap`-mapped function.
71
+ - All outputs with a
72
+ non-None mapped axis must specify a non-None `out_axes`, and if outputs with None mapped axis specifies
73
+ a non-None `out_axes`, the result broadcasts across the mapped axis.
67
74
 
68
75
  Returns:
69
76
  Function, returns the Vectorized/Batched version function of `fn`. The arguments and outputs of this function
@@ -20,7 +20,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
20
20
  from mindspore.ops import _constants
21
21
  from mindspore.ops.function import *
22
22
  from mindspore.ops.function.array_func import chunk_ext, zero_
23
- from mindspore.ops.function.math_func import all, argmax_ext, float_power_ext, erfinv_, tanh_
23
+ from mindspore.ops.function.math_func import all, argmax_ext, float_power_ext, erfinv_, tanh_, bernoulli_ext
24
24
  from mindspore.ops.function.random_func import random_, uniform_ext, uniform_, normal_
25
25
  from mindspore.ops import operations as P
26
26
  from mindspore.ops.operations import array_ops
@@ -36,10 +36,10 @@ from mindspore.ops.function.math_func import dot
36
36
  from mindspore.ops.function.array_func import new_empty
37
37
  from mindspore.ops import auto_generate
38
38
  from mindspore.ops.auto_generate import cast
39
- from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
39
+ from mindspore.ops._utils.arg_dtype_cast import DtypeToEnum
40
40
  from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul, \
41
41
  scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow, \
42
- scalar_uadd, scalar_usub
42
+ scalar_uadd, scalar_usub, scalar_max, scalar_min
43
43
 
44
44
  typeof = Primitive('typeof')
45
45
  hastype = Primitive('hastype')
@@ -117,6 +117,7 @@ switch_layer = Primitive('switch_layer')
117
117
  reduced_shape = Primitive("reduced_shape")
118
118
  # shape_mul:input must be shape multiply elements in tuple(shape)
119
119
  shape_mul = _sequence_ops.shape_mul()
120
+ put_ = auto_generate.put_
120
121
 
121
122
  setattr(tensor_operator_registry, 'tuple_to_tensor',
122
123
  _sequence_ops.TupleToTensor)
@@ -307,6 +308,7 @@ setattr(tensor_operator_registry, 'ormqr', ormqr)
307
308
  setattr(tensor_operator_registry, 'masked_scatter', array_ops.MaskedScatter)
308
309
  setattr(tensor_operator_registry, 'index_put', array_ops.IndexPut)
309
310
  setattr(tensor_operator_registry, 'index_put_', auto_generate.index_put_)
311
+ setattr(tensor_operator_registry, 'put_', put_)
310
312
  setattr(tensor_operator_registry, 'quantile', quantile)
311
313
  setattr(tensor_operator_registry, 'nanquantile', nanquantile)
312
314
  setattr(tensor_operator_registry, 'orgqr', orgqr)
@@ -395,7 +397,7 @@ setattr(tensor_operator_registry, 'tensor_scatter_add', tensor_scatter_add)
395
397
  setattr(tensor_operator_registry, 'inplace_scatter_add', auto_generate.inplace_scatter_add)
396
398
  setattr(tensor_operator_registry, 'slice_scatter', slice_scatter)
397
399
  setattr(tensor_operator_registry, 'select_scatter', select_scatter)
398
- setattr(tensor_operator_registry, 'bernoulli', bernoulli)
400
+ setattr(tensor_operator_registry, 'bernoulli', bernoulli_ext)
399
401
  setattr(tensor_operator_registry, 'poisson', P.Poisson)
400
402
  setattr(tensor_operator_registry, 'randperm', P.Randperm)
401
403
  setattr(tensor_operator_registry, 'multinomial', multinomial)