mindspore 2.6.0__cp310-cp310-win_amd64.whl → 2.7.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (455) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +64 -83
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +47 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +177 -52
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +338 -208
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +2 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +84 -133
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +47 -38
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +69 -23
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +5 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +1 -0
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +4 -44
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +425 -19
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +125 -101
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +488 -620
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +86 -85
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +2 -4
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/lamb.py +1 -3
  188. mindspore/nn/optim/optimizer.py +1 -1
  189. mindspore/nn/optim/tft_wrapper.py +2 -3
  190. mindspore/nn/optim/thor.py +2 -2
  191. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  192. mindspore/nn/probability/distribution/exponential.py +2 -1
  193. mindspore/nn/probability/distribution/poisson.py +2 -1
  194. mindspore/nn/sparse/sparse.py +3 -3
  195. mindspore/nn/wrap/cell_wrapper.py +73 -42
  196. mindspore/nn/wrap/grad_reducer.py +37 -52
  197. mindspore/nn/wrap/loss_scale.py +72 -74
  198. mindspore/numpy/array_creations.py +7 -7
  199. mindspore/numpy/fft.py +1 -1
  200. mindspore/numpy/math_ops.py +1 -1
  201. mindspore/numpy/utils_const.py +1 -1
  202. mindspore/opencv_core452.dll +0 -0
  203. mindspore/opencv_imgcodecs452.dll +0 -0
  204. mindspore/opencv_imgproc452.dll +0 -0
  205. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  206. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  207. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  208. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  209. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  210. mindspore/ops/_vmap/vmap_array_ops.py +6 -13
  211. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  212. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +29 -10
  213. mindspore/ops/auto_generate/gen_extend_func.py +5 -55
  214. mindspore/ops/auto_generate/gen_ops_def.py +753 -273
  215. mindspore/ops/auto_generate/gen_ops_prim.py +1687 -958
  216. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  217. mindspore/ops/composite/__init__.py +10 -0
  218. mindspore/ops/composite/base.py +9 -5
  219. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  220. mindspore/ops/composite/multitype_ops/_compile_utils.py +132 -108
  221. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  222. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  223. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  224. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  225. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  226. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  227. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  228. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  229. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  230. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  231. mindspore/ops/function/__init__.py +4 -1
  232. mindspore/ops/function/_add_attr_func.py +11 -6
  233. mindspore/ops/function/array_func.py +17 -100
  234. mindspore/ops/function/debug_func.py +8 -5
  235. mindspore/ops/function/grad/grad_func.py +5 -13
  236. mindspore/ops/function/math_func.py +65 -399
  237. mindspore/ops/function/nn_func.py +44 -61
  238. mindspore/ops/function/other_func.py +4 -1
  239. mindspore/ops/function/random_func.py +31 -4
  240. mindspore/ops/functional.py +2 -3
  241. mindspore/ops/functional_overload.py +486 -18
  242. mindspore/ops/op_info_register.py +21 -0
  243. mindspore/ops/operations/__init__.py +5 -2
  244. mindspore/ops/operations/_custom_ops_utils.py +675 -8
  245. mindspore/ops/operations/_inner_ops.py +14 -18
  246. mindspore/ops/operations/_sequence_ops.py +1 -1
  247. mindspore/ops/operations/array_ops.py +4 -50
  248. mindspore/ops/operations/comm_ops.py +186 -41
  249. mindspore/ops/operations/custom_ops.py +244 -175
  250. mindspore/ops/operations/debug_ops.py +55 -4
  251. mindspore/ops/operations/image_ops.py +13 -13
  252. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  253. mindspore/ops/operations/math_ops.py +8 -9
  254. mindspore/ops/operations/nn_ops.py +6 -7
  255. mindspore/ops/primitive.py +9 -20
  256. mindspore/ops/tensor_method.py +52 -11
  257. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  258. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  259. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  260. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  261. mindspore/ops_generate/common/base_generator.py +14 -0
  262. mindspore/ops_generate/common/gen_constants.py +7 -2
  263. mindspore/ops_generate/common/gen_utils.py +0 -19
  264. mindspore/ops_generate/common/op_proto.py +11 -4
  265. mindspore/ops_generate/common/template.py +88 -11
  266. mindspore/ops_generate/gen_ops.py +1 -1
  267. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  268. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  269. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  270. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  271. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  272. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  273. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  274. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  275. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  276. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  277. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  278. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  279. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  280. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  281. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  282. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  283. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  284. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  285. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  286. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  287. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  288. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  289. mindspore/parallel/_auto_parallel_context.py +9 -17
  290. mindspore/parallel/_cell_wrapper.py +106 -40
  291. mindspore/parallel/_parallel_serialization.py +4 -3
  292. mindspore/parallel/_ps_context.py +4 -6
  293. mindspore/parallel/_tensor.py +167 -12
  294. mindspore/parallel/_transformer/moe.py +1 -1
  295. mindspore/parallel/_transformer/transformer.py +17 -12
  296. mindspore/parallel/_utils.py +5 -11
  297. mindspore/parallel/auto_parallel.py +33 -12
  298. mindspore/parallel/checkpoint_convert.py +3 -3
  299. mindspore/parallel/checkpoint_transform.py +5 -1
  300. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  301. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  302. mindspore/parallel/cluster/run.py +48 -7
  303. mindspore/parallel/function/__init__.py +8 -1
  304. mindspore/parallel/function/reshard_func.py +7 -6
  305. mindspore/parallel/nn/__init__.py +15 -2
  306. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  307. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  308. mindspore/parallel/shard.py +9 -23
  309. mindspore/parallel/transform_safetensors.py +468 -174
  310. mindspore/pgodb140.dll +0 -0
  311. mindspore/pgort140.dll +0 -0
  312. mindspore/profiler/__init__.py +2 -1
  313. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  314. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  315. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +3 -0
  316. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  317. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  318. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  319. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  321. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  322. mindspore/profiler/analysis/task_manager.py +1 -1
  323. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  324. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  325. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  326. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  327. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  328. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  329. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  330. mindspore/profiler/common/constant.py +16 -0
  331. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  332. mindspore/profiler/common/path_manager.py +9 -0
  333. mindspore/profiler/common/profiler_context.py +50 -29
  334. mindspore/profiler/common/profiler_info.py +0 -16
  335. mindspore/profiler/common/profiler_meta_data.py +1 -0
  336. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  337. mindspore/profiler/common/profiler_output_path.py +23 -8
  338. mindspore/profiler/common/profiler_parameters.py +128 -35
  339. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  340. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  341. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  342. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  343. mindspore/profiler/dynamic_profiler.py +374 -338
  344. mindspore/profiler/envprofiler.py +42 -12
  345. mindspore/profiler/experimental_config.py +112 -7
  346. mindspore/profiler/mstx.py +33 -12
  347. mindspore/profiler/platform/__init__.py +2 -3
  348. mindspore/profiler/platform/cpu_profiler.py +10 -4
  349. mindspore/profiler/platform/npu_profiler.py +30 -20
  350. mindspore/profiler/profiler.py +218 -154
  351. mindspore/profiler/profiler_action_controller.py +65 -77
  352. mindspore/profiler/profiler_interface.py +2 -2
  353. mindspore/profiler/schedule.py +10 -4
  354. mindspore/rewrite/common/config.py +1 -0
  355. mindspore/rewrite/common/namer.py +1 -0
  356. mindspore/rewrite/common/namespace.py +1 -0
  357. mindspore/rewrite/node/node.py +31 -11
  358. mindspore/rewrite/parsers/assign_parser.py +1 -1
  359. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  360. mindspore/run_check/_check_version.py +7 -10
  361. mindspore/runtime/__init__.py +8 -6
  362. mindspore/runtime/event.py +10 -4
  363. mindspore/runtime/executor.py +87 -45
  364. mindspore/runtime/memory.py +22 -30
  365. mindspore/runtime/thread_bind_core.py +299 -165
  366. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  367. mindspore/swresample-4.dll +0 -0
  368. mindspore/swscale-6.dll +0 -0
  369. mindspore/tbbmalloc.dll +0 -0
  370. mindspore/tinyxml2.dll +0 -0
  371. mindspore/train/_utils.py +9 -5
  372. mindspore/train/amp.py +43 -23
  373. mindspore/train/callback/__init__.py +5 -5
  374. mindspore/train/callback/_callback.py +2 -1
  375. mindspore/train/callback/_checkpoint.py +4 -14
  376. mindspore/train/callback/_flops_collector.py +11 -7
  377. mindspore/train/callback/_landscape.py +0 -1
  378. mindspore/train/callback/_train_fault_tolerance.py +72 -18
  379. mindspore/train/data_sink.py +15 -6
  380. mindspore/train/dataset_helper.py +14 -5
  381. mindspore/train/model.py +49 -47
  382. mindspore/train/serialization.py +168 -126
  383. mindspore/train/summary/summary_record.py +13 -2
  384. mindspore/train/train_thor/model_thor.py +2 -2
  385. mindspore/turbojpeg.dll +0 -0
  386. mindspore/utils/__init__.py +3 -2
  387. mindspore/utils/dryrun.py +0 -6
  388. mindspore/utils/runtime_execution_order_check.py +162 -78
  389. mindspore/utils/sdc_detect.py +68 -0
  390. mindspore/utils/utils.py +14 -17
  391. mindspore/vcmeta.dll +0 -0
  392. mindspore/vcruntime140.dll +0 -0
  393. mindspore/vcruntime140_1.dll +0 -0
  394. mindspore/version.py +1 -1
  395. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  396. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/RECORD +400 -439
  397. mindspore/_deprecated/jit.py +0 -198
  398. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  399. mindspore/communication/_hccl_management.py +0 -297
  400. mindspore/experimental/es/embedding_service.py +0 -891
  401. mindspore/experimental/es/embedding_service_layer.py +0 -581
  402. mindspore/profiler/common/validator/__init__.py +0 -14
  403. mindspore/profiler/common/validator/validate_path.py +0 -84
  404. mindspore/profiler/parser/__init__.py +0 -14
  405. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  406. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  407. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  408. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  409. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  410. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  411. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  412. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  413. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  414. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  415. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  416. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  417. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  418. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  419. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  420. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  421. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  422. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  423. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  424. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  425. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  426. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  427. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  428. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  429. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  430. mindspore/profiler/parser/container.py +0 -229
  431. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  432. mindspore/profiler/parser/flops_parser.py +0 -531
  433. mindspore/profiler/parser/framework_enum.py +0 -111
  434. mindspore/profiler/parser/framework_parser.py +0 -464
  435. mindspore/profiler/parser/framework_struct.py +0 -61
  436. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  437. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  438. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  439. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  440. mindspore/profiler/parser/hccl_parser.py +0 -573
  441. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  442. mindspore/profiler/parser/integrator.py +0 -526
  443. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  444. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  445. mindspore/profiler/parser/minddata_parser.py +0 -186
  446. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  447. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  448. mindspore/profiler/parser/optime_parser.py +0 -250
  449. mindspore/profiler/parser/profiler_info.py +0 -213
  450. mindspore/profiler/parser/step_trace_parser.py +0 -666
  451. mindspore/utils/hooks.py +0 -81
  452. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  453. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  454. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  455. {mindspore-2.6.0.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
2
  #
3
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
3
+ # Copyright 2020-2025 Huawei Technologies Co., Ltd
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
6
  # you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
21
+
20
22
  import gc
21
23
  import types
22
24
  import sys
@@ -42,23 +44,25 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
42
44
  from mindspore._c_expression.amp import get_curr_amp_strategy
43
45
  from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
44
46
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
45
- _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
47
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
46
48
  from mindspore.parallel._ps_context import _is_role_sched
47
49
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
50
  _is_parallel_mode
49
51
  from mindspore import _checkparam as Validator
50
52
  from mindspore._checkparam import is_stub_tensor
51
- from mindspore.common._utils import is_shape_unknown
53
+ from mindspore.common._utils import is_shape_unknown, get_func
52
54
  from mindspore.common.mutable import mutable, _check_element_type
53
- from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
- get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
+ from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
56
+ update_auto_dynamic_shape_phase
57
+ from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
55
58
  from mindspore.common._pijit_context import PIJitCaptureContext
56
- from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
59
+ from mindspore.common.parameter import Parameter
60
+ from mindspore.common.hook_handle import _hook_version
57
61
  from mindspore.common.jit_context import jit_context
58
62
  from mindspore.common.jit_trace import _jit_trace
59
63
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
60
64
 
61
- # Store ms_function class compiled pipeline cache.
65
+ # Store jit class compiled pipeline cache.
62
66
  ms_compile_cache = set()
63
67
  # Store cell compiled pipeline cache.
64
68
  cells_compile_cache = {}
@@ -72,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
72
76
  TOTAL_ARG_LEN = "total_arg_length"
73
77
 
74
78
 
79
+ def _real_phase(phase, obj):
80
+ real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
81
+ return real_phase
82
+
83
+
75
84
  def _check_recompile_args(compile_args, kwargs):
76
85
  """Check recompile of graph"""
77
86
 
@@ -134,8 +143,6 @@ def _convert_python_data(data):
134
143
  """
135
144
  if isinstance(data, PythonTensor):
136
145
  return data
137
- if isinstance(data, StubNode):
138
- return ms.common._stub_tensor._convert_stub(data)
139
146
  if data.__class__ is tuple:
140
147
  # Handle namedtuple since its type is tuple.
141
148
  if hasattr(data, "_fields"):
@@ -278,13 +285,13 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
278
285
  module = importlib.util.module_from_spec(module_spec)
279
286
  if hasattr(module, '__file__'):
280
287
  dep_file_path = module.__file__
288
+ # Exclude the installed modules.
289
+ if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
290
+ logger.debug(f"dependent file path: {dep_file_path}")
291
+ compile_cache_dep_files.append(dep_file_path)
292
+ __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
281
293
  else:
282
294
  continue
283
- # Exclude the installed modules.
284
- if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
285
- logger.debug(f"dependent file path: {dep_file_path}")
286
- compile_cache_dep_files.append(dep_file_path)
287
- __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
288
295
 
289
296
 
290
297
  def _get_compile_cache_dep_files():
@@ -342,7 +349,7 @@ def _get_parameter_layout():
342
349
  return layout
343
350
 
344
351
 
345
- def _handle_arg(obj, arg, compile_arg):
352
+ def _handle_arg(obj, arg, has_mutable_arg):
346
353
  """Handle arg for runtime .If need handle the arg, return True"""
347
354
  from mindspore._extends.parse import compile_config
348
355
  if isinstance(arg, PythonTensor):
@@ -352,7 +359,7 @@ def _handle_arg(obj, arg, compile_arg):
352
359
  return arg
353
360
  elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
354
361
  return arg
355
- elif compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and getattr(compile_arg, "__ms_mutable__"):
362
+ elif has_mutable_arg:
356
363
  # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
357
364
  if isinstance(arg, list) and not arg:
358
365
  return None
@@ -366,7 +373,7 @@ def _handle_arg(obj, arg, compile_arg):
366
373
  return None
367
374
 
368
375
 
369
- def _handle_arg_predict(obj, arg, compile_arg):
376
+ def _handle_arg_predict(obj, arg, has_mutable_arg):
370
377
  """Handle arg for runtime .If need handle the arg, return True"""
371
378
  if arg is None:
372
379
  return None
@@ -375,8 +382,7 @@ def _handle_arg_predict(obj, arg, compile_arg):
375
382
  return None
376
383
 
377
384
  if isinstance(arg, (list, tuple)):
378
- if compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
379
- getattr(compile_arg, "__ms_mutable__"):
385
+ if has_mutable_arg:
380
386
  # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
381
387
  if isinstance(arg, list) and not arg:
382
388
  return None
@@ -388,35 +394,30 @@ def _handle_arg_predict(obj, arg, compile_arg):
388
394
  return arg
389
395
 
390
396
 
391
- def _get_args_for_run(obj, args, kwargs, compile_args):
397
+ def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
392
398
  """Get the actual input args and kwargs for runtime."""
393
399
  new_args = []
394
- for arg, compile_arg in zip(args, compile_args):
395
- new_arg = _handle_arg(obj, arg, compile_arg)
400
+ fn = _handle_arg_predict if is_predict else _handle_arg
401
+ for arg, has_mutable_arg in zip(args, has_mutable_args_list):
402
+ new_arg = fn(obj, arg, has_mutable_arg)
396
403
  if new_arg is not None:
397
404
  new_args.append(new_arg)
398
405
 
399
406
  for _, value in kwargs.items():
400
- new_value = _handle_arg(obj, value, None)
407
+ new_value = fn(obj, value, None)
401
408
  if new_value is not None:
402
409
  new_args.append(new_value)
403
410
 
404
411
  return new_args
405
412
 
406
413
 
407
- def _get_args_for_run_predict(obj, args, kwargs, compile_args):
408
- """Get the actual input args and kwargs for runtime."""
414
+ def _get_mutable_flags(compile_args):
415
+ """Get a list of booleans indicating whether each argument is marked as mutable"""
409
416
  new_args = []
410
- for arg, compile_arg in zip(args, compile_args):
411
- new_arg = _handle_arg_predict(obj, arg, compile_arg)
412
- if new_arg is not None:
413
- new_args.append(new_arg)
414
-
415
- for _, value in kwargs.items():
416
- new_value = _handle_arg_predict(obj, value, None)
417
- if new_value is not None:
418
- new_args.append(new_value)
419
-
417
+ for compile_arg in compile_args:
418
+ has_mutable_arg = compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
419
+ getattr(compile_arg, "__ms_mutable__")
420
+ new_args.append(has_mutable_arg)
420
421
  return new_args
421
422
 
422
423
 
@@ -544,10 +545,12 @@ def _get_parameter_ids(args, kwargs):
544
545
  parameter_ids += str(id(value))
545
546
  return parameter_ids
546
547
 
548
+
547
549
  def _get_tensor_hook_key(tensor):
548
550
  """Get the hook key of Tensor/Parameter"""
549
551
  return ".".join(map(str, map(id, tensor.hooks())))
550
552
 
553
+
551
554
  def _get_hook_key(*args, **kwargs):
552
555
  """Get the hook key of Tensors/Parameters"""
553
556
  hook_key = ""
@@ -586,13 +589,16 @@ class _JitExecutor:
586
589
  The result of pipeline running in graph mode.
587
590
  """
588
591
 
589
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
592
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0,
593
+ cell_cache_key_extend=''):
590
594
  init_pipeline()
591
595
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
592
596
  raise RuntimeError('fn {} is not function or method'.format(fn))
593
597
 
594
598
  self.fn = fn
595
599
  self.input_signature = input_signature
600
+ self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
601
+ self.enable_jit_dynamic = self.dynamic_args_shapes is not None
596
602
  self.obj = None
597
603
  if obj and hasattr(obj, fn.__name__):
598
604
  self.obj = obj
@@ -606,6 +612,7 @@ class _JitExecutor:
606
612
  self._compile_args = None
607
613
  self._enable_auto_dynamic = dynamic == 1
608
614
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
615
+ self._cell_cache_key_extend = cell_cache_key_extend
609
616
 
610
617
  def _predict(self, *args, **kwargs):
611
618
  """Dedicated routine for predict."""
@@ -630,15 +637,18 @@ class _JitExecutor:
630
637
  else: # get compiled args to generate run args by _generate_run_args
631
638
  compile_args = self._generate_compile_args(args_list)
632
639
  key_id = self._get_key_id()
633
- compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
634
- compile_args,
635
- key_id,
636
- self.input_signature,
637
- self._enable_auto_dynamic
638
- )
640
+ if self.input_signature is None:
641
+ compile_args = get_auto_dynamic_shape_args(
642
+ compile_args, key_id, self._enable_auto_dynamic
643
+ )
639
644
  self._compile_args = compile_args
640
645
 
641
646
  new_inputs = self._generate_run_args(args_list, kwargs)
647
+ if self.jit_config_dict:
648
+ jit_config_dict = self.jit_config_dict
649
+ else:
650
+ jit_config_dict = JitConfig().jit_config_dict
651
+ self._graph_executor.set_jit_config(jit_config_dict)
642
652
  output = self._graph_executor(
643
653
  tuple(new_inputs),
644
654
  self.obj.phase_cache[self.obj.phase]
@@ -658,12 +668,9 @@ class _JitExecutor:
658
668
  args_list = args_list[1:]
659
669
  phase = ""
660
670
  try:
661
- if context.get_context("mode") == context.PYNATIVE_MODE:
662
- _pynative_executor.set_jit_compile_status(True, phase)
663
- phase = self.compile(self.fn.__name__, *args_list, **kwargs)
664
- _pynative_executor.set_jit_compile_status(False, phase)
665
- else:
666
- phase = self.compile(self.fn.__name__, *args_list, **kwargs)
671
+ _pynative_executor.set_jit_compile_status(True, phase)
672
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
673
+ _pynative_executor.set_jit_compile_status(False, phase)
667
674
  except Exception as err:
668
675
  _pynative_executor.clear_res()
669
676
  raise err
@@ -672,31 +679,27 @@ class _JitExecutor:
672
679
  return None
673
680
 
674
681
  new_inputs = self._generate_run_args(args_list, kwargs)
675
- if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
676
- output = _pynative_executor.grad_jit(*new_inputs)
682
+ if self.jit_config_dict:
683
+ jit_config_dict = self.jit_config_dict
677
684
  else:
678
- output = self._graph_executor(tuple(new_inputs), phase)
679
- if jit_context():
680
- if is_stub_tensor(output):
681
- output = output.stub_sync()
682
- return jit_context().run_graph(phase, output, *tuple(new_inputs))
683
-
685
+ jit_config_dict = JitConfig().jit_config_dict
686
+ self._graph_executor.set_jit_config(jit_config_dict)
687
+ output = _pynative_executor.grad_jit(*new_inputs)
688
+ if jit_context():
689
+ if is_stub_tensor(output):
690
+ output = output.stub_sync()
691
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
684
692
  return output
685
693
 
686
694
  def compile(self, method_name, *args, **kwargs):
687
695
  """Returns pipeline for the given args."""
688
- # Check whether hook function registered on Cell object.
689
- if self.obj and hasattr(self.obj, "_hook_fn_registered"):
690
- if self.obj._hook_fn_registered():
691
- logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
692
- f"If you want to use hook function, please use context.set_context to set "
693
- f"pynative mode and remove 'jit' decorator.")
694
696
  # Chose dynamic shape tensors or actual input tensors as compile args.
695
697
  compile_args = self._generate_compile_args(args)
696
698
  key_id = self._get_key_id()
697
- compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
698
- self.input_signature,
699
- self._enable_auto_dynamic)
699
+ if self.input_signature is None:
700
+ compile_args = get_auto_dynamic_shape_args(
701
+ compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
702
+ )
700
703
 
701
704
  # Add mutable for compile_args for two scene:
702
705
  # 1) Origin args is mutable.
@@ -736,18 +739,23 @@ class _JitExecutor:
736
739
 
737
740
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
738
741
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
742
+ key = str(key)
739
743
 
740
744
  parameter_ids = _get_parameter_ids(args, kwargs)
741
745
  if parameter_ids != "":
742
- key = str(key) + '.' + parameter_ids
746
+ key += '.' + parameter_ids
747
+
748
+ key += "." + _get_hook_key(*args, **kwargs)
749
+ key += "." + str(_hook_version())
743
750
 
744
- key = str(key) + "." + _get_hook_key(*args, **kwargs)
751
+ phase = generate_name + '.' + key
745
752
 
746
- phase = generate_name + '.' + str(key)
753
+ if self.input_signature is None:
754
+ update_auto_dynamic_shape_phase(compile_args, key_id, phase)
747
755
 
748
- update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
756
+ phase = phase + self._cell_cache_key_extend
749
757
 
750
- if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
758
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
751
759
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
752
760
  # generated in generate_arguments_key.
753
761
  self._graph_executor.clear_compile_arguments_resource()
@@ -758,30 +766,23 @@ class _JitExecutor:
758
766
  # If enable compile cache, get the dependency files list and set to graph executor.
759
767
  self._set_compile_cache_dep_files()
760
768
  if self.jit_config_dict:
761
- self._graph_executor.set_jit_config(self.jit_config_dict)
769
+ jit_config_dict = self.jit_config_dict
762
770
  else:
763
771
  jit_config_dict = JitConfig().jit_config_dict
764
- self._graph_executor.set_jit_config(jit_config_dict)
765
772
 
766
773
  if self.obj is None:
767
774
  # Set an attribute to fn as an identifier.
768
- if isinstance(self.fn, types.MethodType):
769
- setattr(self.fn.__func__, "__jit_function__", True)
770
- else:
771
- setattr(self.fn, "__jit_function__", True)
772
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
773
- if isinstance(self.fn, types.MethodType):
774
- delattr(self.fn.__func__, "__jit_function__")
775
- else:
776
- delattr(self.fn, "__jit_function__")
775
+ setattr(get_func(self.fn), "__jit_function__", True)
776
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
777
+ delattr(get_func(self.fn), "__jit_function__")
777
778
  else:
778
779
  if isinstance(self.obj, ms.nn.Cell):
779
780
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
780
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
781
+ is_compile = self._graph_executor.compile(
782
+ self.obj, compile_args, kwargs, phase, jit_config_dict)
781
783
 
782
784
  if not is_compile:
783
785
  raise RuntimeError("Executor compile failed.")
784
- set_parameter_hook_updated(False)
785
786
  ms_compile_cache.add(phase)
786
787
  if hasattr(self.obj, "phase"):
787
788
  self.obj.phase_cache[self.obj.phase] = phase
@@ -829,41 +830,70 @@ class _JitExecutor:
829
830
  if enable_compile_cache is True or enable_compile_cache == "1":
830
831
  self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
831
832
 
833
+ def _generate_compile_args_by_enable_dynamic(self, args_list):
834
+ """Generate compile args by enable_dynamic."""
835
+ compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
836
+ compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
837
+ if self.obj is not None:
838
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
839
+ else:
840
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
841
+ logger.info(f"dynamic shape compile_args: {compile_args}")
842
+ return compile_args
843
+
844
+ def _generate_compile_args_by_set_inputs(self, args_list):
845
+ """Generate compile args by set_inputs."""
846
+ compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
847
+ if len(compile_args) != len(args_list):
848
+ raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
849
+ f"dynamic shape tensors: {len(compile_args)}.")
850
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
851
+ Validator.check_symbolic_shape(compile_args, args_list)
852
+ return compile_args
853
+
854
+ def _generate_compile_args_by_input_signature(self, args_list):
855
+ """Generate compile args by input_signature."""
856
+ compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
857
+ dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
858
+ Validator.check_symbolic_shape(self.input_signature, args_list)
859
+ if dyn_shape:
860
+ # Checkout whether the `sens` has been added to args_list.
861
+ if len(compile_args) == len(args_list) - 1:
862
+ logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
863
+ f"of input_signature args '{len(compile_args)}'. The last actual args may "
864
+ f"be 'sens' and added it to compile args.")
865
+ compile_args.append(args_list[-1])
866
+ compile_args = tuple(compile_args)
867
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
868
+ if self.obj is not None:
869
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
870
+ else:
871
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
872
+ else:
873
+ if not verify_inputs_signature(compile_args, args_list):
874
+ raise ValueError("The input args is incompatible with the args in `input_signature`!")
875
+ return compile_args
876
+
877
+ def _check_set_inputs(self):
878
+ """Check if the `set_inputs()` of Cell object has been set."""
879
+ return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
880
+
832
881
  def _generate_compile_args(self, args_list):
833
882
  """Chose dynamic shape tensors or actual input tensors as compile args."""
834
- # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
835
- compile_args = _pynative_executor.get_dynamic_input(args_list)
883
+ # Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
884
+ if self.enable_jit_dynamic and self._check_set_inputs():
885
+ raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
886
+ # Case: The `enable_dynamic` is provided.
887
+ if self.enable_jit_dynamic:
888
+ return self._generate_compile_args_by_enable_dynamic(args_list)
836
889
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
837
- if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
838
- compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
839
- if len(compile_args) != len(args_list):
840
- raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
841
- f"dynamic shape tensors: {len(compile_args)}.")
842
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
843
- Validator.check_symbolic_shape(compile_args, args_list)
844
-
890
+ if self._check_set_inputs():
891
+ return self._generate_compile_args_by_set_inputs(args_list)
845
892
  # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
846
893
  if self.input_signature is not None:
847
- compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
848
- dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
849
- Validator.check_symbolic_shape(self.input_signature, args_list)
850
- if dyn_shape:
851
- # Checkout whether the `sens` has been added to args_list.
852
- if len(compile_args) == len(args_list) - 1:
853
- logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
854
- f"of input_signature args '{len(compile_args)}'. The last actual args may "
855
- f"be 'sens' and added it to compile args.")
856
- compile_args.append(args_list[-1])
857
- compile_args = tuple(compile_args)
858
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
859
- if self.obj is not None:
860
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
861
- else:
862
- _pynative_executor.set_dynamic_input(self.fn, *compile_args)
863
- else:
864
- if not verify_inputs_signature(compile_args, args_list):
865
- raise ValueError("The input args is incompatible with the args in `input_signature`!")
866
- return compile_args
894
+ return self._generate_compile_args_by_input_signature(args_list)
895
+ # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
896
+ return _pynative_executor.get_dynamic_input(args_list)
867
897
 
868
898
  def _generate_run_args(self, args_list, kwargs):
869
899
  """
@@ -876,7 +906,7 @@ class _JitExecutor:
876
906
  Returns:
877
907
  new_inputs, new input args, which are required for running.
878
908
  """
879
- return _get_args_for_run(self, args_list, kwargs, self._compile_args)
909
+ return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
880
910
 
881
911
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
882
912
  """Get graph proto from pipeline."""
@@ -1037,6 +1067,67 @@ def _check_options(options, backend):
1037
1067
  _check_option_value(option, value)
1038
1068
 
1039
1069
 
1070
+ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1071
+ """Return the wrapped function for ast mode jit."""
1072
+ def wrap_func(func):
1073
+ nonlocal hash_obj
1074
+ if hasattr(func, "construct"):
1075
+ if isinstance(func, ms.nn.Cell):
1076
+ # Bound the cell object to get the self arg.
1077
+ return types.MethodType(_jit_ast(
1078
+ hash_obj, dynamic, jit_config, func._jit_graph_name)(func.construct.__func__), func)
1079
+ if isinstance(func, type) and issubclass(func, ms.nn.Cell):
1080
+ func.construct = _jit_ast(
1081
+ hash_obj, dynamic, jit_config, '')(func.construct)
1082
+ return func
1083
+
1084
+ if isinstance(func, types.MethodType):
1085
+ return types.MethodType(_jit_ast(hash_obj, dynamic, jit_config, '')(func.__func__), func.__self__)
1086
+
1087
+ if not isinstance(func, types.FunctionType):
1088
+ logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
1089
+ return func
1090
+
1091
+ if hasattr(func, "__wrapped_by_jit__"):
1092
+ logger.warning(f"The func {func} should be wrapped by jit only once.")
1093
+
1094
+ if hash_obj is None or not _is_inner_func(func):
1095
+ hash_obj = int(time.time() * 1e9)
1096
+
1097
+ @wraps(func)
1098
+ def staging_specialize(*args, **kwargs):
1099
+ if os.getenv("MS_JIT") == '0':
1100
+ return func(*args, **kwargs)
1101
+
1102
+ args, kwargs = _handle_func_args(func, *args, **kwargs)
1103
+ process_obj = None
1104
+ if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1105
+ process_obj = args[0]
1106
+ # Handle auto mixed precision strategy.
1107
+ if not hasattr(func, "amp_strategy"):
1108
+ setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
1109
+
1110
+ jit_graph_name = ''
1111
+ if hasattr(staging_specialize, "__jit_graph_name__"):
1112
+ jit_graph_name = staging_specialize.__jit_graph_name__
1113
+ jit_executor = _JitExecutor(
1114
+ func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
1115
+ out = jit_executor(*args, **kwargs)
1116
+ if isinstance(process_obj, ms.nn.Cell):
1117
+ _clear_auto_parallel_context(process_obj)
1118
+ return out
1119
+
1120
+ # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
1121
+ # `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
1122
+ # original `func`.
1123
+ staging_specialize.__signature__ = inspect.signature(func)
1124
+ setattr(staging_specialize, "__wrapped_by_jit__", True)
1125
+ setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
1126
+ return staging_specialize
1127
+
1128
+ return wrap_func
1129
+
1130
+
1040
1131
  def jit(
1041
1132
  function: Optional[Callable] = None,
1042
1133
  *,
@@ -1059,45 +1150,45 @@ def jit(
1059
1150
  and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1060
1151
 
1061
1152
  Args:
1062
- function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1153
+ function (Callable, optional): The Python function or Cell that will be run as a graph. Default: ``None``.
1063
1154
 
1064
1155
  Keyword Args:
1065
1156
  capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1066
- should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1157
+ should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
1067
1158
 
1068
- - `ast <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/static_graph.html>`_ :
1069
- Parse Python ast to build graph.
1070
- - `bytecode <https://www.mindspore.cn/docs/en/r2.5.0/model_train/program_form/pynative.html#pijit>`_ :
1071
- Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1072
- change and/or deletion.
1073
- - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1074
- subject to change and/or deletion.
1159
+ - ast: Parse Python ast to build graph.
1160
+ - bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
1161
+ that is subject to change and/or deletion.
1162
+ - trace: Trace the execution of Python code to build graph. This is an experimental prototype
1163
+ that is subject to change and/or deletion.
1075
1164
 
1076
1165
  jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1077
- with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1166
+ with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
1078
1167
 
1079
- - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1080
- - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1168
+ - O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
1169
+ - O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1081
1170
  level is experimental and is being improved.
1082
1171
 
1083
1172
  dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1084
1173
  is as follows:
1085
1174
 
1086
- - `0`: Do not perform dynamic shape compilation.
1087
- - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1175
+ - 0: Do not perform dynamic shape compilation.
1176
+ - 1: Enable dynamic shape compilation and automatically detect shape changes.
1088
1177
 
1089
1178
  fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1090
1179
  be compatible with all Python syntax in the function as much as possible. If True, we require that the
1091
1180
  entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1092
- not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1093
- Default: ``False``.
1181
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
1182
+ or ``bytecode``. Default: ``False``.
1094
1183
  backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1095
- use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1096
- A2 training series products by default.
1184
+ use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
1185
+ Atlas A2 training series products by default.
1097
1186
 
1098
- - `ms_backend`: Adopt KernelByKernel execution mode.
1099
- - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1100
- the top cell of model. And only can be used in Ascend platform.
1187
+ - ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
1188
+ optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
1189
+ - GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
1190
+ for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
1191
+ and can be executed only on Ascend hardware.
1101
1192
 
1102
1193
  **options (dict): A dictionary of options to pass to the compilation backend.
1103
1194
 
@@ -1120,11 +1211,11 @@ def jit(
1120
1211
  `disable_format_transform` can be set to ``True`` to try to improve training performance.
1121
1212
  Default: ``False`` .
1122
1213
  - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1123
- methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1214
+ methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
1124
1215
 
1125
- - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1216
+ - bfs: The default sorting method, breadth priority, good communication masking, relatively good
1126
1217
  performance.
1127
- - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1218
+ - dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1128
1219
  of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1129
1220
  other execution orders run out of memory (OOM).
1130
1221
 
@@ -1135,11 +1226,11 @@ def jit(
1135
1226
  - global (dict): Set global options.
1136
1227
  - session (dict): Set session options.
1137
1228
 
1138
- - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1229
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
1139
1230
  the inference mode is disabled. The range is as follows:
1140
1231
 
1141
- - `on`: Enable inference mode, get better infer performance.
1142
- - `off`: Disable inference mode, use forward for inference. The performance is poor.
1232
+ - on: Enable inference mode, get better infer performance.
1233
+ - off: Disable inference mode, use forward for inference. The performance is poor.
1143
1234
 
1144
1235
  Returns:
1145
1236
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -1158,29 +1249,84 @@ def jit(
1158
1249
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1159
1250
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1160
1251
  ...
1161
- >>> # create a callable MindSpore graph by calling jit
1252
+ >>> # Create a callable MindSpore graph by calling jit.
1162
1253
  >>> def tensor_add(x, y):
1163
1254
  ... z = x + y
1164
1255
  ... return z
1165
1256
  ...
1166
1257
  >>> tensor_add_graph = jit(function=tensor_add)
1167
1258
  >>> out = tensor_add_graph(x, y)
1259
+ >>> print(out)
1260
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1261
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1262
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1263
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1168
1264
  ...
1169
- >>> # create a callable MindSpore graph through decorator @jit
1265
+ >>> # Create a callable MindSpore graph through decorator @jit.
1170
1266
  >>> @jit
1171
1267
  ... def tensor_add_with_dec(x, y):
1172
1268
  ... z = x + y
1173
1269
  ... return z
1174
1270
  ...
1175
1271
  >>> out = tensor_add_with_dec(x, y)
1272
+ >>> print(out)
1273
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1274
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1275
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1276
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1176
1277
  ...
1177
- >>> # create a callable MindSpore graph and capture the entire function into the graph
1278
+ >>> # Create a callable MindSpore graph and capture the entire function into the graph.
1178
1279
  >>> @jit(fullgraph=True)
1179
1280
  ... def tensor_add_fullgraph(x, y):
1180
1281
  ... z = x + y
1181
1282
  ... return z
1182
1283
  ...
1183
1284
  >>> out = tensor_add_fullgraph(x, y)
1285
+ >>> print(out)
1286
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1287
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1288
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1289
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1290
+ ...
1291
+ >>> # Create a callable MindSpore graph by trace mode.
1292
+ >>> @jit(capture_mode="trace")
1293
+ ... def tensor_add_by_trace(x, y):
1294
+ ... z = x + y
1295
+ ... return z
1296
+ ...
1297
+ >>> out = tensor_add_by_trace(x, y)
1298
+ >>> print(out)
1299
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1300
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1301
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1302
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1303
+ ...
1304
+ >>> # Create a callable MindSpore graph with ms_backend and jit_level="O1".
1305
+ >>> @jit(backend="ms_backend", jit_level="O1")
1306
+ ... def tensor_add_by_trace(x, y):
1307
+ ... z = x + y
1308
+ ... return z
1309
+ ...
1310
+ >>> out = tensor_add_by_trace(x, y)
1311
+ >>> print(out)
1312
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1313
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1314
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1315
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1316
+ ...
1317
+ >>> # Create a callable MindSpore graph with GE backend and some ge options on Ascend.
1318
+ >>> @jit(backend="GE", ge_options={"global": {"ge.opSelectImplmode": "high_precision"}})
1319
+ ... def tensor_add_by_trace(x, y):
1320
+ ... z = x + y
1321
+ ... return z
1322
+ ...
1323
+ >>> out = tensor_add_by_trace(x, y)
1324
+ >>> print(out)
1325
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1326
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1327
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1328
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1329
+ ...
1184
1330
  """
1185
1331
 
1186
1332
  capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
@@ -1199,39 +1345,12 @@ def jit(
1199
1345
  jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1200
1346
  infer_boost=infer_boost, backend=backend, options=options_str)
1201
1347
 
1202
- def wrap_func(func):
1203
- nonlocal hash_obj
1204
- if hash_obj is None or not _is_inner_func(func):
1205
- hash_obj = int(time.time() * 1e9)
1206
-
1207
- @wraps(func)
1208
- def staging_specialize(*args, **kwargs):
1209
- if os.getenv("MS_JIT") == '0':
1210
- return func(*args, **kwargs)
1211
-
1212
- args, kwargs = _handle_func_args(func, *args, **kwargs)
1213
- process_obj = None
1214
- if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1215
- process_obj = args[0]
1216
- # Handle auto mixed precision strategy.
1217
- if not hasattr(func, "amp_strategy"):
1218
- if isinstance(func, types.MethodType):
1219
- setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1220
- else:
1221
- setattr(func, "amp_strategy", get_curr_amp_strategy())
1222
-
1223
- ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1224
- out = ms_function_executor(*args, **kwargs)
1225
- return out
1226
-
1227
- return staging_specialize
1228
-
1229
- if capture_mode == "bytecode":
1230
- wrap_func = PIJitCaptureContext(jit_config)
1231
- elif capture_mode == "trace":
1232
- if function is not None:
1233
- return _jit_trace(function)
1234
- return _jit_trace
1348
+ if capture_mode == "ast":
1349
+ wrap_func = _jit_ast(hash_obj, dynamic, jit_config, '')
1350
+ elif capture_mode == "bytecode":
1351
+ wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
1352
+ else:
1353
+ wrap_func = _jit_trace()
1235
1354
 
1236
1355
  if function is not None:
1237
1356
  return wrap_func(function)
@@ -1547,7 +1666,7 @@ class _PyNativeExecutor:
1547
1666
  """
1548
1667
  self._executor.end_graph(obj, output, *args, *(kwargs.values()))
1549
1668
 
1550
- def check_run(self, grad, obj, weights, grad_hash_id, *args):
1669
+ def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
1551
1670
  """
1552
1671
  Whether the forward graph need to construct.
1553
1672
 
@@ -1560,7 +1679,7 @@ class _PyNativeExecutor:
1560
1679
  Return:
1561
1680
  bool, specifies whether the forward graph needs to construct.
1562
1681
  """
1563
- return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
1682
+ return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, **kwargs)
1564
1683
 
1565
1684
  def grad(self, obj, grad, weights, grad_position, *args):
1566
1685
  """
@@ -1802,6 +1921,19 @@ class _PyNativeExecutor:
1802
1921
  """
1803
1922
  return self._executor.constant_folding(*args)
1804
1923
 
1924
+ def set_creation_type(self, tensor, creation_type):
1925
+ """
1926
+ Set tensor's view creation type
1927
+
1928
+ Args:
1929
+ tensor (Tensor): input tensor.
1930
+ creation_type (CreationType): The type of view tensor when it is created.
1931
+
1932
+ Return:
1933
+ None.
1934
+ """
1935
+ return self._executor.set_creation_type(tensor, creation_type)
1936
+
1805
1937
 
1806
1938
  class _CellGraphExecutor:
1807
1939
  """
@@ -1878,13 +2010,6 @@ class _CellGraphExecutor:
1878
2010
  else:
1879
2011
  _set_dataset_mode_config('normal')
1880
2012
 
1881
- @staticmethod
1882
- def _use_vm_mode():
1883
- enable_ge = context.get_context("enable_ge")
1884
- enable_debug_runtime = context.get_context("enable_debug_runtime")
1885
- exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
1886
- return not enable_ge or (enable_debug_runtime and exe_mode)
1887
-
1888
2013
  def _build_data_graph(self, obj, phase):
1889
2014
  self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
1890
2015
 
@@ -1916,7 +2041,12 @@ class _CellGraphExecutor:
1916
2041
  obj.__parse_method__ = 'construct'
1917
2042
  if not hasattr(obj, obj.__parse_method__):
1918
2043
  raise AttributeError(
1919
- 'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2044
+ 'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2045
+ inner_func = inspect.unwrap(obj.construct)
2046
+ if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
2047
+ raise ValueError(
2048
+ "When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
2049
+ )
1920
2050
  key_id = str(id(obj)) + str(obj.create_time)
1921
2051
  args = get_auto_dynamic_shape_args(args, key_id)
1922
2052
 
@@ -1927,20 +2057,25 @@ class _CellGraphExecutor:
1927
2057
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1928
2058
 
1929
2059
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1930
- obj.arguments_key = str(key)
1931
-
1932
- obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
2060
+ key = str(key)
1933
2061
 
1934
2062
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1935
2063
  parameter_ids = _get_parameter_ids(args, kwargs)
1936
2064
  if parameter_ids != "":
1937
- obj.arguments_key = obj.arguments_key + '.' + parameter_ids
2065
+ key += '.' + parameter_ids
2066
+
2067
+ key += "." + _get_hook_key(*args, **kwargs)
2068
+ key += "." + str(_hook_version())
2069
+
2070
+ obj.arguments_key = key
2071
+
1938
2072
  raw_phase = phase
1939
- phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2073
+
2074
+ phase = _real_phase(phase, obj)
1940
2075
  obj.phase_cache[raw_phase] = phase
1941
2076
  update_auto_dynamic_shape_phase(args, key_id, phase)
1942
2077
  obj.current_phase = phase
1943
- if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
2078
+ if phase in obj.compile_cache and self.has_compiled(phase):
1944
2079
  logger.debug("%r graph has existed.", phase)
1945
2080
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1946
2081
  # generated in generate_arguments_key.
@@ -1948,7 +2083,7 @@ class _CellGraphExecutor:
1948
2083
  _clear_auto_parallel_context(obj)
1949
2084
  return phase, False
1950
2085
 
1951
- full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
2086
+ full_function_name = obj.__class__.__name__ + '.' + str(obj.total_instance_count) + '.' + str(id(type(obj)))
1952
2087
  echo_function_name = obj.__class__.__name__
1953
2088
  _check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
1954
2089
 
@@ -1958,17 +2093,14 @@ class _CellGraphExecutor:
1958
2093
  self._set_compile_cache_dep_files(phase)
1959
2094
 
1960
2095
  self._graph_executor.set_weights_values(obj.parameters_dict())
1961
- if jit_config_dict:
1962
- self._graph_executor.set_jit_config(jit_config_dict)
1963
- else:
2096
+ if not jit_config_dict:
1964
2097
  jit_config_dict = JitConfig().jit_config_dict
1965
- self._graph_executor.set_jit_config(jit_config_dict)
1966
2098
  gc.collect()
1967
- result = self._graph_executor.compile(obj, args, kwargs, phase)
2099
+ result = self._graph_executor.compile(
2100
+ obj, args, kwargs, phase, jit_config_dict)
1968
2101
  obj.compile_cache.add(phase)
1969
2102
  if not result:
1970
2103
  raise RuntimeError("Executor compile failed.")
1971
- set_parameter_hook_updated(False)
1972
2104
  graph = self._graph_executor.get_func_graph(phase)
1973
2105
 
1974
2106
  if graph is None:
@@ -1993,15 +2125,15 @@ class _CellGraphExecutor:
1993
2125
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
1994
2126
 
1995
2127
  def _get_shard_strategy(self, obj):
1996
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2128
+ real_phase = _real_phase(obj.phase, obj)
1997
2129
  return self._graph_executor.get_strategy(real_phase)
1998
2130
 
1999
2131
  def _get_num_parallel_ops(self, obj):
2000
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2132
+ real_phase = _real_phase(obj.phase, obj)
2001
2133
  return self._graph_executor.get_num_parallel_ops(real_phase)
2002
2134
 
2003
2135
  def _get_allreduce_fusion(self, obj):
2004
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2136
+ real_phase = _real_phase(obj.phase, obj)
2005
2137
  return self._graph_executor.get_allreduce_fusion(real_phase)
2006
2138
 
2007
2139
  def __call__(self, obj, *args, phase='predict'):
@@ -2053,10 +2185,10 @@ class _CellGraphExecutor:
2053
2185
  Tensor/Tuple, return execute result.
2054
2186
  """
2055
2187
  if phase == 'save':
2056
- exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2188
+ exe_phase = _real_phase(phase, obj)
2057
2189
  return self._graph_executor((), exe_phase)
2058
2190
 
2059
- phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2191
+ phase_real = _real_phase(phase, obj)
2060
2192
  if self.has_compiled(phase_real):
2061
2193
  return self._exec_pip(obj, *args, phase=phase_real)
2062
2194
  raise KeyError('{} graph is not exist.'.format(phase_real))
@@ -2083,7 +2215,7 @@ class _CellGraphExecutor:
2083
2215
 
2084
2216
  def get_optimize_graph_proto(self, obj):
2085
2217
  """Return optimize graph binary proto."""
2086
- exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2218
+ exec_id = _real_phase(obj.phase, obj)
2087
2219
  if self._graph_executor.has_compiled(exec_id) is False:
2088
2220
  return None
2089
2221
  graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
@@ -2165,5 +2297,3 @@ def flops_collection(phase='train'):
2165
2297
 
2166
2298
  _cell_graph_executor = _CellGraphExecutor()
2167
2299
  _pynative_executor = _PyNativeExecutor()
2168
-
2169
- __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']