mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,295 @@
1
+ # Copyright 2025 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """ Param and grad buffer, bucket implemenatrion. """
16
+ from __future__ import absolute_import
17
+
18
+ __all__ = ["Bucket", "FlattenGradBuffer"]
19
+
20
+ from enum import Enum
21
+ import numpy as np
22
+ from mindspore import mint, Tensor
23
+ from mindspore.common.initializer import Zero
24
+ from mindspore.communication.management import get_group_size
25
+ import mindspore.communication.comm_func as comm_func
26
+
27
+
28
+ class BufferType(Enum):
29
+ PARAM = 0
30
+ GRAD = 1
31
+
32
+
33
+ MEM_ALIGN_SIZE = 512
34
+ ALIGN_BYTES = 32
35
+ MIN_BUCKET_SIZE = int(1 * 1024 * 1024)
36
+ DEFAULT_BUCKET_SIZE = int(25 * 1024 * 1024)
37
+
38
+
39
+ class Bucket:
40
+ """
41
+ Bucket to track a subset of parameters and gradients in the buffer. Bucket records the parameters
42
+ whose gradient has already been computed. It also provide functionality to synchronize gradients among
43
+ data parallel group when all parameters' graidents have been computed.
44
+
45
+ Args:
46
+ average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
47
+ params (List(Parameters)): Parameters belongs to this bucket.
48
+ grad_data (Tensor): A section of buffers' gradient data, coressponding to parameters in this bucket.
49
+ offset (int): Start index in the buffer.
50
+ numel_unpadded (int): Number of unpadded elements in bucket.
51
+ data_parallel_group (str): Data parallel group name.
52
+ data_parallel_world_size (int): Data parallel group size.
53
+ gradient_scaling_factor (float): Work with average_in_collective, it is 1.0 when average_in_collective
54
+ true else 1.0/dp
55
+ """
56
+
57
+ def __init__(self, average_in_collective, params, grad_data, offset, numel_unpadded, data_parallel_group,
58
+ data_parallel_world_size, gradient_scaling_factor):
59
+ self.average_in_collective = average_in_collective
60
+ self.params_list = params
61
+ self.params = set(params)
62
+ self.params_grad_ready = set()
63
+ self.grad_data = grad_data
64
+ self.grad_data_numel = self.grad_data.numel()
65
+ self.offset = offset
66
+ self.numel_unpadded = numel_unpadded
67
+ self.data_parallel_group = data_parallel_group
68
+ self.data_parallel_world_size = data_parallel_world_size
69
+ self.gradient_scaling_factor = gradient_scaling_factor
70
+
71
+ if self.data_parallel_world_size > 1:
72
+ self.grad_reducer = comm_func.all_reduce
73
+
74
+ self.reset()
75
+
76
+ def inplace_reduce_dp(self, src):
77
+ """conduct all-reduce/reduce-scatter on src tensor and inplace update result into target."""
78
+ self.communication_result, self.communication_handle = self.grad_reducer(
79
+ src, "sum", self.data_parallel_group, async_op=True
80
+ )
81
+
82
+ def reset(self):
83
+ """reset bucket for the next iteration."""
84
+ self.params_grad_ready = set()
85
+ self.is_reduce_issued = False
86
+ self.communication_handle = None
87
+ self.communication_result = None
88
+
89
+ def issue_grad_reduce(self):
90
+ """issue grad reduce for the local grad data view."""
91
+ if self.is_reduce_issued:
92
+ raise RuntimeError("The bucket reduce is already issued")
93
+
94
+ if self.gradient_scaling_factor != 1.0:
95
+ self.grad_data.copy_(mint.mul(self.grad_data, self.gradient_scaling_factor))
96
+
97
+ if self.data_parallel_world_size > 1:
98
+ self.inplace_reduce_dp(self.grad_data)
99
+
100
+ self.is_reduce_issued = True
101
+
102
+ def final_grad_reduce(self):
103
+ """finalize grad reduce for the local grad data view."""
104
+ start_idx = 0
105
+ end_idx = self.grad_data_numel
106
+ target = self.grad_data[start_idx:end_idx]
107
+
108
+ if not self.is_reduce_issued:
109
+ raise RuntimeError(
110
+ f"The bucket reduce has not been issued "
111
+ f"with only {len(self.params_grad_ready)}/{len(self.params)} params ready"
112
+ )
113
+
114
+ if self.data_parallel_world_size > 1:
115
+ self.communication_handle.wait()
116
+ target.copy_(self.communication_result)
117
+ self.communication_result = None
118
+ if self.average_in_collective:
119
+ target.copy_(mint.div(target, self.data_parallel_world_size))
120
+
121
+ def register_grad_ready(self, param):
122
+ """register grad ready and issue bucket grad reduce when the bucket is ready."""
123
+ if param not in self.params:
124
+ raise ValueError("The param to be registered is not in the bucket")
125
+
126
+ if param in self.params_grad_ready:
127
+ raise ValueError(f"The param {param} is already registered")
128
+
129
+ self.params_grad_ready.add(param)
130
+ if len(self.params_grad_ready) == len(self.params):
131
+ self.issue_grad_reduce()
132
+ return True
133
+
134
+ return False
135
+
136
+ def __repr__(self):
137
+ return f"Bucket (offset={self.offset}, param_lens={len(self.params)})"
138
+
139
+
140
+ class FlattenGradBuffer:
141
+ """
142
+ Allocate contiguous memory buffer for given parameters and corresponding gradients. Breaking
143
+ up parameters and gradients buffer into small buckets, which is the unit for all-reduce/reduce-scatter
144
+ communication during back-propagation.
145
+
146
+ Args:
147
+ average_in_collective (bool): Scaling grads before/after AllReduce, True: scaling after AllReduce.
148
+ param_dtype (mindspore.dtype): The parameters' datatype.
149
+ grad_dtype (mindspore.dtype): The gradients' datatype.
150
+ params (List(Parameters)): Parameters belongs to this buffer.
151
+ data_parallel_group (str): Data parallel group name.
152
+ bucket_size (int): Bucket size threshold used to partition bucekts.
153
+ gradient_scaling_factor (float):
154
+ """
155
+
156
+ def __init__(self, average_in_collective, param_dtype, grad_dtype, params, data_parallel_group,
157
+ bucket_size, gradient_scaling_factor, ddp_handle):
158
+ super(FlattenGradBuffer, self).__init__()
159
+ self.param_dtype = param_dtype
160
+ self.grad_dtype = grad_dtype
161
+ self.data_parallel_group = data_parallel_group
162
+ self.data_parallel_world_size = get_group_size(group=self.data_parallel_group)
163
+ self.gradient_scaling_factor = gradient_scaling_factor
164
+ self.average_in_collective = average_in_collective
165
+
166
+ self.buckets = []
167
+ self.param_index_map = {}
168
+ self.param_to_bucket = {}
169
+ self.sync_enabled = True
170
+ self.issued = 0
171
+ self.ddp_handle = ddp_handle
172
+
173
+ buckets_metadata = self.calc_partition_metadata(bucket_size, params)
174
+ self.instantiate_buckets(buckets_metadata, params)
175
+
176
+ def calc_partition_metadata(self, bucket_size, params):
177
+ """calc bucket partition metadata"""
178
+ # helper func
179
+ def _need_new_bucket(bucket_numel, bucket_id):
180
+ target_bucket_size = bucket_size
181
+ if bucket_id == 0 and bucket_size == DEFAULT_BUCKET_SIZE:
182
+ target_bucket_size = MIN_BUCKET_SIZE
183
+ return (
184
+ bucket_size is not None
185
+ and bucket_numel != 0
186
+ and bucket_numel >= target_bucket_size
187
+ )
188
+
189
+ def _build_bucket():
190
+ nonlocal buckets_metadata, bucket_start_index, bucket_params, bucket_id
191
+ bucket_end_index = data_start_index
192
+ buckets_metadata.append(
193
+ (bucket_start_index, bucket_end_index, bucket_params)
194
+ )
195
+ bucket_start_index = bucket_end_index
196
+ bucket_id = bucket_id + 1
197
+ bucket_params = []
198
+
199
+ param_data_list = []
200
+ buckets_metadata = []
201
+ data_start_index = 0
202
+ data_end_index = 0
203
+ bucket_id = 0
204
+ bucket_start_index = 0
205
+ bucket_params = []
206
+ for param in params[::]: # traverse from the beginning
207
+ last_bucket_numel = data_start_index - bucket_start_index
208
+ if _need_new_bucket(last_bucket_numel, bucket_id):
209
+ _build_bucket()
210
+ data_end_index = data_start_index + param.numel()
211
+ bucket_params.append(param)
212
+ param_data_list.append(param)
213
+ self.param_index_map[param] = (data_start_index, data_end_index, bucket_id)
214
+ data_start_index = data_end_index
215
+
216
+ # add bucket for the last few params which do not reach the bucket_size threshold
217
+ if data_start_index - bucket_start_index > 0:
218
+ bucket_end_index = data_start_index
219
+ buckets_metadata.append(
220
+ (bucket_start_index, bucket_end_index, bucket_params)
221
+ )
222
+ data_start_index = bucket_end_index
223
+
224
+ # allocate contiguous memory for parameters and gradients
225
+ self.numel = data_start_index
226
+ self.grad_data = Tensor(shape=(self.numel), dtype=self.grad_dtype, init=Zero())
227
+ self.grad_data.init_data()
228
+ self.numel_unpadded = 0
229
+ return buckets_metadata
230
+
231
+ def instantiate_buckets(self, buckets_metadata, params):
232
+ """build bucket instance according to partition metadata"""
233
+ for bucket_start_index, bucket_end_index, bucket_params in buckets_metadata:
234
+ local_grad_data = self.grad_data[bucket_start_index:bucket_end_index]
235
+ self.numel_unpadded += bucket_end_index - bucket_start_index
236
+ bucket = Bucket(
237
+ average_in_collective=self.average_in_collective,
238
+ params=bucket_params,
239
+ grad_data=local_grad_data,
240
+ offset=bucket_start_index,
241
+ numel_unpadded=bucket_end_index - bucket_start_index,
242
+ data_parallel_group=self.data_parallel_group,
243
+ data_parallel_world_size=self.data_parallel_world_size,
244
+ gradient_scaling_factor=self.gradient_scaling_factor,
245
+ )
246
+ self.buckets.append(bucket)
247
+ for param in bucket_params:
248
+ self.param_to_bucket[param] = bucket
249
+
250
+ for param in params:
251
+ data_start_index, _, _ = self.param_index_map[param]
252
+ param.grad = self._get_buffer_slice(
253
+ param.shape, data_start_index, BufferType.GRAD
254
+ )
255
+
256
+ def _get_buffer_slice(self, shape, start_index, buffer_type):
257
+ """get the buffer view with the same shape"""
258
+ end_index = start_index + int(np.prod(shape))
259
+ if start_index < 0 or end_index > self.numel:
260
+ raise ValueError("index out of range")
261
+ if buffer_type == BufferType.GRAD:
262
+ buffer_tensor = self.grad_data[start_index:end_index]
263
+ else:
264
+ raise TypeError("Invalid buffer type for _get_buffer_slice.")
265
+ buffer_tensor = buffer_tensor.view(shape)
266
+ return buffer_tensor
267
+
268
+ def reset(self):
269
+ """reset buffer for the next iteration."""
270
+ self.grad_data.zero_()
271
+ for bucket in self.buckets:
272
+ bucket.reset()
273
+ self.sync_enabled = True
274
+
275
+ def final_grad_reduce(self):
276
+ """finalize grad reduce for each bucket"""
277
+ for bucket in self.buckets:
278
+ bucket.final_grad_reduce()
279
+
280
+ def register_grad_ready(self, param):
281
+ """register ready grad in its buckets"""
282
+ if self.sync_enabled:
283
+ bucket = self.param_to_bucket[param]
284
+ if bucket.register_grad_ready(param):
285
+ self.issued += 1
286
+ if self.issued == len(self.buckets):
287
+ self.ddp_handle.buffer_issued += 1
288
+ if self.ddp_handle.buffer_issued == len(self.ddp_handle.buffers):
289
+ self.ddp_handle.final_grad_reduce()
290
+
291
+ def __repr__(self):
292
+ param_index_with_name = {
293
+ param.name: index for (param, index) in self.param_index_map.items()
294
+ }
295
+ return f"Buffer has buckets: \n {self.buckets} \n and param_index_map: \n {param_index_with_name}"
@@ -42,11 +42,12 @@ def reshard(tensor, layout):
42
42
  can check :class:`mindspore.parallel.Layout` for reference.
43
43
 
44
44
  Note:
45
- - In the Graph mode, this function can set the sharding propagation strategy of a tensor.
46
- For those tensor do not manually be set, their strategies are decided by the sharding
47
- strategy propagation algorithm automatically.
48
- - In PyNative mode, you can use this method to arrange tensors in a cell (that is, cells
49
- that use Cell.shard/F.shard in PyNative mode) that is executed in parallel in graph mode.
45
+ In the Graph mode, this function can set the sharding propagation strategy of a tensor.
46
+ For those tensor do not manually be set, their strategies are decided by the sharding
47
+ strategy propagation algorithm automatically.
48
+
49
+ .. warning::
50
+ The method is currently not supported in PyNative mode.
50
51
 
51
52
  Args:
52
53
  tensor (Tensor): The tensor to be set the sharding strategy.
@@ -28,7 +28,8 @@ from mindspore import log as logger
28
28
 
29
29
  class PipelineCell(Cell):
30
30
  """
31
- Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training.
31
+ Slice MiniBatch into finer-grained MicroBatch for use in pipeline-parallel training,
32
+ and specify the segment info.
32
33
 
33
34
  Note:
34
35
  micro_size must be greater or equal to pipeline stages.
@@ -37,6 +38,8 @@ class PipelineCell(Cell):
37
38
  network (Cell): The target network to wrap.
38
39
  micro_size (int): MicroBatch size.
39
40
  stage_config (dict, optional): The stage configuration for each cell's execution in pipeline parallel.
41
+ segment_config (dict, optional): The segment configuration for each cell's execution in pipeline parallel.
42
+ Default ``None``.
40
43
 
41
44
  Supported Platforms:
42
45
  ``Ascend``
@@ -48,7 +51,7 @@ class PipelineCell(Cell):
48
51
  >>> net = LeNet5()
49
52
  >>> net = nn.PipelineCell(net, 4, stage_config={"cell_name_0": 0, "cell_name_1": 1})
50
53
  """
51
- def __init__(self, network, micro_size, stage_config=None):
54
+ def __init__(self, network, micro_size, stage_config=None, segment_config=None):
52
55
  super(PipelineCell, self).__init__(auto_prefix=False)
53
56
  self.network = network
54
57
  self.micro_inputs = nn.CellList()
@@ -104,6 +107,37 @@ class PipelineCell(Cell):
104
107
  logger.warning(cell_name)
105
108
  raise KeyError("For 'PipelineCell', the argument 'stage_config' : {} is not "
106
109
  "found in 'network' : {}".format(config_dict, network))
110
+ if segment_config is None:
111
+ return
112
+ self._config_segment(segment_config)
113
+
114
+
115
+ def _config_segment(self, segment_config):
116
+ """
117
+ Config segment num for cell.
118
+ """
119
+ config_dict = segment_config.copy()
120
+
121
+ for cell_name, cell in self.network.cells_and_names():
122
+ if cell_name in segment_config:
123
+ setattr(cell, "pipeline_segment", segment_config[cell_name])
124
+ del config_dict[cell_name]
125
+ if str(self.network) in segment_config:
126
+ setattr(self.network, "pipeline_segment", segment_config[str(self.network)])
127
+ del config_dict[str(self.network)]
128
+ # if there are any config elements left, print them
129
+ if config_dict:
130
+ for config_cell_name, config_segment_num in config_dict.items():
131
+ logger.error("pipeline_cell segment_config set pipeline_segment fail!")
132
+ logger.warning("config cell name:" + str(config_cell_name) +
133
+ " config segment num:" + str(config_segment_num))
134
+ logger.warning("network:" + str(self.network))
135
+ logger.warning("cell name available:")
136
+ for cell_name, _ in self.network.cells_and_names():
137
+ logger.warning(cell_name)
138
+ raise KeyError("For 'PipelineCell', the argument 'segment_config' : {} is not "
139
+ "found in 'network' : {}".format(config_dict, self.network))
140
+
107
141
 
108
142
  def construct(self, *args, **kwargs):
109
143
  ret = None
@@ -119,7 +153,8 @@ class PipelineCell(Cell):
119
153
 
120
154
  class Pipeline(PipelineCell):
121
155
  """
122
- Specify the number of micro_batch for pipeline parallelism and the division rules for stage.
156
+ Specify the number of micro_batch for pipeline parallelism and the division rules for stage,
157
+ and specify the segment info.
123
158
 
124
159
  Note:
125
160
  micro_size must be greater or equal to pipeline stages.
@@ -128,6 +163,8 @@ class Pipeline(PipelineCell):
128
163
  network (Cell): The target network to wrap.
129
164
  micro_size (int): MicroBatch size.
130
165
  stage_config (dict, optional): Stage configuration for cell's execution in pipeline parallel. Default ``None``.
166
+ segment_config (dict, optional): The segment configuration for each cell's execution in pipeline parallel.
167
+ Default ``None``.
131
168
 
132
169
  Raises:
133
170
  TypeError: The type of `net` is not cell.
@@ -17,7 +17,6 @@ from __future__ import absolute_import
17
17
 
18
18
  __all__ = ['PipelineGradReducer']
19
19
 
20
- from mindspore import context
21
20
  from mindspore.nn.cell import Cell
22
21
  from mindspore.ops import functional as F, composite as C, operations as P
23
22
  import mindspore.common.dtype as mstype
@@ -140,7 +139,6 @@ class PipelineGradReducer(Cell):
140
139
  """
141
140
  def __init__(self, parameters, scale_sense=1.0, opt_shard=None):
142
141
  super(PipelineGradReducer, self).__init__(auto_prefix=False)
143
- self._check_mode()
144
142
  self.accu_grads = parameters.clone(prefix="accu_grads", init="zeros")
145
143
  self.grad_reducer = Identity()
146
144
  self.degree = Tensor(1, mstype.float32)
@@ -162,9 +160,3 @@ class PipelineGradReducer(Cell):
162
160
  accu_grads = self.grad_reducer(self.accu_grads)
163
161
  new_grads = self.hyper_map(F.partial(grad_scale, self.scale_sense * self.degree), grads, accu_grads)
164
162
  return new_grads
165
-
166
- def _check_mode(self):
167
- """check parallel mode"""
168
- mode = context.get_context('mode')
169
- if mode != context.GRAPH_MODE:
170
- raise RuntimeError(f"PipelineGradReducer only support graph mode, but get {mode}")
@@ -253,13 +253,6 @@ class Shard(Shard_):
253
253
  "will be overwritten as False.")
254
254
  ms.set_algo_parameters(fully_use_devices=False)
255
255
 
256
- if ms.context.get_auto_parallel_context("full_batch_is_set") is False and \
257
- ms.context.get_context("mode") == ms.context.PYNATIVE_MODE:
258
- logger.warning("When calling the shard interface, "
259
- "'dataset_strategy' or 'full_batch' is not manually set by the user, "
260
- "and the 'dataset_strategy' will be set to 'full_batch'.")
261
- ms.context.set_auto_parallel_context(dataset_strategy="full_batch")
262
-
263
256
  if self._is_attrs_has_been_set(fn, in_strategy, out_strategy, device, level):
264
257
  return self.shard_fn
265
258
  shard_ = Shard()
@@ -394,11 +387,10 @@ class Shard(Shard_):
394
387
  f"The tuple strategy for each dimension should be tuple(int).")
395
388
 
396
389
 
397
- def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
390
+ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None):
398
391
  """
399
392
  Specify the input and output slicing strategy for a Cell or function.
400
- In PyNative mode, use this method to specify a Cell for distributed
401
- execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
393
+ In Graph mode, use this method to specify distribution strategy for a Cell,
402
394
  strategy for others will be set by sharding propagation.
403
395
  in_strategy and out_strategy define the input and output layout respectively.
404
396
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
@@ -410,7 +402,9 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
410
402
  - It is valid only in semi auto parallel or auto parallel mode.
411
403
  In other parallel modes, strategies set here will be ignored.
412
404
  - If the input contain Parameter, its strategy should be set in `in_strategy`.
413
- - This method currently does not support dynamic shapes.
405
+
406
+ .. warning::
407
+ The method is currently not supported in PyNative mode.
414
408
 
415
409
  Args:
416
410
  fn (Union[Cell, Function]): Function to be executed in parallel.
@@ -432,19 +426,12 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
432
426
  has been set, the parameter setting will be ignored. Supported
433
427
  only when `fn` is a Cell with parameters.
434
428
  Default: ``None`` .
435
- device (str, optional): Select a certain `device` target. It is not in use right now.
436
- Support ["CPU", "GPU", "Ascend"]. Default: ``"Ascend"`` .
437
- level (int, optional): Option for parallel strategy infer algorithm, namely the object function,
438
- maximize computation
439
- over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
440
- use right now. Support [0, 1, 2]. Default: ``0`` .
441
429
 
442
430
  Returns:
443
431
  Function, return the function that will be executed under auto parallel process.
444
432
 
445
433
  Raises:
446
434
  AssertionError: If parallel mode is not "auto_parallel" nor "semi_auto_parallel".
447
- AssertionError: If device_target it not "Ascend" or "GPU".
448
435
  TypeError: If `in_strategy` is not a tuple.
449
436
  TypeError: If `out_strategy` is not a tuple or None.
450
437
  TypeError: If any element in `in_strategy` is not a tuple(int) or tuple(mindspore.parallel.Layout).
@@ -452,8 +439,6 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
452
439
  TypeError: If `parameter_plan` is not a dict or None.
453
440
  TypeError: If any key in `parameter_plan` is not a str.
454
441
  TypeError: If any value in `parameter_plan` is not a tuple(int) or a tuple(mindspore.parallel.Layout).
455
- TypeError: If `device` is not a str.
456
- TypeError: If `level` is not an integer.
457
442
 
458
443
  Supported Platforms:
459
444
  ``Ascend``
@@ -556,4 +541,5 @@ def shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device="Ascen
556
541
  if not isinstance(fn, (ms.nn.Cell)):
557
542
  logger.warning("'fn' is not a mindspore.nn.Cell, and its definition cannot involve Parameter; "
558
543
  "otherwise, the result may be incorrect.")
559
- return Shard()(fn, in_strategy, out_strategy, parameter_plan, device, level)
544
+
545
+ return Shard()(fn, in_strategy, out_strategy, parameter_plan)