mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (458) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +65 -84
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +58 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +178 -53
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +377 -203
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +5 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +117 -131
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +67 -55
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +70 -24
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +27 -7
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +6 -46
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +429 -23
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +140 -104
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +491 -623
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +117 -110
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +4 -6
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/asgd.py +2 -0
  188. mindspore/nn/optim/lamb.py +1 -3
  189. mindspore/nn/optim/optimizer.py +1 -1
  190. mindspore/nn/optim/tft_wrapper.py +2 -3
  191. mindspore/nn/optim/thor.py +2 -2
  192. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  193. mindspore/nn/probability/distribution/exponential.py +2 -1
  194. mindspore/nn/probability/distribution/poisson.py +2 -1
  195. mindspore/nn/sparse/sparse.py +3 -3
  196. mindspore/nn/wrap/cell_wrapper.py +73 -42
  197. mindspore/nn/wrap/grad_reducer.py +37 -52
  198. mindspore/nn/wrap/loss_scale.py +72 -74
  199. mindspore/numpy/array_creations.py +7 -7
  200. mindspore/numpy/fft.py +1 -1
  201. mindspore/numpy/math_ops.py +5 -5
  202. mindspore/numpy/utils_const.py +1 -1
  203. mindspore/opencv_core452.dll +0 -0
  204. mindspore/opencv_imgcodecs452.dll +0 -0
  205. mindspore/opencv_imgproc452.dll +0 -0
  206. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  207. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  209. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  210. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  211. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  212. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  213. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
  214. mindspore/ops/auto_generate/gen_extend_func.py +27 -145
  215. mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
  216. mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
  217. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  218. mindspore/ops/composite/__init__.py +10 -0
  219. mindspore/ops/composite/base.py +9 -5
  220. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  221. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  222. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  223. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  224. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  225. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  226. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  227. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  228. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  229. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  230. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  231. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  232. mindspore/ops/function/__init__.py +4 -1
  233. mindspore/ops/function/_add_attr_func.py +11 -6
  234. mindspore/ops/function/array_func.py +19 -102
  235. mindspore/ops/function/debug_func.py +8 -5
  236. mindspore/ops/function/grad/grad_func.py +5 -13
  237. mindspore/ops/function/math_func.py +77 -572
  238. mindspore/ops/function/nn_func.py +46 -94
  239. mindspore/ops/function/other_func.py +4 -1
  240. mindspore/ops/function/random_func.py +44 -5
  241. mindspore/ops/function/vmap_func.py +2 -1
  242. mindspore/ops/functional.py +4 -4
  243. mindspore/ops/functional_overload.py +594 -18
  244. mindspore/ops/op_info_register.py +21 -0
  245. mindspore/ops/operations/__init__.py +16 -11
  246. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  247. mindspore/ops/operations/_inner_ops.py +14 -18
  248. mindspore/ops/operations/_sequence_ops.py +1 -1
  249. mindspore/ops/operations/array_ops.py +5 -51
  250. mindspore/ops/operations/comm_ops.py +186 -41
  251. mindspore/ops/operations/custom_ops.py +303 -177
  252. mindspore/ops/operations/debug_ops.py +59 -4
  253. mindspore/ops/operations/image_ops.py +13 -13
  254. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  255. mindspore/ops/operations/math_ops.py +8 -9
  256. mindspore/ops/operations/nn_ops.py +8 -40
  257. mindspore/ops/primitive.py +9 -20
  258. mindspore/ops/tensor_method.py +63 -15
  259. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  260. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  261. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  262. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  263. mindspore/ops_generate/common/base_generator.py +14 -0
  264. mindspore/ops_generate/common/gen_constants.py +8 -3
  265. mindspore/ops_generate/common/gen_utils.py +0 -19
  266. mindspore/ops_generate/common/op_proto.py +11 -4
  267. mindspore/ops_generate/common/template.py +88 -11
  268. mindspore/ops_generate/gen_ops.py +1 -1
  269. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  270. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  271. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  272. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  273. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  274. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  275. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  276. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  277. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  278. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  279. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  280. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  281. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  282. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  283. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  284. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  285. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  286. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  287. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  288. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  289. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  290. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  291. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  292. mindspore/parallel/_auto_parallel_context.py +16 -23
  293. mindspore/parallel/_cell_wrapper.py +113 -45
  294. mindspore/parallel/_parallel_serialization.py +4 -3
  295. mindspore/parallel/_ps_context.py +4 -6
  296. mindspore/parallel/_tensor.py +167 -12
  297. mindspore/parallel/_transformer/moe.py +1 -1
  298. mindspore/parallel/_transformer/transformer.py +17 -12
  299. mindspore/parallel/_utils.py +5 -11
  300. mindspore/parallel/auto_parallel.py +35 -14
  301. mindspore/parallel/checkpoint_convert.py +3 -3
  302. mindspore/parallel/checkpoint_transform.py +13 -7
  303. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  304. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  305. mindspore/parallel/cluster/run.py +48 -7
  306. mindspore/parallel/function/__init__.py +8 -1
  307. mindspore/parallel/function/reshard_func.py +12 -12
  308. mindspore/parallel/nn/__init__.py +15 -2
  309. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  310. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  311. mindspore/parallel/shard.py +10 -25
  312. mindspore/parallel/transform_safetensors.py +469 -174
  313. mindspore/pgodb140.dll +0 -0
  314. mindspore/pgort140.dll +0 -0
  315. mindspore/profiler/__init__.py +2 -1
  316. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  317. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  318. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  319. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  321. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  322. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  323. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  324. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  325. mindspore/profiler/analysis/task_manager.py +1 -1
  326. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  327. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  328. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  329. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  330. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  331. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  332. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  333. mindspore/profiler/common/constant.py +16 -0
  334. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  335. mindspore/profiler/common/path_manager.py +9 -0
  336. mindspore/profiler/common/profiler_context.py +50 -29
  337. mindspore/profiler/common/profiler_info.py +0 -16
  338. mindspore/profiler/common/profiler_meta_data.py +1 -0
  339. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  340. mindspore/profiler/common/profiler_output_path.py +23 -8
  341. mindspore/profiler/common/profiler_parameters.py +128 -35
  342. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  343. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  344. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  345. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  346. mindspore/profiler/dynamic_profiler.py +374 -338
  347. mindspore/profiler/envprofiler.py +42 -12
  348. mindspore/profiler/experimental_config.py +112 -7
  349. mindspore/profiler/mstx.py +33 -12
  350. mindspore/profiler/platform/__init__.py +2 -3
  351. mindspore/profiler/platform/cpu_profiler.py +10 -4
  352. mindspore/profiler/platform/npu_profiler.py +30 -20
  353. mindspore/profiler/profiler.py +218 -154
  354. mindspore/profiler/profiler_action_controller.py +65 -77
  355. mindspore/profiler/profiler_interface.py +2 -2
  356. mindspore/profiler/schedule.py +10 -4
  357. mindspore/rewrite/common/config.py +1 -0
  358. mindspore/rewrite/common/namer.py +1 -0
  359. mindspore/rewrite/common/namespace.py +1 -0
  360. mindspore/rewrite/node/node.py +31 -11
  361. mindspore/rewrite/parsers/assign_parser.py +1 -1
  362. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  363. mindspore/run_check/_check_version.py +7 -10
  364. mindspore/runtime/__init__.py +8 -6
  365. mindspore/runtime/event.py +10 -4
  366. mindspore/runtime/executor.py +87 -45
  367. mindspore/runtime/memory.py +31 -32
  368. mindspore/runtime/thread_bind_core.py +299 -165
  369. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  370. mindspore/swresample-4.dll +0 -0
  371. mindspore/swscale-6.dll +0 -0
  372. mindspore/tbbmalloc.dll +0 -0
  373. mindspore/tinyxml2.dll +0 -0
  374. mindspore/train/_utils.py +17 -7
  375. mindspore/train/amp.py +43 -23
  376. mindspore/train/callback/__init__.py +5 -5
  377. mindspore/train/callback/_callback.py +2 -1
  378. mindspore/train/callback/_checkpoint.py +4 -14
  379. mindspore/train/callback/_flops_collector.py +11 -7
  380. mindspore/train/callback/_landscape.py +0 -1
  381. mindspore/train/callback/_train_fault_tolerance.py +98 -21
  382. mindspore/train/data_sink.py +15 -6
  383. mindspore/train/dataset_helper.py +14 -5
  384. mindspore/train/model.py +133 -69
  385. mindspore/train/serialization.py +168 -126
  386. mindspore/train/summary/summary_record.py +13 -2
  387. mindspore/train/train_thor/model_thor.py +2 -2
  388. mindspore/turbojpeg.dll +0 -0
  389. mindspore/utils/__init__.py +3 -2
  390. mindspore/utils/dryrun.py +0 -6
  391. mindspore/utils/runtime_execution_order_check.py +163 -77
  392. mindspore/utils/sdc_detect.py +68 -0
  393. mindspore/utils/utils.py +14 -17
  394. mindspore/vcmeta.dll +0 -0
  395. mindspore/vcruntime140.dll +0 -0
  396. mindspore/vcruntime140_1.dll +0 -0
  397. mindspore/version.py +1 -1
  398. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  399. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
  400. mindspore/_deprecated/jit.py +0 -198
  401. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  402. mindspore/communication/_hccl_management.py +0 -297
  403. mindspore/experimental/es/embedding_service.py +0 -891
  404. mindspore/experimental/es/embedding_service_layer.py +0 -581
  405. mindspore/profiler/common/validator/__init__.py +0 -14
  406. mindspore/profiler/common/validator/validate_path.py +0 -84
  407. mindspore/profiler/parser/__init__.py +0 -14
  408. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  409. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  410. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  416. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  417. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  418. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  419. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  420. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  421. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  422. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  423. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  424. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  425. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  426. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  427. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  428. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  429. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  430. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  431. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  432. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  433. mindspore/profiler/parser/container.py +0 -229
  434. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  435. mindspore/profiler/parser/flops_parser.py +0 -531
  436. mindspore/profiler/parser/framework_enum.py +0 -111
  437. mindspore/profiler/parser/framework_parser.py +0 -464
  438. mindspore/profiler/parser/framework_struct.py +0 -61
  439. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  440. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  441. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  442. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  443. mindspore/profiler/parser/hccl_parser.py +0 -573
  444. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  445. mindspore/profiler/parser/integrator.py +0 -526
  446. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  447. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  448. mindspore/profiler/parser/minddata_parser.py +0 -186
  449. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  450. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  451. mindspore/profiler/parser/optime_parser.py +0 -250
  452. mindspore/profiler/parser/profiler_info.py +0 -213
  453. mindspore/profiler/parser/step_trace_parser.py +0 -666
  454. mindspore/utils/hooks.py +0 -81
  455. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  456. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  457. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  458. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/common/api.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2
2
  #
3
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
3
+ # Copyright 2020-2025 Huawei Technologies Co., Ltd
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
6
  # you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
17
17
  """Providing interface methods."""
18
18
  from __future__ import absolute_import
19
19
 
20
+ __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
21
+
20
22
  import gc
21
23
  import types
22
24
  import sys
@@ -42,23 +44,25 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
42
44
  from mindspore._c_expression.amp import get_curr_amp_strategy
43
45
  from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
44
46
  PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
45
- _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, StubNode, MSContext, TensorPy as Tensor
47
+ _run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
46
48
  from mindspore.parallel._ps_context import _is_role_sched
47
49
  from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
48
50
  _is_parallel_mode
49
51
  from mindspore import _checkparam as Validator
50
52
  from mindspore._checkparam import is_stub_tensor
51
- from mindspore.common._utils import is_shape_unknown
53
+ from mindspore.common._utils import is_shape_unknown, get_func
52
54
  from mindspore.common.mutable import mutable, _check_element_type
53
- from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
54
- get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
55
+ from mindspore.common.dynamic_shape.auto_dynamic_shape import get_auto_dynamic_shape_args, \
56
+ update_auto_dynamic_shape_phase
57
+ from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
55
58
  from mindspore.common._pijit_context import PIJitCaptureContext
56
- from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
59
+ from mindspore.common.parameter import Parameter
60
+ from mindspore.common.hook_handle import _hook_version
57
61
  from mindspore.common.jit_context import jit_context
58
62
  from mindspore.common.jit_trace import _jit_trace
59
63
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
60
64
 
61
- # Store ms_function class compiled pipeline cache.
65
+ # Store jit class compiled pipeline cache.
62
66
  ms_compile_cache = set()
63
67
  # Store cell compiled pipeline cache.
64
68
  cells_compile_cache = {}
@@ -72,6 +76,11 @@ ARG_SPECIFIED = "arg_specified_infos"
72
76
  TOTAL_ARG_LEN = "total_arg_length"
73
77
 
74
78
 
79
+ def _real_phase(phase, obj):
80
+ real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
81
+ return real_phase
82
+
83
+
75
84
  def _check_recompile_args(compile_args, kwargs):
76
85
  """Check recompile of graph"""
77
86
 
@@ -134,8 +143,6 @@ def _convert_python_data(data):
134
143
  """
135
144
  if isinstance(data, PythonTensor):
136
145
  return data
137
- if isinstance(data, StubNode):
138
- return ms.common._stub_tensor._convert_stub(data)
139
146
  if data.__class__ is tuple:
140
147
  # Handle namedtuple since its type is tuple.
141
148
  if hasattr(data, "_fields"):
@@ -278,13 +285,13 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
278
285
  module = importlib.util.module_from_spec(module_spec)
279
286
  if hasattr(module, '__file__'):
280
287
  dep_file_path = module.__file__
288
+ # Exclude the installed modules.
289
+ if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
290
+ logger.debug(f"dependent file path: {dep_file_path}")
291
+ compile_cache_dep_files.append(dep_file_path)
292
+ __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
281
293
  else:
282
294
  continue
283
- # Exclude the installed modules.
284
- if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
285
- logger.debug(f"dependent file path: {dep_file_path}")
286
- compile_cache_dep_files.append(dep_file_path)
287
- __get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
288
295
 
289
296
 
290
297
  def _get_compile_cache_dep_files():
@@ -342,7 +349,7 @@ def _get_parameter_layout():
342
349
  return layout
343
350
 
344
351
 
345
- def _handle_arg(obj, arg, compile_arg):
352
+ def _handle_arg(obj, arg, has_mutable_arg):
346
353
  """Handle arg for runtime .If need handle the arg, return True"""
347
354
  from mindspore._extends.parse import compile_config
348
355
  if isinstance(arg, PythonTensor):
@@ -352,7 +359,7 @@ def _handle_arg(obj, arg, compile_arg):
352
359
  return arg
353
360
  elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
354
361
  return arg
355
- elif compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and getattr(compile_arg, "__ms_mutable__"):
362
+ elif has_mutable_arg:
356
363
  # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
357
364
  if isinstance(arg, list) and not arg:
358
365
  return None
@@ -366,7 +373,7 @@ def _handle_arg(obj, arg, compile_arg):
366
373
  return None
367
374
 
368
375
 
369
- def _handle_arg_predict(obj, arg, compile_arg):
376
+ def _handle_arg_predict(obj, arg, has_mutable_arg):
370
377
  """Handle arg for runtime .If need handle the arg, return True"""
371
378
  if arg is None:
372
379
  return None
@@ -375,8 +382,7 @@ def _handle_arg_predict(obj, arg, compile_arg):
375
382
  return None
376
383
 
377
384
  if isinstance(arg, (list, tuple)):
378
- if compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
379
- getattr(compile_arg, "__ms_mutable__"):
385
+ if has_mutable_arg:
380
386
  # mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
381
387
  if isinstance(arg, list) and not arg:
382
388
  return None
@@ -388,35 +394,30 @@ def _handle_arg_predict(obj, arg, compile_arg):
388
394
  return arg
389
395
 
390
396
 
391
- def _get_args_for_run(obj, args, kwargs, compile_args):
397
+ def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
392
398
  """Get the actual input args and kwargs for runtime."""
393
399
  new_args = []
394
- for arg, compile_arg in zip(args, compile_args):
395
- new_arg = _handle_arg(obj, arg, compile_arg)
400
+ fn = _handle_arg_predict if is_predict else _handle_arg
401
+ for arg, has_mutable_arg in zip(args, has_mutable_args_list):
402
+ new_arg = fn(obj, arg, has_mutable_arg)
396
403
  if new_arg is not None:
397
404
  new_args.append(new_arg)
398
405
 
399
406
  for _, value in kwargs.items():
400
- new_value = _handle_arg(obj, value, None)
407
+ new_value = fn(obj, value, None)
401
408
  if new_value is not None:
402
409
  new_args.append(new_value)
403
410
 
404
411
  return new_args
405
412
 
406
413
 
407
- def _get_args_for_run_predict(obj, args, kwargs, compile_args):
408
- """Get the actual input args and kwargs for runtime."""
414
+ def _get_mutable_flags(compile_args):
415
+ """Get a list of booleans indicating whether each argument is marked as mutable"""
409
416
  new_args = []
410
- for arg, compile_arg in zip(args, compile_args):
411
- new_arg = _handle_arg_predict(obj, arg, compile_arg)
412
- if new_arg is not None:
413
- new_args.append(new_arg)
414
-
415
- for _, value in kwargs.items():
416
- new_value = _handle_arg_predict(obj, value, None)
417
- if new_value is not None:
418
- new_args.append(new_value)
419
-
417
+ for compile_arg in compile_args:
418
+ has_mutable_arg = compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
419
+ getattr(compile_arg, "__ms_mutable__")
420
+ new_args.append(has_mutable_arg)
420
421
  return new_args
421
422
 
422
423
 
@@ -544,10 +545,12 @@ def _get_parameter_ids(args, kwargs):
544
545
  parameter_ids += str(id(value))
545
546
  return parameter_ids
546
547
 
548
+
547
549
  def _get_tensor_hook_key(tensor):
548
550
  """Get the hook key of Tensor/Parameter"""
549
551
  return ".".join(map(str, map(id, tensor.hooks())))
550
552
 
553
+
551
554
  def _get_hook_key(*args, **kwargs):
552
555
  """Get the hook key of Tensors/Parameters"""
553
556
  hook_key = ""
@@ -586,13 +589,16 @@ class _JitExecutor:
586
589
  The result of pipeline running in graph mode.
587
590
  """
588
591
 
589
- def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0):
592
+ def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0,
593
+ cell_cache_key_extend=''):
590
594
  init_pipeline()
591
595
  if not isinstance(fn, (types.FunctionType, types.MethodType)):
592
596
  raise RuntimeError('fn {} is not function or method'.format(fn))
593
597
 
594
598
  self.fn = fn
595
599
  self.input_signature = input_signature
600
+ self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
601
+ self.enable_jit_dynamic = self.dynamic_args_shapes is not None
596
602
  self.obj = None
597
603
  if obj and hasattr(obj, fn.__name__):
598
604
  self.obj = obj
@@ -606,9 +612,55 @@ class _JitExecutor:
606
612
  self._compile_args = None
607
613
  self._enable_auto_dynamic = dynamic == 1
608
614
  self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
615
+ self._cell_cache_key_extend = cell_cache_key_extend
616
+
617
+ def _predict(self, *args, **kwargs):
618
+ """Dedicated routine for predict."""
619
+ if not hasattr(self.obj, "phase"):
620
+ return False, None
621
+
622
+ predict_vailid_phase = {"prefill", 'increment'}
623
+ predict_phase = self.obj.phase
624
+ if predict_phase not in predict_vailid_phase:
625
+ return False, None
626
+
627
+ args_list = args
628
+ if self.obj is not None:
629
+ args_list = args_list[1:]
630
+
631
+ if predict_phase not in self.obj.phase_cache:
632
+ try:
633
+ predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
634
+ except Exception as err:
635
+ _pynative_executor.clear_res()
636
+ raise err
637
+ else: # get compiled args to generate run args by _generate_run_args
638
+ compile_args = self._generate_compile_args(args_list)
639
+ key_id = self._get_key_id()
640
+ if self.input_signature is None:
641
+ compile_args = get_auto_dynamic_shape_args(
642
+ compile_args, key_id, self._enable_auto_dynamic
643
+ )
644
+ self._compile_args = compile_args
645
+
646
+ new_inputs = self._generate_run_args(args_list, kwargs)
647
+ if self.jit_config_dict:
648
+ jit_config_dict = self.jit_config_dict
649
+ else:
650
+ jit_config_dict = JitConfig().jit_config_dict
651
+ self._graph_executor.set_jit_config(jit_config_dict)
652
+ output = self._graph_executor(
653
+ tuple(new_inputs),
654
+ self.obj.phase_cache[self.obj.phase]
655
+ )
656
+ res = _convert_python_data(output)
657
+ return True, res
609
658
 
610
659
  @_wrap_func
611
660
  def __call__(self, *args, **kwargs):
661
+ predict, res = self._predict(*args, **kwargs)
662
+ if predict:
663
+ return res
612
664
  if jit_context() and jit_context().is_nested():
613
665
  return jit_context().run_graph("", None, *())
614
666
  args_list = args
@@ -616,12 +668,9 @@ class _JitExecutor:
616
668
  args_list = args_list[1:]
617
669
  phase = ""
618
670
  try:
619
- if context.get_context("mode") == context.PYNATIVE_MODE:
620
- _pynative_executor.set_jit_compile_status(True, phase)
621
- phase = self.compile(self.fn.__name__, *args_list, **kwargs)
622
- _pynative_executor.set_jit_compile_status(False, phase)
623
- else:
624
- phase = self.compile(self.fn.__name__, *args_list, **kwargs)
671
+ _pynative_executor.set_jit_compile_status(True, phase)
672
+ phase = self.compile(self.fn.__name__, *args_list, **kwargs)
673
+ _pynative_executor.set_jit_compile_status(False, phase)
625
674
  except Exception as err:
626
675
  _pynative_executor.clear_res()
627
676
  raise err
@@ -630,31 +679,27 @@ class _JitExecutor:
630
679
  return None
631
680
 
632
681
  new_inputs = self._generate_run_args(args_list, kwargs)
633
- if context.get_context("mode") == context.PYNATIVE_MODE and not jit_context():
634
- output = _pynative_executor.grad_jit(*new_inputs)
682
+ if self.jit_config_dict:
683
+ jit_config_dict = self.jit_config_dict
635
684
  else:
636
- output = self._graph_executor(tuple(new_inputs), phase)
637
- if jit_context():
638
- if is_stub_tensor(output):
639
- output = output.stub_sync()
640
- return jit_context().run_graph(phase, output, *tuple(new_inputs))
641
-
685
+ jit_config_dict = JitConfig().jit_config_dict
686
+ self._graph_executor.set_jit_config(jit_config_dict)
687
+ output = _pynative_executor.grad_jit(*new_inputs)
688
+ if jit_context():
689
+ if is_stub_tensor(output):
690
+ output = output.stub_sync()
691
+ return jit_context().run_graph(phase, output, *tuple(new_inputs))
642
692
  return output
643
693
 
644
694
  def compile(self, method_name, *args, **kwargs):
645
695
  """Returns pipeline for the given args."""
646
- # Check whether hook function registered on Cell object.
647
- if self.obj and hasattr(self.obj, "_hook_fn_registered"):
648
- if self.obj._hook_fn_registered():
649
- logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
650
- f"If you want to use hook function, please use context.set_context to set "
651
- f"pynative mode and remove 'jit' decorator.")
652
696
  # Chose dynamic shape tensors or actual input tensors as compile args.
653
697
  compile_args = self._generate_compile_args(args)
654
698
  key_id = self._get_key_id()
655
- compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
656
- self.input_signature,
657
- self._enable_auto_dynamic)
699
+ if self.input_signature is None:
700
+ compile_args = get_auto_dynamic_shape_args(
701
+ compile_args, key_id, self._enable_auto_dynamic, self.enable_jit_dynamic
702
+ )
658
703
 
659
704
  # Add mutable for compile_args for two scene:
660
705
  # 1) Origin args is mutable.
@@ -674,7 +719,7 @@ class _JitExecutor:
674
719
  f'`{self.fn.__module__}`')
675
720
  self.obj.__parse_method__ = method_name
676
721
  if isinstance(self.obj, ms.nn.Cell):
677
- generate_name = generate_name + '.' + str(self.obj.create_time)
722
+ generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
678
723
  create_time = str(self.obj.create_time)
679
724
  else:
680
725
  generate_name = generate_name + '.' + str(self._create_time)
@@ -694,18 +739,23 @@ class _JitExecutor:
694
739
 
695
740
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
696
741
  key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
742
+ key = str(key)
697
743
 
698
744
  parameter_ids = _get_parameter_ids(args, kwargs)
699
745
  if parameter_ids != "":
700
- key = str(key) + '.' + parameter_ids
746
+ key += '.' + parameter_ids
747
+
748
+ key += "." + _get_hook_key(*args, **kwargs)
749
+ key += "." + str(_hook_version())
701
750
 
702
- key = str(key) + "." + _get_hook_key(*args, **kwargs)
751
+ phase = generate_name + '.' + key
703
752
 
704
- phase = generate_name + '.' + str(key)
753
+ if self.input_signature is None:
754
+ update_auto_dynamic_shape_phase(compile_args, key_id, phase)
705
755
 
706
- update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
756
+ phase = phase + self._cell_cache_key_extend
707
757
 
708
- if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
758
+ if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
709
759
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
710
760
  # generated in generate_arguments_key.
711
761
  self._graph_executor.clear_compile_arguments_resource()
@@ -716,31 +766,26 @@ class _JitExecutor:
716
766
  # If enable compile cache, get the dependency files list and set to graph executor.
717
767
  self._set_compile_cache_dep_files()
718
768
  if self.jit_config_dict:
719
- self._graph_executor.set_jit_config(self.jit_config_dict)
769
+ jit_config_dict = self.jit_config_dict
720
770
  else:
721
771
  jit_config_dict = JitConfig().jit_config_dict
722
- self._graph_executor.set_jit_config(jit_config_dict)
723
772
 
724
773
  if self.obj is None:
725
774
  # Set an attribute to fn as an identifier.
726
- if isinstance(self.fn, types.MethodType):
727
- setattr(self.fn.__func__, "__jit_function__", True)
728
- else:
729
- setattr(self.fn, "__jit_function__", True)
730
- is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase)
731
- if isinstance(self.fn, types.MethodType):
732
- delattr(self.fn.__func__, "__jit_function__")
733
- else:
734
- delattr(self.fn, "__jit_function__")
775
+ setattr(get_func(self.fn), "__jit_function__", True)
776
+ is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, jit_config_dict)
777
+ delattr(get_func(self.fn), "__jit_function__")
735
778
  else:
736
779
  if isinstance(self.obj, ms.nn.Cell):
737
780
  self._graph_executor.set_weights_values(self.obj.parameters_dict())
738
- is_compile = self._graph_executor.compile(self.obj, compile_args, kwargs, phase)
781
+ is_compile = self._graph_executor.compile(
782
+ self.obj, compile_args, kwargs, phase, jit_config_dict)
739
783
 
740
784
  if not is_compile:
741
785
  raise RuntimeError("Executor compile failed.")
742
- set_parameter_hook_updated(False)
743
786
  ms_compile_cache.add(phase)
787
+ if hasattr(self.obj, "phase"):
788
+ self.obj.phase_cache[self.obj.phase] = phase
744
789
 
745
790
  return phase
746
791
 
@@ -785,41 +830,70 @@ class _JitExecutor:
785
830
  if enable_compile_cache is True or enable_compile_cache == "1":
786
831
  self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
787
832
 
833
+ def _generate_compile_args_by_enable_dynamic(self, args_list):
834
+ """Generate compile args by enable_dynamic."""
835
+ compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
836
+ compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
837
+ if self.obj is not None:
838
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
839
+ else:
840
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
841
+ logger.info(f"dynamic shape compile_args: {compile_args}")
842
+ return compile_args
843
+
844
+ def _generate_compile_args_by_set_inputs(self, args_list):
845
+ """Generate compile args by set_inputs."""
846
+ compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
847
+ if len(compile_args) != len(args_list):
848
+ raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
849
+ f"dynamic shape tensors: {len(compile_args)}.")
850
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
851
+ Validator.check_symbolic_shape(compile_args, args_list)
852
+ return compile_args
853
+
854
+ def _generate_compile_args_by_input_signature(self, args_list):
855
+ """Generate compile args by input_signature."""
856
+ compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
857
+ dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
858
+ Validator.check_symbolic_shape(self.input_signature, args_list)
859
+ if dyn_shape:
860
+ # Checkout whether the `sens` has been added to args_list.
861
+ if len(compile_args) == len(args_list) - 1:
862
+ logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
863
+ f"of input_signature args '{len(compile_args)}'. The last actual args may "
864
+ f"be 'sens' and added it to compile args.")
865
+ compile_args.append(args_list[-1])
866
+ compile_args = tuple(compile_args)
867
+ self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
868
+ if self.obj is not None:
869
+ _pynative_executor.set_dynamic_input(self.obj, *compile_args)
870
+ else:
871
+ _pynative_executor.set_dynamic_input(self.fn, *compile_args)
872
+ else:
873
+ if not verify_inputs_signature(compile_args, args_list):
874
+ raise ValueError("The input args is incompatible with the args in `input_signature`!")
875
+ return compile_args
876
+
877
+ def _check_set_inputs(self):
878
+ """Check if the `set_inputs()` of Cell object has been set."""
879
+ return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
880
+
788
881
  def _generate_compile_args(self, args_list):
789
882
  """Chose dynamic shape tensors or actual input tensors as compile args."""
790
- # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
791
- compile_args = _pynative_executor.get_dynamic_input(args_list)
883
+ # Case: The `enable_dynamic` is provided and `set_inputs()` of Cell object has been set.
884
+ if self.enable_jit_dynamic and self._check_set_inputs():
885
+ raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
886
+ # Case: The `enable_dynamic` is provided.
887
+ if self.enable_jit_dynamic:
888
+ return self._generate_compile_args_by_enable_dynamic(args_list)
792
889
  # Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
793
- if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
794
- compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
795
- if len(compile_args) != len(args_list):
796
- raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
797
- f"dynamic shape tensors: {len(compile_args)}.")
798
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
799
- Validator.check_symbolic_shape(compile_args, args_list)
800
-
890
+ if self._check_set_inputs():
891
+ return self._generate_compile_args_by_set_inputs(args_list)
801
892
  # Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
802
893
  if self.input_signature is not None:
803
- compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
804
- dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
805
- Validator.check_symbolic_shape(self.input_signature, args_list)
806
- if dyn_shape:
807
- # Checkout whether the `sens` has been added to args_list.
808
- if len(compile_args) == len(args_list) - 1:
809
- logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
810
- f"of input_signature args '{len(compile_args)}'. The last actual args may "
811
- f"be 'sens' and added it to compile args.")
812
- compile_args.append(args_list[-1])
813
- compile_args = tuple(compile_args)
814
- self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
815
- if self.obj is not None:
816
- _pynative_executor.set_dynamic_input(self.obj, *compile_args)
817
- else:
818
- _pynative_executor.set_dynamic_input(self.fn, *compile_args)
819
- else:
820
- if not verify_inputs_signature(compile_args, args_list):
821
- raise ValueError("The input args is incompatible with the args in `input_signature`!")
822
- return compile_args
894
+ return self._generate_compile_args_by_input_signature(args_list)
895
+ # Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
896
+ return _pynative_executor.get_dynamic_input(args_list)
823
897
 
824
898
  def _generate_run_args(self, args_list, kwargs):
825
899
  """
@@ -832,7 +906,7 @@ class _JitExecutor:
832
906
  Returns:
833
907
  new_inputs, new input args, which are required for running.
834
908
  """
835
- return _get_args_for_run(self, args_list, kwargs, self._compile_args)
909
+ return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
836
910
 
837
911
  def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
838
912
  """Get graph proto from pipeline."""
@@ -993,6 +1067,67 @@ def _check_options(options, backend):
993
1067
  _check_option_value(option, value)
994
1068
 
995
1069
 
1070
+ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
1071
+ """Return the wrapped function for ast mode jit."""
1072
+ def wrap_func(func):
1073
+ nonlocal hash_obj
1074
+ if hasattr(func, "construct"):
1075
+ if isinstance(func, ms.nn.Cell):
1076
+ # Bound the cell object to get the self arg.
1077
+ return types.MethodType(_jit_ast(
1078
+ hash_obj, dynamic, jit_config, func._jit_graph_name)(func.construct.__func__), func)
1079
+ if isinstance(func, type) and issubclass(func, ms.nn.Cell):
1080
+ func.construct = _jit_ast(
1081
+ hash_obj, dynamic, jit_config, '')(func.construct)
1082
+ return func
1083
+
1084
+ if isinstance(func, types.MethodType):
1085
+ return types.MethodType(_jit_ast(hash_obj, dynamic, jit_config, '')(func.__func__), func.__self__)
1086
+
1087
+ if not isinstance(func, types.FunctionType):
1088
+ logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
1089
+ return func
1090
+
1091
+ if hasattr(func, "__wrapped_by_jit__"):
1092
+ logger.warning(f"The func {func} should be wrapped by jit only once.")
1093
+
1094
+ if hash_obj is None or not _is_inner_func(func):
1095
+ hash_obj = int(time.time() * 1e9)
1096
+
1097
+ @wraps(func)
1098
+ def staging_specialize(*args, **kwargs):
1099
+ if os.getenv("MS_JIT") == '0':
1100
+ return func(*args, **kwargs)
1101
+
1102
+ args, kwargs = _handle_func_args(func, *args, **kwargs)
1103
+ process_obj = None
1104
+ if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1105
+ process_obj = args[0]
1106
+ # Handle auto mixed precision strategy.
1107
+ if not hasattr(func, "amp_strategy"):
1108
+ setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
1109
+
1110
+ jit_graph_name = ''
1111
+ if hasattr(staging_specialize, "__jit_graph_name__"):
1112
+ jit_graph_name = staging_specialize.__jit_graph_name__
1113
+ jit_executor = _JitExecutor(
1114
+ func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
1115
+ out = jit_executor(*args, **kwargs)
1116
+ if isinstance(process_obj, ms.nn.Cell):
1117
+ _clear_auto_parallel_context(process_obj)
1118
+ return out
1119
+
1120
+ # `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
1121
+ # `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
1122
+ # original `func`.
1123
+ staging_specialize.__signature__ = inspect.signature(func)
1124
+ setattr(staging_specialize, "__wrapped_by_jit__", True)
1125
+ setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
1126
+ return staging_specialize
1127
+
1128
+ return wrap_func
1129
+
1130
+
996
1131
  def jit(
997
1132
  function: Optional[Callable] = None,
998
1133
  *,
@@ -1015,45 +1150,45 @@ def jit(
1015
1150
  and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
1016
1151
 
1017
1152
  Args:
1018
- function (Function, optional): The Python function that will be run as a graph. Default: ``None``.
1153
+ function (Callable, optional): The Python function or Cell that will be run as a graph. Default: ``None``.
1019
1154
 
1020
1155
  Keyword Args:
1021
1156
  capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
1022
- should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
1157
+ should be ``"ast"`` , ``"bytecode"`` or ``"trace"`` . Default: ``"ast"`` .
1023
1158
 
1024
- - `ast <https://www.mindspore.cn/tutorials/en/master/compile/static_graph.html>`_ :
1025
- Parse Python ast to build graph.
1026
- - `bytecode` :
1027
- Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
1028
- change and/or deletion.
1029
- - `trace` : Trace the execution of Python code to build graph. This is an experimental prototype that is
1030
- subject to change and/or deletion.
1159
+ - ast: Parse Python ast to build graph.
1160
+ - bytecode: Parse Python bytecode to build graph at runtime. This is an experimental prototype
1161
+ that is subject to change and/or deletion.
1162
+ - trace: Trace the execution of Python code to build graph. This is an experimental prototype
1163
+ that is subject to change and/or deletion.
1031
1164
 
1032
1165
  jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
1033
- with default backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
1166
+ with ms_backend. The value of jit_level should be ``"O0"`` or ``"O1"`` . Default: ``"O0"`` .
1034
1167
 
1035
- - `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
1036
- - `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1168
+ - O0: Except for optimizations that may affect functionality, all other optimizations are turned off.
1169
+ - O1: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
1037
1170
  level is experimental and is being improved.
1038
1171
 
1039
1172
  dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
1040
1173
  is as follows:
1041
1174
 
1042
- - `0`: Do not perform dynamic shape compilation.
1043
- - `1`: Enable dynamic shape compilation and automatically detect shape changes.
1175
+ - 0: Do not perform dynamic shape compilation.
1176
+ - 1: Enable dynamic shape compilation and automatically detect shape changes.
1044
1177
 
1045
1178
  fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
1046
1179
  be compatible with all Python syntax in the function as much as possible. If True, we require that the
1047
1180
  entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
1048
- not supported), then it will raise an exception. This currently only applies when capture_mode is ast.
1049
- Default: ``False``.
1181
+ not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
1182
+ or ``bytecode``. Default: ``False``.
1050
1183
  backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
1051
- use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
1052
- A2 training series products by default.
1184
+ use ``"GE"`` backend for Atlas training series products and ``"ms_backend"`` backend for others including
1185
+ Atlas A2 training series products by default.
1053
1186
 
1054
- - `ms_backend`: Adopt KernelByKernel execution mode.
1055
- - `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
1056
- the top cell of model. And only can be used in Ascend platform.
1187
+ - ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
1188
+ optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
1189
+ - GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
1190
+ for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
1191
+ and can be executed only on Ascend hardware.
1057
1192
 
1058
1193
  **options (dict): A dictionary of options to pass to the compilation backend.
1059
1194
 
@@ -1076,11 +1211,11 @@ def jit(
1076
1211
  `disable_format_transform` can be set to ``True`` to try to improve training performance.
1077
1212
  Default: ``False`` .
1078
1213
  - exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
1079
- methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
1214
+ methods are supported: ``"bfs"`` and ``"dfs"`` . Default: ``"bfs"`` .
1080
1215
 
1081
- - `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
1216
+ - bfs: The default sorting method, breadth priority, good communication masking, relatively good
1082
1217
  performance.
1083
- - `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1218
+ - dfs: An optional sorting method, depth-first sorting. The performance is relatively worse than that
1084
1219
  of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
1085
1220
  other execution orders run out of memory (OOM).
1086
1221
 
@@ -1091,11 +1226,11 @@ def jit(
1091
1226
  - global (dict): Set global options.
1092
1227
  - session (dict): Set session options.
1093
1228
 
1094
- - infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
1229
+ - infer_boost (str, optional): Used to control the inference mode. Default: ``"off"``, which means
1095
1230
  the inference mode is disabled. The range is as follows:
1096
1231
 
1097
- - `on`: Enable inference mode, get better infer performance.
1098
- - `off`: Disable inference mode, use forward for inference. The performance is poor.
1232
+ - on: Enable inference mode, get better infer performance.
1233
+ - off: Disable inference mode, use forward for inference. The performance is poor.
1099
1234
 
1100
1235
  Returns:
1101
1236
  Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
@@ -1114,29 +1249,84 @@ def jit(
1114
1249
  >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1115
1250
  >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
1116
1251
  ...
1117
- >>> # create a callable MindSpore graph by calling jit
1252
+ >>> # Create a callable MindSpore graph by calling jit.
1118
1253
  >>> def tensor_add(x, y):
1119
1254
  ... z = x + y
1120
1255
  ... return z
1121
1256
  ...
1122
1257
  >>> tensor_add_graph = jit(function=tensor_add)
1123
1258
  >>> out = tensor_add_graph(x, y)
1259
+ >>> print(out)
1260
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1261
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1262
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1263
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1124
1264
  ...
1125
- >>> # create a callable MindSpore graph through decorator @jit
1265
+ >>> # Create a callable MindSpore graph through decorator @jit.
1126
1266
  >>> @jit
1127
1267
  ... def tensor_add_with_dec(x, y):
1128
1268
  ... z = x + y
1129
1269
  ... return z
1130
1270
  ...
1131
1271
  >>> out = tensor_add_with_dec(x, y)
1272
+ >>> print(out)
1273
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1274
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1275
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1276
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1132
1277
  ...
1133
- >>> # create a callable MindSpore graph and capture the entire function into the graph
1278
+ >>> # Create a callable MindSpore graph and capture the entire function into the graph.
1134
1279
  >>> @jit(fullgraph=True)
1135
1280
  ... def tensor_add_fullgraph(x, y):
1136
1281
  ... z = x + y
1137
1282
  ... return z
1138
1283
  ...
1139
1284
  >>> out = tensor_add_fullgraph(x, y)
1285
+ >>> print(out)
1286
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1287
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1288
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1289
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1290
+ ...
1291
+ >>> # Create a callable MindSpore graph by trace mode.
1292
+ >>> @jit(capture_mode="trace")
1293
+ ... def tensor_add_by_trace(x, y):
1294
+ ... z = x + y
1295
+ ... return z
1296
+ ...
1297
+ >>> out = tensor_add_by_trace(x, y)
1298
+ >>> print(out)
1299
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1300
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1301
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1302
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1303
+ ...
1304
+ >>> # Create a callable MindSpore graph with ms_backend and jit_level="O1".
1305
+ >>> @jit(backend="ms_backend", jit_level="O1")
1306
+ ... def tensor_add_by_trace(x, y):
1307
+ ... z = x + y
1308
+ ... return z
1309
+ ...
1310
+ >>> out = tensor_add_by_trace(x, y)
1311
+ >>> print(out)
1312
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1313
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1314
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1315
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1316
+ ...
1317
+ >>> # Create a callable MindSpore graph with GE backend and some ge options on Ascend.
1318
+ >>> @jit(backend="GE", ge_options={"global": {"ge.opSelectImplmode": "high_precision"}})
1319
+ ... def tensor_add_by_trace(x, y):
1320
+ ... z = x + y
1321
+ ... return z
1322
+ ...
1323
+ >>> out = tensor_add_by_trace(x, y)
1324
+ >>> print(out)
1325
+ Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
1326
+ [[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1327
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
1328
+ [ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
1329
+ ...
1140
1330
  """
1141
1331
 
1142
1332
  capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
@@ -1155,39 +1345,12 @@ def jit(
1155
1345
  jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
1156
1346
  infer_boost=infer_boost, backend=backend, options=options_str)
1157
1347
 
1158
- def wrap_func(func):
1159
- nonlocal hash_obj
1160
- if hash_obj is None or not _is_inner_func(func):
1161
- hash_obj = int(time.time() * 1e9)
1162
-
1163
- @wraps(func)
1164
- def staging_specialize(*args, **kwargs):
1165
- if os.getenv("MS_JIT") == '0':
1166
- return func(*args, **kwargs)
1167
-
1168
- args, kwargs = _handle_func_args(func, *args, **kwargs)
1169
- process_obj = None
1170
- if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
1171
- process_obj = args[0]
1172
- # Handle auto mixed precision strategy.
1173
- if not hasattr(func, "amp_strategy"):
1174
- if isinstance(func, types.MethodType):
1175
- setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
1176
- else:
1177
- setattr(func, "amp_strategy", get_curr_amp_strategy())
1178
-
1179
- ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
1180
- out = ms_function_executor(*args, **kwargs)
1181
- return out
1182
-
1183
- return staging_specialize
1184
-
1185
- if capture_mode == "bytecode":
1186
- wrap_func = PIJitCaptureContext(jit_config)
1187
- elif capture_mode == "trace":
1188
- if function is not None:
1189
- return _jit_trace(function)
1190
- return _jit_trace
1348
+ if capture_mode == "ast":
1349
+ wrap_func = _jit_ast(hash_obj, dynamic, jit_config, '')
1350
+ elif capture_mode == "bytecode":
1351
+ wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
1352
+ else:
1353
+ wrap_func = _jit_trace()
1191
1354
 
1192
1355
  if function is not None:
1193
1356
  return wrap_func(function)
@@ -1503,7 +1666,7 @@ class _PyNativeExecutor:
1503
1666
  """
1504
1667
  self._executor.end_graph(obj, output, *args, *(kwargs.values()))
1505
1668
 
1506
- def check_run(self, grad, obj, weights, grad_hash_id, *args):
1669
+ def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
1507
1670
  """
1508
1671
  Whether the forward graph need to construct.
1509
1672
 
@@ -1516,7 +1679,7 @@ class _PyNativeExecutor:
1516
1679
  Return:
1517
1680
  bool, specifies whether the forward graph needs to construct.
1518
1681
  """
1519
- return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
1682
+ return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, **kwargs)
1520
1683
 
1521
1684
  def grad(self, obj, grad, weights, grad_position, *args):
1522
1685
  """
@@ -1758,6 +1921,19 @@ class _PyNativeExecutor:
1758
1921
  """
1759
1922
  return self._executor.constant_folding(*args)
1760
1923
 
1924
+ def set_creation_type(self, tensor, creation_type):
1925
+ """
1926
+ Set tensor's view creation type
1927
+
1928
+ Args:
1929
+ tensor (Tensor): input tensor.
1930
+ creation_type (CreationType): The type of view tensor when it is created.
1931
+
1932
+ Return:
1933
+ None.
1934
+ """
1935
+ return self._executor.set_creation_type(tensor, creation_type)
1936
+
1761
1937
 
1762
1938
  class _CellGraphExecutor:
1763
1939
  """
@@ -1834,13 +2010,6 @@ class _CellGraphExecutor:
1834
2010
  else:
1835
2011
  _set_dataset_mode_config('normal')
1836
2012
 
1837
- @staticmethod
1838
- def _use_vm_mode():
1839
- enable_ge = context.get_context("enable_ge")
1840
- enable_debug_runtime = context.get_context("enable_debug_runtime")
1841
- exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
1842
- return not enable_ge or (enable_debug_runtime and exe_mode)
1843
-
1844
2013
  def _build_data_graph(self, obj, phase):
1845
2014
  self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
1846
2015
 
@@ -1872,7 +2041,12 @@ class _CellGraphExecutor:
1872
2041
  obj.__parse_method__ = 'construct'
1873
2042
  if not hasattr(obj, obj.__parse_method__):
1874
2043
  raise AttributeError(
1875
- 'The class {} dose not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2044
+ 'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
2045
+ inner_func = inspect.unwrap(obj.construct)
2046
+ if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
2047
+ raise ValueError(
2048
+ "When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
2049
+ )
1876
2050
  key_id = str(id(obj)) + str(obj.create_time)
1877
2051
  args = get_auto_dynamic_shape_args(args, key_id)
1878
2052
 
@@ -1883,20 +2057,25 @@ class _CellGraphExecutor:
1883
2057
  self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
1884
2058
 
1885
2059
  key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
1886
- obj.arguments_key = str(key)
1887
-
1888
- obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
2060
+ key = str(key)
1889
2061
 
1890
2062
  # When exist parameter in the top graph inputs, need check if the parameter object has changed.
1891
2063
  parameter_ids = _get_parameter_ids(args, kwargs)
1892
2064
  if parameter_ids != "":
1893
- obj.arguments_key = obj.arguments_key + '.' + parameter_ids
2065
+ key += '.' + parameter_ids
2066
+
2067
+ key += "." + _get_hook_key(*args, **kwargs)
2068
+ key += "." + str(_hook_version())
2069
+
2070
+ obj.arguments_key = key
2071
+
1894
2072
  raw_phase = phase
1895
- phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2073
+
2074
+ phase = _real_phase(phase, obj)
1896
2075
  obj.phase_cache[raw_phase] = phase
1897
2076
  update_auto_dynamic_shape_phase(args, key_id, phase)
1898
2077
  obj.current_phase = phase
1899
- if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
2078
+ if phase in obj.compile_cache and self.has_compiled(phase):
1900
2079
  logger.debug("%r graph has existed.", phase)
1901
2080
  # Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
1902
2081
  # generated in generate_arguments_key.
@@ -1904,7 +2083,7 @@ class _CellGraphExecutor:
1904
2083
  _clear_auto_parallel_context(obj)
1905
2084
  return phase, False
1906
2085
 
1907
- full_function_name = obj.__class__.__name__ + '.' + str(obj.instance_count) + '.' + str(id(type(obj)))
2086
+ full_function_name = obj.__class__.__name__ + '.' + str(obj.total_instance_count) + '.' + str(id(type(obj)))
1908
2087
  echo_function_name = obj.__class__.__name__
1909
2088
  _check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
1910
2089
 
@@ -1914,17 +2093,14 @@ class _CellGraphExecutor:
1914
2093
  self._set_compile_cache_dep_files(phase)
1915
2094
 
1916
2095
  self._graph_executor.set_weights_values(obj.parameters_dict())
1917
- if jit_config_dict:
1918
- self._graph_executor.set_jit_config(jit_config_dict)
1919
- else:
2096
+ if not jit_config_dict:
1920
2097
  jit_config_dict = JitConfig().jit_config_dict
1921
- self._graph_executor.set_jit_config(jit_config_dict)
1922
2098
  gc.collect()
1923
- result = self._graph_executor.compile(obj, args, kwargs, phase)
2099
+ result = self._graph_executor.compile(
2100
+ obj, args, kwargs, phase, jit_config_dict)
1924
2101
  obj.compile_cache.add(phase)
1925
2102
  if not result:
1926
2103
  raise RuntimeError("Executor compile failed.")
1927
- set_parameter_hook_updated(False)
1928
2104
  graph = self._graph_executor.get_func_graph(phase)
1929
2105
 
1930
2106
  if graph is None:
@@ -1949,15 +2125,15 @@ class _CellGraphExecutor:
1949
2125
  return self._graph_executor.updata_param_node_default_input(phase, new_param)
1950
2126
 
1951
2127
  def _get_shard_strategy(self, obj):
1952
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2128
+ real_phase = _real_phase(obj.phase, obj)
1953
2129
  return self._graph_executor.get_strategy(real_phase)
1954
2130
 
1955
2131
  def _get_num_parallel_ops(self, obj):
1956
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2132
+ real_phase = _real_phase(obj.phase, obj)
1957
2133
  return self._graph_executor.get_num_parallel_ops(real_phase)
1958
2134
 
1959
2135
  def _get_allreduce_fusion(self, obj):
1960
- real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2136
+ real_phase = _real_phase(obj.phase, obj)
1961
2137
  return self._graph_executor.get_allreduce_fusion(real_phase)
1962
2138
 
1963
2139
  def __call__(self, obj, *args, phase='predict'):
@@ -2009,10 +2185,10 @@ class _CellGraphExecutor:
2009
2185
  Tensor/Tuple, return execute result.
2010
2186
  """
2011
2187
  if phase == 'save':
2012
- exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2188
+ exe_phase = _real_phase(phase, obj)
2013
2189
  return self._graph_executor((), exe_phase)
2014
2190
 
2015
- phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2191
+ phase_real = _real_phase(phase, obj)
2016
2192
  if self.has_compiled(phase_real):
2017
2193
  return self._exec_pip(obj, *args, phase=phase_real)
2018
2194
  raise KeyError('{} graph is not exist.'.format(phase_real))
@@ -2039,7 +2215,7 @@ class _CellGraphExecutor:
2039
2215
 
2040
2216
  def get_optimize_graph_proto(self, obj):
2041
2217
  """Return optimize graph binary proto."""
2042
- exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
2218
+ exec_id = _real_phase(obj.phase, obj)
2043
2219
  if self._graph_executor.has_compiled(exec_id) is False:
2044
2220
  return None
2045
2221
  graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
@@ -2121,5 +2297,3 @@ def flops_collection(phase='train'):
2121
2297
 
2122
2298
  _cell_graph_executor = _CellGraphExecutor()
2123
2299
  _pynative_executor = _PyNativeExecutor()
2124
-
2125
- __all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']