mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -44,18 +44,20 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
44
44
  from mindspore._c_expression.amp import get_curr_amp_strategy
45
45
  from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
46
46
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
47
- _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
47
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, TensorPy as Tensor, dump_func_graph, _GraphFragment_
48
48
  from mindspore.parallel._ps_context import _is_role_sched
49
49
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
50
50
  _is_parallel_mode
51
51
  from mindspore import _checkparam as Validator
52
52
  from mindspore._checkparam import is_stub_tensor
53
- from mindspore.common._utils import is_shape_unknown
53
+ from mindspore.common._utils import is_shape_unknown, get_func
54
54
  from mindspore.common.mutable import mutable, _check_element_type
55
- from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
56
- get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
+ from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
56
+ update_auto_dynamic_shape_phase
57
+ from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
57
58
  from mindspore.common._pijit_context import PIJitCaptureContext
58
- from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
59
+ from mindspore.common.parameter import Parameter
60
+ from mindspore.common.hook_handle import _hook_version
59
61
  from mindspore.common.jit_context import jit_context
60
62
  from mindspore.common.jit_trace import _jit_trace
61
63
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
@@ -74,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
74
76
  TOTAL_ARG_LEN = "total_arg_length"
75
77
 
76
78
 
79
+ def _real_phase(phase, obj):
80
+ real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
81
+ return real_phase
82
+
83
+
77
84
  def _check_recompile_args(compile_args, kwargs):
78
85
  """Check recompile of graph"""
79
86
 
@@ -201,6 +208,11 @@ def _handle_func_args(func, *args, **kwargs):
201
208
  args = bound_arguments.args
202
209
  kwargs = bound_arguments.kwargs
203
210
 
211
+ return args, kwargs
212
+
213
+
214
+ def _check_func_args(func, *args):
215
+ """Check the *args inputs of the function"""
204
216
  positional_args = 0
205
217
  default_args = 0
206
218
  has_var = False
@@ -214,14 +226,13 @@ def _handle_func_args(func, *args, **kwargs):
214
226
  default_args += 1
215
227
 
216
228
  if has_var:
217
- return args, kwargs
229
+ return
218
230
 
219
231
  if len(args) < positional_args:
220
232
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
221
233
  if len(args) > positional_args + default_args:
222
234
  raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
223
235
  f"default argument, total {positional_args + default_args}, but got {len(args)}.")
224
- return args, kwargs
225
236
 
226
237
 
227
238
  sys_path = list(sys.path)
@@ -342,7 +353,7 @@ def _get_parameter_layout():
342
353
  return layout
343
354
 
344
355
 
345
- def _handle_arg(obj, arg, has_mutable_arg):
356
+ def _handle_arg(obj, arg, has_mutable_arg, is_predict):
346
357
  """Handle arg for runtime .If need handle the arg, return True"""
347
358
  from mindspore._extends.parse import compile_config
348
359
  if isinstance(arg, PythonTensor):
@@ -357,7 +368,7 @@ def _handle_arg(obj, arg, has_mutable_arg):
357
368
  if isinstance(arg, list) and not arg:
358
369
  return None
359
370
  return arg
360
- elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
371
+ elif not is_predict and (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
361
372
  isinstance(arg, (int, float)):
362
373
  return arg
363
374
  elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
@@ -387,17 +398,16 @@ def _handle_arg_predict(obj, arg, has_mutable_arg):
387
398
  return arg
388
399
 
389
400
 
390
- def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
401
+ def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict=False):
391
402
  """Get the actual input args and kwargs for runtime."""
392
403
  new_args = []
393
- fn = _handle_arg_predict if is_predict else _handle_arg
394
404
  for arg, has_mutable_arg in zip(args, has_mutable_args_list):
395
- new_arg = fn(obj, arg, has_mutable_arg)
405
+ new_arg = _handle_arg(obj, arg, has_mutable_arg, is_predict)
396
406
  if new_arg is not None:
397
407
  new_args.append(new_arg)
398
408
 
399
409
  for _, value in kwargs.items():
400
- new_value = fn(obj, value, None)
410
+ new_value = _handle_arg(obj, value, None, is_predict)
401
411
  if new_value is not None:
402
412
  new_args.append(new_value)
403
413
 
@@ -538,10 +548,12 @@ def _get_parameter_ids(args, kwargs):
538
548
  parameter_ids += str(id(value))
539
549
  return parameter_ids
540
550
 
551
+
541
552
  def _get_tensor_hook_key(tensor):
542
553
  """Get the hook key of Tensor/Parameter"""
543
554
  return ".".join(map(str, map(id, tensor.hooks())))
544
555
 
556
+
545
557
  def _get_hook_key(*args, **kwargs):
546
558
  """Get the hook key of Tensors/Parameters"""
547
559
  hook_key = ""
@@ -588,6 +600,8 @@ class _JitExecutor:
588
600
 
589
601
  self.fn = fn
590
602
  self.input_signature = input_signature
603
+ self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
604
+ self.enable_jit_dynamic = self.dynamic_args_shapes is not None
591
605
  self.obj = None
592
606
  if obj and hasattr(obj, fn.__name__):
593
607
  self.obj = obj
@@ -598,7 +612,7 @@ class _JitExecutor:
598
612
  else:
599
613
  self._graph_executor = GraphExecutor_.get_instance()
600
614
  self._create_time = ms_create_time
601
- self._compile_args = None
615
+ self._mutable_flags = None
602
616
  self._enable_auto_dynamic = dynamic == 1
603
617
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
604
618
  self._cell_cache_key_extend = cell_cache_key_extend
@@ -623,18 +637,8 @@ class _JitExecutor:
623
637
  except Exception as err:
624
638
  _pynative_executor.clear_res()
625
639
  raise err
626
- else: # get compiled args to generate run args by _generate_run_args
627
- compile_args = self._generate_compile_args(args_list)
628
- key_id = self._get_key_id()
629
- compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
630
- compile_args,
631
- key_id,
632
- self.input_signature,
633
- self._enable_auto_dynamic
634
- )
635
- self._compile_args = compile_args
636
640
 
637
- new_inputs = self._generate_run_args(args_list, kwargs)
641
+ new_inputs = self._generate_run_args(args_list, kwargs, is_predict=True)
638
642
  if self.jit_config_dict:
639
643
  jit_config_dict = self.jit_config_dict
640
644
  else:
@@ -647,11 +651,25 @@ class _JitExecutor:
647
651
  res = _convert_python_data(output)
648
652
  return True, res
649
653
 
654
+ def compile_frontend(self, *args, **kwargs):
655
+ """Only compile to the frontend graph."""
656
+ args_list = args
657
+ if self.obj is not None:
658
+ args_list = args_list[1:]
659
+ os.environ['MS_DEV_PRECOMPILE_ONLY'] = '1'
660
+ phase = ""
661
+ _pynative_executor.set_jit_compile_phase(phase)
662
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
663
+ _pynative_executor.set_jit_compile_phase(phase)
664
+ os.unsetenv('MS_DEV_PRECOMPILE_ONLY')
665
+ return self._graph_executor.get_func_graph(phase), self._mutable_flags, phase, self.enable_tuple_broaden
666
+
650
667
  @_wrap_func
651
668
  def __call__(self, *args, **kwargs):
652
669
  predict, res = self._predict(*args, **kwargs)
653
670
  if predict:
654
671
  return res
672
+ _check_func_args(self.fn, *args)
655
673
  if jit_context() and jit_context().is_nested():
656
674
  return jit_context().run_graph("", None, *())
657
675
  args_list = args
@@ -659,9 +677,9 @@ class _JitExecutor:
659
677
  args_list = args_list[1:]
660
678
  phase = ""
661
679
  try:
662
- _pynative_executor.set_jit_compile_status(True, phase)
680
+ _pynative_executor.set_jit_compile_phase(phase)
663
681
  phase = self.compile(self.fn.__name__, *args_list, **kwargs)
664
- _pynative_executor.set_jit_compile_status(False, phase)
682
+ _pynative_executor.set_jit_compile_phase(phase)
665
683
  except Exception as err:
666
684
  _pynative_executor.clear_res()
667
685
  raise err
@@ -684,24 +702,24 @@ class _JitExecutor:
684
702
 
685
703
  def compile(self, method_name, *args, **kwargs):
686
704
  """Returns pipeline for the given args."""
687
- # Check whether hook function registered on Cell object.
688
- if self.obj and hasattr(self.obj, "_hook_fn_registered"):
689
- if self.obj._hook_fn_registered():
690
- logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
691
- f"If you want to use hook function, please use context.set_context to set "
692
- f"pynative mode and remove 'jit' decorator.")
693
705
  # Chose dynamic shape tensors or actual input tensors as compile args.
706
+ self._graph_executor.set_real_args(args, kwargs)
694
707
  compile_args = self._generate_compile_args(args)
695
708
  key_id = self._get_key_id()
696
- compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
697
- self.input_signature,
698
- self._enable_auto_dynamic)
709
+ if self.input_signature is None:
710
+ compile_args = get_auto_dynamic_shape_args(
711
+ compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
712
+ )
699
713
 
700
714
  # Add mutable for compile_args for two scene:
701
715
  # 1) Origin args is mutable.
702
716
  # 2) Args contains sequence with gradient tensor.
703
717
  compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
704
- self._compile_args = compile_args
718
+ mutable_flags = _get_mutable_flags(compile_args)
719
+ self._mutable_flags = mutable_flags
720
+ # Store the _mutable_flags in the cell obj for incremental inference.
721
+ if self.obj is not None:
722
+ self.obj._mutable_flags = mutable_flags
705
723
  generate_name, echo_function_name = self._get_generate_name()
706
724
  # The full Function name
707
725
  full_function_name = generate_name
@@ -735,20 +753,23 @@ class _JitExecutor:
735
753
 
736
754
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
737
755
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
756
+ key = str(key)
738
757
 
739
758
  parameter_ids = _get_parameter_ids(args, kwargs)
740
759
  if parameter_ids != "":
741
- key = str(key) + '.' + parameter_ids
760
+ key += '.' + parameter_ids
742
761
 
743
- key = str(key) + "." + _get_hook_key(*args, **kwargs)
762
+ key += "." + _get_hook_key(*args, **kwargs)
763
+ key += "." + str(_hook_version())
744
764
 
745
- phase = generate_name + '.' + str(key)
765
+ phase = generate_name + '.' + key
746
766
 
747
- update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
767
+ if self.input_signature is None:
768
+ update_auto_dynamic_shape_phase(compile_args, key_id, phase)
748
769
 
749
770
  phase = phase + self._cell_cache_key_extend
750
771
 
751
- if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
772
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
752
773
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
753
774
  # generated in generate_arguments_key.
754
775
  self._graph_executor.clear_compile_arguments_resource()
@@ -765,16 +786,9 @@ class _JitExecutor:
765
786
 
766
787
  if self.obj is None:
767
788
  # Set an attribute to fn as an identifier.
768
- if isinstance(self.fn, types.MethodType):
769
- setattr(self.fn.__func__, "__jit_function__", True)
770
- else:
771
- setattr(self.fn, "__jit_function__", True)
772
- is_compile = self._graph_executor.compile(
773
- self.fn, compile_args, kwargs, phase, jit_config_dict)
774
- if isinstance(self.fn, types.MethodType):
775
- delattr(self.fn.__func__, "__jit_function__")
776
- else:
777
- delattr(self.fn, "__jit_function__")
789
+ setattr(get_func(self.fn), "__jit_function__", True)
790
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
791
+ delattr(get_func(self.fn), "__jit_function__")
778
792
  else:
779
793
  if isinstance(self.obj, ms.nn.Cell):
780
794
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
@@ -783,7 +797,6 @@ class _JitExecutor:
783
797
 
784
798
  if not is_compile:
785
799
  raise RuntimeError("Executor compile failed.")
786
- set_parameter_hook_updated(False)
787
800
  ms_compile_cache.add(phase)
788
801
  if hasattr(self.obj, "phase"):
789
802
  self.obj.phase_cache[self.obj.phase] = phase
@@ -831,43 +844,73 @@ class _JitExecutor:
831
844
  if enable_compile_cache is True or enable_compile_cache == "1":
832
845
  self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
833
846
 
847
+ def _generate_compile_args_by_enable_dynamic(self, args_list):
848
+ """Generate compile args by enable_dynamic."""
849
+ compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
850
+ compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
851
+ if self.obj is not None:
852
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
853
+ else:
854
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
855
+ logger.info(f"dynamic shape compile_args: {compile_args}")
856
+ Validator.check_symbolic_shape(compile_args, args_list)
857
+ return compile_args
858
+
859
+ def _generate_compile_args_by_set_inputs(self, args_list):
860
+ """Generate compile args by set_inputs."""
861
+ compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
862
+ if len(compile_args) != len(args_list):
863
+ raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
864
+ f"dynamic shape tensors: {len(compile_args)}.")
865
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
866
+ Validator.check_symbolic_shape(compile_args, args_list)
867
+ return compile_args
868
+
869
+ def _generate_compile_args_by_input_signature(self, args_list):
870
+ """Generate compile args by input_signature."""
871
+ compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
872
+ dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
873
+ Validator.check_symbolic_shape(self.input_signature, args_list)
874
+ if dyn_shape:
875
+ # Checkout whether the `sens` has been added to args_list.
876
+ if len(compile_args) == len(args_list) - 1:
877
+ logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
878
+ f"of input_signature args '{len(compile_args)}'. The last actual args may "
879
+ f"be 'sens' and added it to compile args.")
880
+ compile_args.append(args_list[-1])
881
+ compile_args = tuple(compile_args)
882
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
883
+ if self.obj is not None:
884
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
885
+ else:
886
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
887
+ else:
888
+ if not verify_inputs_signature(compile_args, args_list):
889
+ raise ValueError("The input args is incompatible with the args in `input_signature`!")
890
+ return compile_args
891
+
892
+ def _check_set_inputs(self):
893
+ """Check if the `set_inputs()` of Cell object has been set."""
894
+ return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
895
+
834
896
  def _generate_compile_args(self, args_list):
835
897
  """Chose dynamic shape tensors or actual input tensors as compile args."""
836
- # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
837
- compile_args = _pynative_executor.get_dynamic_input(args_list)
898
+ # Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
899
+ if self.enable_jit_dynamic and self._check_set_inputs():
900
+ raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
901
+ # Case: The `enable_dynamic` is provided.
902
+ if self.enable_jit_dynamic:
903
+ return self._generate_compile_args_by_enable_dynamic(args_list)
838
904
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
839
- if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
840
- compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
841
- if len(compile_args) != len(args_list):
842
- raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
843
- f"dynamic shape tensors: {len(compile_args)}.")
844
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
845
- Validator.check_symbolic_shape(compile_args, args_list)
846
-
905
+ if self._check_set_inputs():
906
+ return self._generate_compile_args_by_set_inputs(args_list)
847
907
  # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
848
908
  if self.input_signature is not None:
849
- compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
850
- dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
851
- Validator.check_symbolic_shape(self.input_signature, args_list)
852
- if dyn_shape:
853
- # Checkout whether the `sens` has been added to args_list.
854
- if len(compile_args) == len(args_list) - 1:
855
- logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
856
- f"of input_signature args '{len(compile_args)}'. The last actual args may "
857
- f"be 'sens' and added it to compile args.")
858
- compile_args.append(args_list[-1])
859
- compile_args = tuple(compile_args)
860
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
861
- if self.obj is not None:
862
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
863
- else:
864
- _pynative_executor.set_dynamic_input(self.fn, *compile_args)
865
- else:
866
- if not verify_inputs_signature(compile_args, args_list):
867
- raise ValueError("The input args is incompatible with the args in `input_signature`!")
868
- return compile_args
909
+ return self._generate_compile_args_by_input_signature(args_list)
910
+ # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
911
+ return _pynative_executor.get_dynamic_input(args_list)
869
912
 
870
- def _generate_run_args(self, args_list, kwargs):
913
+ def _generate_run_args(self, args_list, kwargs, is_predict=False):
871
914
  """
872
915
  Generate input args, which are required for running.
873
916
 
@@ -878,7 +921,11 @@ class _JitExecutor:
878
921
  Returns:
879
922
  new_inputs, new input args, which are required for running.
880
923
  """
881
- return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
924
+ if self.obj is not None and hasattr(self.obj, '_mutable_flags'):
925
+ mutable_flags = self.obj._mutable_flags
926
+ else:
927
+ mutable_flags = self._mutable_flags
928
+ return _get_args_for_run(self, args_list, kwargs, mutable_flags, is_predict)
882
929
 
883
930
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
884
931
  """Get graph proto from pipeline."""
@@ -950,7 +997,7 @@ def _check_option_backend(option, backend):
950
997
  'ge_options': ['GE'],
951
998
  'infer_boost': ['ms_backend'],
952
999
  }
953
- if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
1000
+ if option in option_backend_cfgs and backend != '' and backend not in option_backend_cfgs[option]:
954
1001
  logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
955
1002
  f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
956
1003
 
@@ -1077,10 +1124,7 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1077
1124
  process_obj = args[0]
1078
1125
  # Handle auto mixed precision strategy.
1079
1126
  if not hasattr(func, "amp_strategy"):
1080
- if isinstance(func, types.MethodType):
1081
- setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1082
- else:
1083
- setattr(func, "amp_strategy", get_curr_amp_strategy())
1127
+ setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
1084
1128
 
1085
1129
  jit_graph_name = ''
1086
1130
  if hasattr(staging_specialize, "__jit_graph_name__"):
@@ -1088,6 +1132,8 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1088
1132
  jit_executor = _JitExecutor(
1089
1133
  func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
1090
1134
  out = jit_executor(*args, **kwargs)
1135
+ if isinstance(process_obj, ms.nn.Cell):
1136
+ _clear_auto_parallel_context(process_obj)
1091
1137
  return out
1092
1138
 
1093
1139
  # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
@@ -1127,28 +1173,26 @@ def jit(
1127
1173
 
1128
1174
  Keyword Args:
1129
1175
  capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1130
- should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1176
+ should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
1131
1177
 
1132
- - `ast <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#ast>`_ :
1133
- Parse Python ast to build graph.
1134
- - `bytecode <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#bytecode>`_ :
1135
- Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1136
- change and/or deletion.
1137
- - `trace <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#trace>`_ : Trace the execution of Python code to build graph. This is an experimental prototype that is
1138
- subject to change and/or deletion.
1178
+ - ast: Parse Python ast to build graph.
1179
+ - bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
1180
+ that is subject to change and/or deletion.
1181
+ - trace: Trace the execution of Python code to build graph. This is an experimental prototype
1182
+ that is subject to change and/or deletion.
1139
1183
 
1140
1184
  jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1141
- with ms_backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1185
+ with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
1142
1186
 
1143
- - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1144
- - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1187
+ - O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
1188
+ - O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1145
1189
  level is experimental and is being improved.
1146
1190
 
1147
1191
  dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1148
1192
  is as follows:
1149
1193
 
1150
- - `0`: Do not perform dynamic shape compilation.
1151
- - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1194
+ - 0: Do not perform dynamic shape compilation.
1195
+ - 1: Enable dynamic shape compilation and automatically detect shape changes.
1152
1196
 
1153
1197
  fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1154
1198
  be compatible with all Python syntax in the function as much as possible. If True, we require that the
@@ -1156,12 +1200,16 @@ def jit(
1156
1200
  not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
1157
1201
  or ``bytecode``. Default: ``False``.
1158
1202
  backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1159
- use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1160
- A2 training series products by default.
1203
+ use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
1204
+ Atlas A2 training series products by default.
1161
1205
 
1162
- - `ms_backend`: Adopt KernelByKernel execution mode.
1163
- - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1164
- the top cell of model. And only can be used in Ascend platform.
1206
+ - ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
1207
+ optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
1208
+ - GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
1209
+ for Ascend model compilation and execution. Note: This backend only supports GRAPH Mode in Ascend,
1210
+ only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
1211
+ dynamic shape scenes. In addition, this backend incurs additional compilation costs and is difficult to
1212
+ debug and tune.
1165
1213
 
1166
1214
  **options (dict): A dictionary of options to pass to the compilation backend.
1167
1215
 
@@ -1184,11 +1232,11 @@ def jit(
1184
1232
  `disable_format_transform` can be set to ``True`` to try to improve training performance.
1185
1233
  Default: ``False`` .
1186
1234
  - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1187
- methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1235
+ methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
1188
1236
 
1189
- - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1237
+ - bfs: The default sorting method, breadth priority, good communication masking, relatively good
1190
1238
  performance.
1191
- - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1239
+ - dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1192
1240
  of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1193
1241
  other execution orders run out of memory (OOM).
1194
1242
 
@@ -1199,11 +1247,11 @@ def jit(
1199
1247
  - global (dict): Set global options.
1200
1248
  - session (dict): Set session options.
1201
1249
 
1202
- - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1250
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
1203
1251
  the inference mode is disabled. The range is as follows:
1204
1252
 
1205
- - `on`: Enable inference mode, get better infer performance.
1206
- - `off`: Disable inference mode, use forward for inference. The performance is poor.
1253
+ - on: Enable inference mode, get better infer performance.
1254
+ - off: Disable inference mode, use forward for inference. The performance is poor.
1207
1255
 
1208
1256
  Returns:
1209
1257
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -1306,9 +1354,8 @@ def jit(
1306
1354
  jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
1307
1355
  dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
1308
1356
  fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
1309
- if backend == "":
1310
- backend = "GE" if MSContext.get_instance().get_ascend_soc_version() == "ascend910" else "ms_backend"
1311
- backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1357
+ if backend != "":
1358
+ backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
1312
1359
  jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
1313
1360
  hash_obj = _get_hash_obj(options)
1314
1361
  _check_options(options, backend)
@@ -1323,7 +1370,7 @@ def jit(
1323
1370
  elif capture_mode == "bytecode":
1324
1371
  wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
1325
1372
  else:
1326
- wrap_func = _jit_trace()
1373
+ wrap_func = _jit_trace(jit_config)
1327
1374
 
1328
1375
  if function is not None:
1329
1376
  return wrap_func(function)
@@ -1530,6 +1577,20 @@ def _parameter_broadcast(obj):
1530
1577
  _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
1531
1578
 
1532
1579
 
1580
+ def _run_in_jit():
1581
+ """In jit, this function always returns true. Otherwise, returns false."""
1582
+ def _temp_func():
1583
+ return 0
1584
+
1585
+ from mindspore.ops.primitive import constexpr
1586
+
1587
+ @constexpr(check=False)
1588
+ def _check_func(func):
1589
+ return func is None
1590
+
1591
+ return _check_func(_temp_func)
1592
+
1593
+
1533
1594
  class _no_grad(contextlib.ContextDecorator):
1534
1595
  """
1535
1596
  Context Manager to disable gradient calculation. When enter this context, we will disable calculate
@@ -1799,17 +1860,16 @@ class _PyNativeExecutor:
1799
1860
  """
1800
1861
  return self._executor.requires_grad()
1801
1862
 
1802
- def set_jit_compile_status(self, status, phase):
1863
+ def set_jit_compile_phase(self, phase):
1803
1864
  """
1804
- Set jit is compiling
1865
+ Set jit phase
1805
1866
 
1806
1867
  Args:
1807
- status(bool): jit compile status
1808
1868
  phase (str): The phase of cell/function instance.
1809
1869
  Return:
1810
1870
  None.
1811
1871
  """
1812
- self._executor.set_jit_compile_status(status, phase)
1872
+ self._executor.set_jit_compile_phase(phase)
1813
1873
 
1814
1874
  def set_is_run_recompute(self, status):
1815
1875
  """
@@ -1894,6 +1954,32 @@ class _PyNativeExecutor:
1894
1954
  """
1895
1955
  return self._executor.constant_folding(*args)
1896
1956
 
1957
+ def set_creation_type(self, tensor, creation_type):
1958
+ """
1959
+ Set tensor's view creation type
1960
+
1961
+ Args:
1962
+ tensor (Tensor): input tensor.
1963
+ creation_type (CreationType): The type of view tensor when it is created.
1964
+
1965
+ Return:
1966
+ None.
1967
+ """
1968
+ return self._executor.set_creation_type(tensor, creation_type)
1969
+
1970
+ def queue_backward_final_callback(self, callback):
1971
+ """
1972
+ add backward final callback
1973
+
1974
+ Args:
1975
+ callback(Function): callback function.
1976
+
1977
+ Return:
1978
+ None.
1979
+ """
1980
+ return self._executor.queue_backward_final_callback(callback)
1981
+
1982
+
1897
1983
 
1898
1984
  class _CellGraphExecutor:
1899
1985
  """
@@ -2002,6 +2088,11 @@ class _CellGraphExecutor:
2002
2088
  if not hasattr(obj, obj.__parse_method__):
2003
2089
  raise AttributeError(
2004
2090
  'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2091
+ inner_func = inspect.unwrap(obj.construct)
2092
+ if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
2093
+ raise ValueError(
2094
+ "When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
2095
+ )
2005
2096
  key_id = str(id(obj)) + str(obj.create_time)
2006
2097
  args = get_auto_dynamic_shape_args(args, key_id)
2007
2098
 
@@ -2012,20 +2103,27 @@ class _CellGraphExecutor:
2012
2103
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
2013
2104
 
2014
2105
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
2015
- obj.arguments_key = str(key)
2016
-
2017
- obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
2106
+ key = str(key)
2018
2107
 
2019
2108
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
2020
2109
  parameter_ids = _get_parameter_ids(args, kwargs)
2021
2110
  if parameter_ids != "":
2022
- obj.arguments_key = obj.arguments_key + '.' + parameter_ids
2111
+ key += '.' + parameter_ids
2112
+
2113
+ key += "." + _get_hook_key(*args, **kwargs)
2114
+ key += "." + str(_hook_version())
2115
+
2116
+ obj.arguments_key = key
2117
+
2023
2118
  raw_phase = phase
2024
- phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2119
+
2120
+ phase = _real_phase(phase, obj)
2025
2121
  obj.phase_cache[raw_phase] = phase
2026
2122
  update_auto_dynamic_shape_phase(args, key_id, phase)
2027
2123
  obj.current_phase = phase
2028
- if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
2124
+ obj._add_attr("compile_phase", phase)
2125
+ obj.compile_phase = phase
2126
+ if phase in obj.compile_cache and self.has_compiled(phase):
2029
2127
  logger.debug("%r graph has existed.", phase)
2030
2128
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
2031
2129
  # generated in generate_arguments_key.
@@ -2051,7 +2149,6 @@ class _CellGraphExecutor:
2051
2149
  obj.compile_cache.add(phase)
2052
2150
  if not result:
2053
2151
  raise RuntimeError("Executor compile failed.")
2054
- set_parameter_hook_updated(False)
2055
2152
  graph = self._graph_executor.get_func_graph(phase)
2056
2153
 
2057
2154
  if graph is None:
@@ -2075,16 +2172,20 @@ class _CellGraphExecutor:
2075
2172
  new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
2076
2173
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
2077
2174
 
2175
+ def set_real_args(self, args, kwargs):
2176
+ """Set real arguments to graph executor."""
2177
+ self._graph_executor.set_real_args(args, kwargs)
2178
+
2078
2179
  def _get_shard_strategy(self, obj):
2079
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2180
+ real_phase = _real_phase(obj.phase, obj)
2080
2181
  return self._graph_executor.get_strategy(real_phase)
2081
2182
 
2082
2183
  def _get_num_parallel_ops(self, obj):
2083
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2184
+ real_phase = _real_phase(obj.phase, obj)
2084
2185
  return self._graph_executor.get_num_parallel_ops(real_phase)
2085
2186
 
2086
2187
  def _get_allreduce_fusion(self, obj):
2087
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2188
+ real_phase = _real_phase(obj.phase, obj)
2088
2189
  return self._graph_executor.get_allreduce_fusion(real_phase)
2089
2190
 
2090
2191
  def __call__(self, obj, *args, phase='predict'):
@@ -2136,10 +2237,10 @@ class _CellGraphExecutor:
2136
2237
  Tensor/Tuple, return execute result.
2137
2238
  """
2138
2239
  if phase == 'save':
2139
- exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2240
+ exe_phase = _real_phase(phase, obj)
2140
2241
  return self._graph_executor((), exe_phase)
2141
2242
 
2142
- phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2243
+ phase_real = _real_phase(phase, obj)
2143
2244
  if self.has_compiled(phase_real):
2144
2245
  return self._exec_pip(obj, *args, phase=phase_real)
2145
2246
  raise KeyError('{} graph is not exist.'.format(phase_real))
@@ -2164,9 +2265,22 @@ class _CellGraphExecutor:
2164
2265
  return None
2165
2266
  return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
2166
2267
 
2268
+ def _get_onnx_func_graph_proto(self, obj, exec_id, use_prefix=False, input_names=None, output_names=None,
2269
+ opset_version=11, export_params=True, keep_initializers_as_inputs=False,
2270
+ dynamic_axes=None, extra_save_params=False, save_file_dir=None):
2271
+ """Get graph proto from pipeline."""
2272
+ if use_prefix:
2273
+ exec_id = exec_id + '.' + obj.arguments_key
2274
+ if self._graph_executor.has_compiled(exec_id) is False:
2275
+ return None
2276
+
2277
+ return self._graph_executor.get_onnx_func_graph_proto(exec_id, input_names, output_names, opset_version,
2278
+ export_params, keep_initializers_as_inputs, dynamic_axes,
2279
+ extra_save_params, save_file_dir)
2280
+
2167
2281
  def get_optimize_graph_proto(self, obj):
2168
2282
  """Return optimize graph binary proto."""
2169
- exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2283
+ exec_id = _real_phase(obj.phase, obj)
2170
2284
  if self._graph_executor.has_compiled(exec_id) is False:
2171
2285
  return None
2172
2286
  graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
@@ -2246,5 +2360,190 @@ def flops_collection(phase='train'):
2246
2360
  return _cell_graph_executor.flops_collection(phase)
2247
2361
 
2248
2362
 
2363
+ class _ScriptGraph:
2364
+ """Store the graph compiled by the frontend compiler."""
2365
+ def __init__(self, func_graph, func, origin_cell, mutable_flags, phase, enable_tuple_broaden):
2366
+ self.func_graph = func_graph
2367
+ self.func = func
2368
+ self.origin_cell = origin_cell
2369
+ self.mutable_flags = mutable_flags
2370
+ self.phase = phase
2371
+ self.enable_tuple_broaden = enable_tuple_broaden
2372
+
2373
+ def print(self):
2374
+ """Print the MindIR of the frontend graph."""
2375
+ graph_str = dump_func_graph(self.func_graph)
2376
+ print(graph_str, flush=True)
2377
+
2378
+
2379
+ def _frontend_compile_ast(dynamic, jit_config, jit_graph_name=''):
2380
+ """Return the wrapped function for ast mode jit."""
2381
+ def wrap_func(func):
2382
+ if hasattr(func, "construct") and isinstance(func, ms.nn.Cell):
2383
+ # Bound the cell object to get the self arg.
2384
+ return types.MethodType(_frontend_compile_ast(dynamic, jit_config,
2385
+ func._jit_graph_name)(func.construct.__func__), func)
2386
+
2387
+ if isinstance(func, types.MethodType):
2388
+ return types.MethodType(_frontend_compile_ast(dynamic, jit_config)(func.__func__), func.__self__)
2389
+
2390
+ if not isinstance(func, types.FunctionType):
2391
+ logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
2392
+ return func
2393
+
2394
+ hash_obj = int(time.time() * 1e9)
2395
+
2396
+ @wraps(func)
2397
+ def staging_specialize(*args, **kwargs):
2398
+ if os.getenv("MS_JIT") == '0':
2399
+ return func(*args, **kwargs)
2400
+
2401
+ args, kwargs = _handle_func_args(func, *args, **kwargs)
2402
+ process_obj = None
2403
+ if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
2404
+ process_obj = args[0]
2405
+ # Handle auto mixed precision strategy.
2406
+ if not hasattr(func, "amp_strategy"):
2407
+ setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
2408
+
2409
+ jit_graph_name = ''
2410
+ if hasattr(staging_specialize, "__jit_graph_name__"):
2411
+ jit_graph_name = staging_specialize.__jit_graph_name__
2412
+ jit_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
2413
+ func_graph, mutable_flags, phase, enable_tuple_broaden = jit_executor.compile_frontend(*args, **kwargs)
2414
+ return _ScriptGraph(func_graph, func, process_obj, mutable_flags, phase, enable_tuple_broaden)
2415
+
2416
+ # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
2417
+ # `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
2418
+ # original `func`.
2419
+ staging_specialize.__signature__ = inspect.signature(func)
2420
+ setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
2421
+ return staging_specialize
2422
+
2423
+ return wrap_func
2424
+
2425
+
2426
+ def _frontend_compile(function: Callable,
2427
+ *,
2428
+ dynamic: int = 0,
2429
+ fullgraph: bool = False):
2430
+ """
2431
+ Create a frontend MindSpore graph from a Python function by the ast capture mode.
2432
+
2433
+ Args:
2434
+ function (Callable, optional): The Python function or Cell instance that will be compiled as a frontend graph.
2435
+ Default: ``None``.
2436
+
2437
+ Keyword Args:
2438
+ dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
2439
+ is as follows:
2440
+
2441
+ - `0`: Do not perform dynamic shape compilation.
2442
+ - `1`: Enable dynamic shape compilation and automatically detect shape changes.
2443
+
2444
+ fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
2445
+ be compatible with all Python syntax in the function as much as possible. If True, we require that the
2446
+ entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
2447
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
2448
+ or ``bytecode``. Default: ``False``.
2449
+
2450
+ Returns:
2451
+ a :class:`_ScriptGraph` object.
2452
+
2453
+ Supported Platforms:
2454
+ ``Ascend`` ``GPU`` ``CPU``
2455
+
2456
+ Examples:
2457
+ >>> import numpy as np
2458
+ >>> from mindspore import Tensor
2459
+ >>> from mindspore import ops
2460
+ >>> from mindspore.common.api import _frontend_compile
2461
+ ...
2462
+ >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2463
+ >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2464
+ ...
2465
+ >>> def tensor_add(x, y):
2466
+ ... z = x + y
2467
+ ... return z
2468
+ ...
2469
+ >>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
2470
+ >>> tensor_add_graph.print()
2471
+ ...
2472
+ """
2473
+
2474
+ dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
2475
+ fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
2476
+ jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
2477
+ jit_config = JitConfig(jit_syntax_level=jit_syntax_level)
2478
+ return _frontend_compile_ast(dynamic, jit_config)(function)
2479
+
2480
+
2481
+ class _GraphFragment(_GraphFragment_):
2482
+ """
2483
+ Represents the output by backend graph split.
2484
+ """
2485
+ def __init__(self, frag):
2486
+ if frag is None or not isinstance(frag, _GraphFragment_):
2487
+ raise TypeError(f"Expect input `frag` to be a _GraphFragment_, but got {type(frag)}")
2488
+ _GraphFragment_.__init__(self, frag)
2489
+
2490
+ def __call__(self, *args):
2491
+ return super().__call__(args)
2492
+
2493
+ def __repr__(self):
2494
+ return self.__str__()
2495
+
2496
+ def id(self):
2497
+ return self.id_()
2498
+
2499
+ def is_graph(self):
2500
+ return self.is_graph_()
2501
+
2502
+ def py_key(self):
2503
+ return self.py_key_()
2504
+
2505
+ def args_list(self):
2506
+ return self.args_list_()
2507
+
2508
+
2509
+ def _graph_split(script_graph):
2510
+ """
2511
+ Split the script_graph into several fragments according to the nodes with the split op attribute.
2512
+
2513
+ Args:
2514
+ a :class:`_ScriptGraph` object.
2515
+
2516
+ Returns:
2517
+ several :class:`_GraphFragment` object.
2518
+
2519
+ Supported Platforms:
2520
+ ``Ascend`` ``GPU`` ``CPU``
2521
+
2522
+ Examples:
2523
+ >>> import numpy as np
2524
+ >>> from mindspore import Tensor
2525
+ >>> from mindspore import ops
2526
+ >>> from mindspore.common.api import _frontend_compile, _graph_split
2527
+ ...
2528
+ >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2529
+ >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
2530
+ >>> add = ops.Add().add_prim_attr("split_op", True).add_prim_attr("func_id", "add_func")
2531
+ ...
2532
+ >>> def tensor_add(x, y):
2533
+ ... z1 = x + y
2534
+ ... z2 = add(z1, x)
2535
+ ... return z2
2536
+ ...
2537
+ >>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
2538
+ >>> frags = _graph_split(tensor_add_graph)
2539
+ >>> print(frags)
2540
+ ...
2541
+ """
2542
+ outputs = JitExecutor_.get_instance().split_graph(script_graph.func_graph)
2543
+ fragments = []
2544
+ for arg in outputs:
2545
+ fragments.append(_GraphFragment(arg))
2546
+ return fragments
2547
+
2249
2548
  _cell_graph_executor = _CellGraphExecutor()
2250
2549
  _pynative_executor = _PyNativeExecutor()