mindspore 2.6.0rc1__cp39-cp39-win_amd64.whl → 2.7.0rc1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (384) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +1 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +40 -9
  7. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  8. mindspore/_extends/optimize/cell_utils.py +96 -0
  9. mindspore/_extends/parse/__init__.py +2 -2
  10. mindspore/_extends/parse/compile_config.py +44 -22
  11. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
  12. mindspore/_extends/parse/parser.py +37 -62
  13. mindspore/_extends/parse/resources.py +39 -0
  14. mindspore/_extends/parse/standard_method.py +43 -13
  15. mindspore/_extends/parse/trope.py +8 -1
  16. mindspore/_extends/pijit/__init__.py +1 -2
  17. mindspore/amp.py +4 -4
  18. mindspore/avcodec-59.dll +0 -0
  19. mindspore/avdevice-59.dll +0 -0
  20. mindspore/avfilter-8.dll +0 -0
  21. mindspore/avformat-59.dll +0 -0
  22. mindspore/avutil-57.dll +0 -0
  23. mindspore/boost/adasum.py +1 -1
  24. mindspore/boost/boost_cell_wrapper.py +4 -4
  25. mindspore/common/__init__.py +27 -2
  26. mindspore/common/_grad_function.py +2 -1
  27. mindspore/common/_pijit_context.py +28 -7
  28. mindspore/common/_stub_tensor.py +1 -209
  29. mindspore/common/_tensor_cpp_method.py +1 -1
  30. mindspore/common/_tensor_docs.py +77 -16
  31. mindspore/common/api.py +238 -113
  32. mindspore/common/dtype.py +21 -11
  33. mindspore/common/dump.py +10 -15
  34. mindspore/common/generator.py +5 -3
  35. mindspore/common/hook_handle.py +11 -2
  36. mindspore/common/jit_config.py +1 -1
  37. mindspore/common/jit_trace.py +84 -105
  38. mindspore/common/parameter.py +26 -12
  39. mindspore/common/recompute.py +3 -3
  40. mindspore/common/sparse_tensor.py +0 -3
  41. mindspore/common/symbol.py +0 -1
  42. mindspore/common/tensor.py +81 -81
  43. mindspore/communication/_comm_helper.py +46 -4
  44. mindspore/communication/management.py +79 -7
  45. mindspore/context.py +58 -40
  46. mindspore/dataset/core/config.py +3 -3
  47. mindspore/dataset/engine/datasets.py +20 -7
  48. mindspore/dataset/engine/datasets_user_defined.py +33 -3
  49. mindspore/dataset/engine/iterators.py +2 -2
  50. mindspore/dataset/engine/obs/config_loader.py +2 -2
  51. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  52. mindspore/dataset/transforms/py_transforms.py +7 -3
  53. mindspore/dataset/transforms/transforms.py +7 -3
  54. mindspore/dataset/vision/validators.py +1 -0
  55. mindspore/device_context/ascend/device.py +1 -1
  56. mindspore/device_context/gpu/__init__.py +2 -2
  57. mindspore/device_context/gpu/device.py +1 -1
  58. mindspore/device_context/gpu/op_precision.py +4 -2
  59. mindspore/device_context/gpu/op_tuning.py +6 -3
  60. mindspore/device_manager.py +16 -9
  61. mindspore/dnnl.dll +0 -0
  62. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -7
  63. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  64. mindspore/experimental/optim/adadelta.py +13 -20
  65. mindspore/experimental/optim/adagrad.py +15 -22
  66. mindspore/experimental/optim/adam.py +17 -24
  67. mindspore/experimental/optim/adamax.py +14 -22
  68. mindspore/experimental/optim/adamw.py +28 -34
  69. mindspore/experimental/optim/asgd.py +15 -25
  70. mindspore/experimental/optim/lr_scheduler.py +27 -45
  71. mindspore/experimental/optim/nadam.py +14 -24
  72. mindspore/experimental/optim/optimizer.py +13 -23
  73. mindspore/experimental/optim/radam.py +18 -24
  74. mindspore/experimental/optim/rmsprop.py +14 -25
  75. mindspore/experimental/optim/rprop.py +15 -26
  76. mindspore/experimental/optim/sgd.py +9 -19
  77. mindspore/hal/__init__.py +4 -4
  78. mindspore/hal/contiguous_tensors_handle.py +2 -2
  79. mindspore/hal/memory.py +27 -7
  80. mindspore/include/api/cell.h +37 -1
  81. mindspore/include/api/delegate.h +10 -0
  82. mindspore/include/api/model.h +3 -0
  83. mindspore/include/api/types.h +2 -2
  84. mindspore/include/c_api/model_c.h +0 -58
  85. mindspore/include/c_api/tensor_c.h +0 -26
  86. mindspore/include/dataset/vision_ascend.h +1 -1
  87. mindspore/jpeg62.dll +0 -0
  88. mindspore/mindrecord/tools/cifar10.py +60 -11
  89. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  90. mindspore/mindspore_backend_common.dll +0 -0
  91. mindspore/mindspore_backend_manager.dll +0 -0
  92. mindspore/mindspore_common.dll +0 -0
  93. mindspore/mindspore_core.dll +0 -0
  94. mindspore/mindspore_cpu_res_manager.dll +0 -0
  95. mindspore/mindspore_dump.dll +0 -0
  96. mindspore/mindspore_frontend.dll +0 -0
  97. mindspore/mindspore_glog.dll +0 -0
  98. mindspore/mindspore_memory_pool.dll +0 -0
  99. mindspore/mindspore_ms_backend.dll +0 -0
  100. mindspore/mindspore_ops.dll +0 -0
  101. mindspore/mindspore_ops_host.dll +0 -0
  102. mindspore/mindspore_ops_kernel_common.dll +0 -0
  103. mindspore/mindspore_profiler.dll +0 -0
  104. mindspore/mindspore_pyboost.dll +0 -0
  105. mindspore/mindspore_pynative.dll +0 -0
  106. mindspore/mindspore_res_manager.dll +0 -0
  107. mindspore/mindspore_runtime_pipeline.dll +0 -0
  108. mindspore/mint/__init__.py +6 -46
  109. mindspore/mint/distributed/__init__.py +1 -0
  110. mindspore/mint/distributed/distributed.py +212 -9
  111. mindspore/mint/nn/__init__.py +1 -1
  112. mindspore/mint/nn/functional.py +53 -6
  113. mindspore/mint/nn/layer/_functions.py +164 -294
  114. mindspore/mint/nn/layer/activation.py +8 -6
  115. mindspore/mint/nn/layer/conv.py +137 -101
  116. mindspore/mint/nn/layer/normalization.py +8 -22
  117. mindspore/mint/optim/adam.py +19 -18
  118. mindspore/mint/optim/adamw.py +14 -8
  119. mindspore/mint/optim/sgd.py +5 -5
  120. mindspore/nn/cell.py +328 -502
  121. mindspore/nn/grad/cell_grad.py +11 -12
  122. mindspore/nn/layer/activation.py +32 -34
  123. mindspore/nn/layer/basic.py +67 -64
  124. mindspore/nn/layer/channel_shuffle.py +4 -4
  125. mindspore/nn/layer/combined.py +4 -2
  126. mindspore/nn/layer/conv.py +117 -110
  127. mindspore/nn/layer/dense.py +9 -7
  128. mindspore/nn/layer/embedding.py +50 -52
  129. mindspore/nn/layer/image.py +37 -39
  130. mindspore/nn/layer/math.py +111 -112
  131. mindspore/nn/layer/normalization.py +56 -44
  132. mindspore/nn/layer/pooling.py +58 -63
  133. mindspore/nn/layer/rnn_cells.py +33 -33
  134. mindspore/nn/layer/rnns.py +56 -56
  135. mindspore/nn/layer/thor_layer.py +74 -73
  136. mindspore/nn/layer/transformer.py +11 -1
  137. mindspore/nn/learning_rate_schedule.py +20 -20
  138. mindspore/nn/loss/loss.py +79 -81
  139. mindspore/nn/optim/adam.py +3 -3
  140. mindspore/nn/optim/adasum.py +2 -2
  141. mindspore/nn/optim/asgd.py +2 -0
  142. mindspore/nn/optim/optimizer.py +1 -1
  143. mindspore/nn/optim/thor.py +2 -2
  144. mindspore/nn/probability/distribution/exponential.py +2 -1
  145. mindspore/nn/probability/distribution/poisson.py +2 -1
  146. mindspore/nn/sparse/sparse.py +3 -3
  147. mindspore/nn/wrap/cell_wrapper.py +34 -37
  148. mindspore/nn/wrap/grad_reducer.py +37 -37
  149. mindspore/nn/wrap/loss_scale.py +72 -74
  150. mindspore/numpy/array_creations.py +5 -5
  151. mindspore/numpy/fft.py +1 -1
  152. mindspore/numpy/math_ops.py +5 -5
  153. mindspore/opencv_core452.dll +0 -0
  154. mindspore/opencv_imgcodecs452.dll +0 -0
  155. mindspore/opencv_imgproc452.dll +0 -0
  156. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  157. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  158. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  159. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  160. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +42 -11
  161. mindspore/ops/auto_generate/gen_extend_func.py +23 -141
  162. mindspore/ops/auto_generate/gen_ops_def.py +727 -321
  163. mindspore/ops/auto_generate/gen_ops_prim.py +1721 -984
  164. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  165. mindspore/ops/composite/__init__.py +10 -0
  166. mindspore/ops/composite/base.py +8 -4
  167. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  168. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  169. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  170. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  171. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  172. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  173. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  174. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  175. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  176. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  177. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  178. mindspore/ops/function/__init__.py +3 -1
  179. mindspore/ops/function/_add_attr_func.py +11 -6
  180. mindspore/ops/function/array_func.py +9 -96
  181. mindspore/ops/function/debug_func.py +4 -3
  182. mindspore/ops/function/grad/grad_func.py +1 -1
  183. mindspore/ops/function/math_func.py +33 -540
  184. mindspore/ops/function/nn_func.py +28 -74
  185. mindspore/ops/function/other_func.py +4 -1
  186. mindspore/ops/function/random_func.py +44 -5
  187. mindspore/ops/function/vmap_func.py +2 -1
  188. mindspore/ops/functional.py +2 -3
  189. mindspore/ops/functional_overload.py +571 -6
  190. mindspore/ops/op_info_register.py +21 -0
  191. mindspore/ops/operations/__init__.py +16 -11
  192. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  193. mindspore/ops/operations/_inner_ops.py +3 -6
  194. mindspore/ops/operations/_sequence_ops.py +1 -1
  195. mindspore/ops/operations/array_ops.py +2 -2
  196. mindspore/ops/operations/comm_ops.py +185 -26
  197. mindspore/ops/operations/custom_ops.py +294 -174
  198. mindspore/ops/operations/debug_ops.py +59 -4
  199. mindspore/ops/operations/image_ops.py +13 -13
  200. mindspore/ops/operations/manually_defined/ops_def.py +15 -16
  201. mindspore/ops/operations/math_ops.py +3 -4
  202. mindspore/ops/operations/nn_ops.py +7 -39
  203. mindspore/ops/primitive.py +6 -10
  204. mindspore/ops/tensor_method.py +47 -8
  205. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  206. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  207. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  208. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  209. mindspore/ops_generate/common/base_generator.py +14 -0
  210. mindspore/ops_generate/common/gen_constants.py +8 -3
  211. mindspore/ops_generate/common/gen_utils.py +0 -19
  212. mindspore/ops_generate/common/op_proto.py +11 -4
  213. mindspore/ops_generate/common/template.py +88 -11
  214. mindspore/ops_generate/gen_ops.py +1 -1
  215. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  216. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  217. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  218. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  219. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  220. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  221. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  222. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
  223. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  224. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  225. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  226. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  227. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  228. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  229. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  230. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  231. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  232. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  233. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  234. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  235. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  236. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  237. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  238. mindspore/parallel/_auto_parallel_context.py +11 -8
  239. mindspore/parallel/_cell_wrapper.py +113 -45
  240. mindspore/parallel/_parallel_serialization.py +1 -1
  241. mindspore/parallel/_ps_context.py +4 -6
  242. mindspore/parallel/_tensor.py +167 -12
  243. mindspore/parallel/_transformer/moe.py +1 -1
  244. mindspore/parallel/_transformer/transformer.py +13 -8
  245. mindspore/parallel/auto_parallel.py +14 -7
  246. mindspore/parallel/checkpoint_convert.py +3 -3
  247. mindspore/parallel/checkpoint_transform.py +11 -7
  248. mindspore/parallel/cluster/process_entity/_api.py +84 -48
  249. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  250. mindspore/parallel/cluster/run.py +43 -4
  251. mindspore/parallel/function/__init__.py +8 -1
  252. mindspore/parallel/function/reshard_func.py +6 -7
  253. mindspore/parallel/nn/__init__.py +15 -2
  254. mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
  255. mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
  256. mindspore/parallel/shard.py +3 -4
  257. mindspore/parallel/transform_safetensors.py +463 -174
  258. mindspore/profiler/__init__.py +2 -1
  259. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  260. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  261. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  262. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  263. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  264. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  265. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  266. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  267. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  268. mindspore/profiler/analysis/task_manager.py +1 -1
  269. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  270. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
  272. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  273. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  274. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  275. mindspore/profiler/common/constant.py +16 -0
  276. mindspore/profiler/common/profiler_context.py +25 -27
  277. mindspore/profiler/common/profiler_info.py +0 -16
  278. mindspore/profiler/common/profiler_op_analyse.py +235 -0
  279. mindspore/profiler/common/profiler_output_path.py +23 -8
  280. mindspore/profiler/common/profiler_parameters.py +128 -35
  281. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  282. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  283. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  284. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  285. mindspore/profiler/dynamic_profiler.py +305 -314
  286. mindspore/profiler/envprofiler.py +12 -7
  287. mindspore/profiler/experimental_config.py +96 -6
  288. mindspore/profiler/mstx.py +33 -12
  289. mindspore/profiler/platform/__init__.py +2 -3
  290. mindspore/profiler/platform/npu_profiler.py +29 -19
  291. mindspore/profiler/profiler.py +35 -19
  292. mindspore/profiler/profiler_action_controller.py +64 -76
  293. mindspore/profiler/schedule.py +10 -4
  294. mindspore/rewrite/common/config.py +1 -0
  295. mindspore/rewrite/common/namer.py +1 -0
  296. mindspore/rewrite/common/namespace.py +1 -0
  297. mindspore/rewrite/node/node.py +31 -11
  298. mindspore/rewrite/parsers/assign_parser.py +1 -1
  299. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  300. mindspore/run_check/_check_version.py +7 -10
  301. mindspore/runtime/__init__.py +5 -5
  302. mindspore/runtime/event.py +10 -4
  303. mindspore/runtime/executor.py +60 -45
  304. mindspore/runtime/memory.py +30 -32
  305. mindspore/runtime/thread_bind_core.py +298 -164
  306. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  307. mindspore/swresample-4.dll +0 -0
  308. mindspore/swscale-6.dll +0 -0
  309. mindspore/tinyxml2.dll +0 -0
  310. mindspore/train/_utils.py +14 -4
  311. mindspore/train/amp.py +43 -20
  312. mindspore/train/callback/__init__.py +5 -5
  313. mindspore/train/callback/_checkpoint.py +3 -6
  314. mindspore/train/callback/_flops_collector.py +1 -1
  315. mindspore/train/callback/_landscape.py +0 -1
  316. mindspore/train/callback/_train_fault_tolerance.py +97 -16
  317. mindspore/train/data_sink.py +11 -2
  318. mindspore/train/dataset_helper.py +9 -0
  319. mindspore/train/model.py +135 -55
  320. mindspore/train/serialization.py +133 -111
  321. mindspore/train/summary/summary_record.py +13 -2
  322. mindspore/turbojpeg.dll +0 -0
  323. mindspore/utils/__init__.py +3 -2
  324. mindspore/utils/dryrun.py +0 -6
  325. mindspore/utils/runtime_execution_order_check.py +163 -77
  326. mindspore/utils/sdc_detect.py +68 -0
  327. mindspore/utils/utils.py +6 -9
  328. mindspore/version.py +1 -1
  329. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
  330. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +333 -371
  331. mindspore/_deprecated/jit.py +0 -198
  332. mindspore/experimental/es/__init__.py +0 -22
  333. mindspore/experimental/es/embedding_service.py +0 -891
  334. mindspore/experimental/es/embedding_service_layer.py +0 -581
  335. mindspore/profiler/parser/__init__.py +0 -14
  336. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  337. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  338. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  339. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  340. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  341. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  342. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  343. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  344. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  345. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  346. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  347. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  348. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  349. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  350. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  351. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  352. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  353. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  354. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  355. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  356. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  357. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  358. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  359. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  360. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  361. mindspore/profiler/parser/container.py +0 -229
  362. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  363. mindspore/profiler/parser/flops_parser.py +0 -531
  364. mindspore/profiler/parser/framework_enum.py +0 -111
  365. mindspore/profiler/parser/framework_parser.py +0 -464
  366. mindspore/profiler/parser/framework_struct.py +0 -61
  367. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  368. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  369. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  370. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  371. mindspore/profiler/parser/hccl_parser.py +0 -573
  372. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  373. mindspore/profiler/parser/integrator.py +0 -526
  374. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  375. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  376. mindspore/profiler/parser/minddata_parser.py +0 -186
  377. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  378. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  379. mindspore/profiler/parser/optime_parser.py +0 -250
  380. mindspore/profiler/parser/profiler_info.py +0 -213
  381. mindspore/profiler/parser/step_trace_parser.py +0 -666
  382. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
  383. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
  384. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -15,6 +15,10 @@
15
15
  """cell"""
16
16
  from __future__ import absolute_import
17
17
 
18
+ __all__ = [
19
+ "register_cell_buffer_registration_hook",
20
+ ]
21
+
18
22
  import inspect
19
23
  import os
20
24
  import time
@@ -24,7 +28,6 @@ from collections import OrderedDict, namedtuple
24
28
  from typing import (
25
29
  Dict,
26
30
  Optional,
27
- Set,
28
31
  Callable,
29
32
  List,
30
33
  Tuple,
@@ -34,36 +37,29 @@ from typing import (
34
37
  Mapping
35
38
  )
36
39
 
40
+ import weakref
37
41
  import mindspore as ms
38
42
  from mindspore._checkparam import args_type_check, check_hook_fn
39
43
  from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
40
44
  from mindspore import log as logger
41
- from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
42
45
  from mindspore.common.hook_handle import HookHandle
43
- from mindspore.context import ParallelMode
44
46
  from mindspore import context
45
47
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
46
48
  from mindspore import _checkparam as Validator
47
49
  from mindspore.common import dtype as mstype
48
50
  from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
49
- _no_grad
50
- from mindspore.common.api import _convert_python_data, _get_args_for_run_predict
51
+ _no_grad, _get_mutable_flags
52
+ from mindspore.common.api import _convert_python_data
51
53
  from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
52
- from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple
54
+ from mindspore.common.parameter import _Buffer, Parameter, ParameterTuple, _is_parameter_generated
53
55
  from mindspore.common.tensor import Tensor
54
- from mindspore.ops.operations import Cast
55
56
  from mindspore.ops.primitive import Primitive
56
57
  from mindspore.ops.operations import _inner_ops as inner
57
58
  from mindspore.parallel.shard import Shard
58
59
  from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
59
60
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
60
- from mindspore.common._decorator import deprecated
61
61
  from mindspore.common._register_for_recompute import recompute_registry
62
-
63
-
64
- __all__ = [
65
- "register_cell_buffer_registration_hook",
66
- ]
62
+ from mindspore.common.jit_config import JitConfig
67
63
 
68
64
  _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
69
65
  _EXTRA_STATE_KEY_SUFFIX = "_extra_state"
@@ -102,7 +98,6 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
102
98
  return handle
103
99
 
104
100
 
105
-
106
101
  class Cell(Cell_):
107
102
  """
108
103
  The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
@@ -160,51 +155,57 @@ class Cell(Cell_):
160
155
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
161
156
  '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
162
157
  '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
163
- '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
164
- '_attr_synced', 'pynative', 'requires_grad', 'cell_type',
165
- '_parameters_forward_hook', '_parameters_backward_hook']
158
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix', 'requires_grad', 'cell_type']
166
159
  total_instance_count = 0
167
160
  _buffers: Dict[str, Optional[Tensor]]
168
- _non_persistent_buffers_set: Set[str]
161
+ global_cells = weakref.WeakKeyDictionary()
162
+ _no_auto_lazy_inline = True
163
+
164
+ def __new__(class_, *args, **kwargs):
165
+ # Use class_ to avoid name conflicts with input args and kwargs.
166
+ this = Cell_.__new__(class_, *args, **kwargs)
167
+ if Cell._no_auto_lazy_inline:
168
+ return this
169
+
170
+ Cell.global_cells[this] = (class_, args, kwargs)
171
+ return this
169
172
 
170
173
  def __init__(self, auto_prefix=True, flags=None):
171
174
  Cell_.__init__(self, self._cell_tag)
172
175
  Cell.total_instance_count += 1
173
- self.instance_count = Cell.total_instance_count
174
- self._params = OrderedDict()
175
- self._cells = OrderedDict()
176
+ super().__setattr__("_params", OrderedDict())
177
+ super().__setattr__("_cells", OrderedDict())
176
178
  super().__setattr__("_buffers", {})
177
- super().__setattr__("_non_persistent_buffers_set", set())
178
- super().__setattr__("_state_dict_hooks", OrderedDict())
179
- super().__setattr__("_state_dict_pre_hooks", OrderedDict())
180
- super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
181
- super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
182
- self._params_list = OrderedDict()
183
- self._primitives = OrderedDict()
184
- self.training = False
185
- self.requires_grad = False
186
- self.is_top_cell = False
187
- self.pynative = False
188
- self._attr_synced = False
189
- self._param_prefix = ''
190
- self._auto_prefix = auto_prefix
191
- self._scope = None
192
- self._phase = 'train'
193
- self._parameter_layout_dict = {}
194
- self._parallel_parameter_name_list = ()
195
- self._parallel_parameter_merge_net_dict = {}
196
- self._create_time = int(time.time() * 1e9)
197
- self.arguments_key = ""
198
- self.compile_cache = set()
199
- self.phase_cache = dict()
179
+ super().__setattr__("_params_list", OrderedDict())
180
+ super().__setattr__("_primitives", OrderedDict())
181
+
182
+ super().__setattr__("_lazy_non_persistent_buffers_set", None)
183
+ super().__setattr__("_lazy_state_dict_hooks", None)
184
+ super().__setattr__("_lazy_state_dict_pre_hooks", None)
185
+ super().__setattr__("_lazy_load_state_dict_pre_hooks", None)
186
+ super().__setattr__("_lazy_load_state_dict_post_hooks", None)
187
+ super().__setattr__("training", False)
188
+ super().__setattr__("requires_grad", False)
189
+ super().__setattr__("is_top_cell", False)
190
+ super().__setattr__("_param_prefix", '')
191
+ super().__setattr__("_auto_prefix", auto_prefix)
192
+ super().__setattr__("_scope", None)
193
+ super().__setattr__("_phase", 'train')
194
+ super().__setattr__("_parameter_layout_dict", None)
195
+ super().__setattr__("_parallel_parameter_name_list", None)
196
+ super().__setattr__("_parallel_parameter_merge_net_dict", None)
197
+ super().__setattr__("_create_time", int(time.time() * 1e9))
198
+ super().__setattr__("arguments_key", "")
199
+ super().__setattr__("_compile_cache", None)
200
+ super().__setattr__("_phase_cache", None)
200
201
  cells_compile_cache[id(self)] = self.compile_cache
201
- self.parameter_broadcast_done = False
202
- self._id = 1
203
- self._exist_objs = None
204
- self._exist_names = None
205
- self._recompute_cell = None
206
- self.mixed_precision_type = None
207
- self.sig = inspect.signature(self.construct)
202
+ super().__setattr__("_id", 1)
203
+ super().__setattr__("_exist_objs", None)
204
+ super().__setattr__("_exist_names", None)
205
+ super().__setattr__("_recompute_cell", None)
206
+ super().__setattr__("mixed_precision_type", None)
207
+ super().__setattr__("_lazy_construct_sig", None)
208
+ super().__setattr__("_jit_graph_name", '')
208
209
  init_pipeline()
209
210
 
210
211
  # call gc to release GE session resources used by non-used cell objects
@@ -214,38 +215,33 @@ class Cell(Cell_):
214
215
 
215
216
  if flags:
216
217
  self.add_flags(**flags)
217
- self._bprop_debug = False
218
+ super().__setattr__("_bprop_debug", False)
218
219
 
219
220
  # hook
220
- self._forward_pre_hook = OrderedDict()
221
- self._forward_hook = OrderedDict()
222
- self._backward_pre_hook = OrderedDict()
223
- self._cell_backward_pre_hook = None
224
- self._backward_hook = OrderedDict()
225
- self._cell_backward_hook = None
226
- self._is_recursion_hook = False
227
-
228
- # parameters hook
229
- self._parameters_forward_hook = None
230
- self._parameters_backward_hook = None
231
-
232
- self.cell_type = None
233
- self.cast = Cast()
234
- self._has_config_recompute = False
235
- self._user_parameters = []
236
- self._dynamic_shape_inputs = None
237
- self._compile_args = None
238
- self.saved_dynamic_shape = None
239
- self._jit_config_dict = dict()
240
- self.grad_ops_label = False
241
- self.ge_sync_data = False
242
- self._is_check_and_refresh = False
243
- self._amp_level = ""
244
- self._init_flag = False
245
- self._shard_fn = None
246
- self.has_bprop = False
221
+ super().__setattr__("_lazy_forward_pre_hook", None)
222
+ super().__setattr__("_lazy_forward_hook", None)
223
+ super().__setattr__("_lazy_backward_pre_hook", None)
224
+ super().__setattr__("_lazy_backward_hook", None)
225
+ super().__setattr__("_lazy_forward_pre_hook_with_kwargs", None)
226
+ super().__setattr__("_lazy_forward_hook_with_kwargs", None)
227
+ super().__setattr__("_cell_backward_pre_hook", None)
228
+ super().__setattr__("_cell_backward_hook", None)
229
+ super().__setattr__("_is_recursion_hook", False)
230
+
231
+ super().__setattr__("cell_type", None)
232
+ super().__setattr__("_has_config_recompute", False)
233
+ super().__setattr__("_lazy_user_parameters", None)
234
+ super().__setattr__("_dynamic_shape_inputs", None)
235
+ super().__setattr__("_has_mutable_args_list", None)
236
+ super().__setattr__("_jit_config_dict", dict())
237
+ super().__setattr__("grad_ops_label", False)
238
+ super().__setattr__("_is_check_and_refresh", False)
239
+ super().__setattr__("_amp_level", "")
240
+ super().__setattr__("_init_flag", False)
241
+ super().__setattr__("_shard_fn", None)
242
+ super().__setattr__("has_bprop", False)
247
243
  if hasattr(self, "bprop"):
248
- self.has_bprop = True
244
+ super().__setattr__("has_bprop", True)
249
245
 
250
246
  def __getstate__(self):
251
247
  base = Cell_.__getstate__(self)
@@ -255,7 +251,6 @@ class Cell(Cell_):
255
251
  base, dict_ = state
256
252
  Cell_.__setstate__(self, base)
257
253
  self.__dict__ = dict_
258
- self._attr_synced = False
259
254
 
260
255
  def __bool__(self):
261
256
  return True
@@ -269,6 +264,112 @@ class Cell(Cell_):
269
264
  def create_time(self):
270
265
  return self._create_time
271
266
 
267
+ @property
268
+ def _non_persistent_buffers_set(self):
269
+ """_non_persistent_buffers_set"""
270
+ if self._lazy_non_persistent_buffers_set is None:
271
+ super().__setattr__("_lazy_non_persistent_buffers_set", set())
272
+ return self._lazy_non_persistent_buffers_set
273
+
274
+ @property
275
+ def _state_dict_hooks(self):
276
+ """_state_dict_hooks"""
277
+ if self._lazy_state_dict_hooks is None:
278
+ super().__setattr__("_lazy_state_dict_hooks", OrderedDict())
279
+ return self._lazy_state_dict_hooks
280
+
281
+ @property
282
+ def _state_dict_pre_hooks(self):
283
+ """_state_dict_pre_hooks"""
284
+ if self._lazy_state_dict_pre_hooks is None:
285
+ super().__setattr__("_lazy_state_dict_pre_hooks", OrderedDict())
286
+ return self._lazy_state_dict_pre_hooks
287
+
288
+ @property
289
+ def _load_state_dict_pre_hooks(self):
290
+ """_load_state_dict_pre_hooks"""
291
+ if self._lazy_load_state_dict_pre_hooks is None:
292
+ super().__setattr__("_lazy_load_state_dict_pre_hooks", OrderedDict())
293
+ return self._lazy_load_state_dict_pre_hooks
294
+
295
+ @property
296
+ def _load_state_dict_post_hooks(self):
297
+ """_load_state_dict_post_hooks"""
298
+ if self._lazy_load_state_dict_post_hooks is None:
299
+ super().__setattr__("_lazy_load_state_dict_post_hooks", OrderedDict())
300
+ return self._lazy_load_state_dict_post_hooks
301
+
302
+ @property
303
+ def compile_cache(self):
304
+ """compile_cache"""
305
+ if self._compile_cache is None:
306
+ super().__setattr__("_compile_cache", set())
307
+ return self._compile_cache
308
+
309
+ @property
310
+ def phase_cache(self):
311
+ """phase_cache"""
312
+ if self._phase_cache is None:
313
+ super().__setattr__("_phase_cache", dict())
314
+ return self._phase_cache
315
+
316
+ @property
317
+ def _forward_pre_hook(self):
318
+ """_forward_pre_hook"""
319
+ if self._lazy_forward_pre_hook is None:
320
+ super().__setattr__("_lazy_forward_pre_hook", OrderedDict())
321
+ return self._lazy_forward_pre_hook
322
+
323
+ @property
324
+ def _forward_hook(self):
325
+ """_forward_hook"""
326
+ if self._lazy_forward_hook is None:
327
+ super().__setattr__("_lazy_forward_hook", OrderedDict())
328
+ return self._lazy_forward_hook
329
+
330
+ @property
331
+ def _backward_pre_hook(self):
332
+ """_backward_pre_hook"""
333
+ if self._lazy_backward_pre_hook is None:
334
+ super().__setattr__("_lazy_backward_pre_hook", OrderedDict())
335
+ return self._lazy_backward_pre_hook
336
+
337
+ @property
338
+ def _backward_hook(self):
339
+ """_backward_hook"""
340
+ if self._lazy_backward_hook is None:
341
+ super().__setattr__("_lazy_backward_hook", OrderedDict())
342
+ return self._lazy_backward_hook
343
+
344
+ @property
345
+ def _forward_pre_hook_with_kwargs(self):
346
+ """_backward_hook"""
347
+ if self._lazy_forward_pre_hook_with_kwargs is None:
348
+ super().__setattr__("_lazy_forward_pre_hook_with_kwargs", OrderedDict())
349
+ return self._lazy_forward_pre_hook_with_kwargs
350
+
351
+ @property
352
+ def _forward_hook_with_kwargs(self):
353
+ """_backward_hook"""
354
+ if self._lazy_forward_hook_with_kwargs is None:
355
+ super().__setattr__("_lazy_forward_hook_with_kwargs", OrderedDict())
356
+ return self._lazy_forward_hook_with_kwargs
357
+
358
+ @property
359
+ def _user_parameters(self):
360
+ """_user_parameters"""
361
+ if self._lazy_user_parameters is None:
362
+ super().__setattr__("_lazy_user_parameters", [])
363
+ return self._lazy_user_parameters
364
+
365
+ @_user_parameters.setter
366
+ def _user_parameters(self, value):
367
+ """_user_parameters"""
368
+ if not isinstance(value, list):
369
+ raise TypeError(f"For 'Cell', the property '_user_parameters' must be list type, "
370
+ f"but got type {type(value)}.")
371
+ self._lazy_user_parameters = value
372
+
272
373
  @property
273
374
  def cell_init_args(self):
274
375
  return self._cell_init_args
@@ -279,15 +380,21 @@ class Cell(Cell_):
279
380
  Get exist parameter names adding by tuple or list of parameter.
280
381
  """
281
382
  if self._exist_names is None:
282
- self._exist_names = set("")
383
+ super().__setattr__("_exist_names", set(""))
283
384
  return self._exist_names
284
385
 
285
386
  @property
286
387
  def exist_objs(self):
287
388
  if self._exist_objs is None:
288
- self._exist_objs = set()
389
+ super().__setattr__("_exist_objs", set())
289
390
  return self._exist_objs
290
391
 
392
+ @property
393
+ def _construct_sig(self):
394
+ if self._lazy_construct_sig is None:
395
+ super().__setattr__("_lazy_construct_sig", inspect.signature(self.construct))
396
+ return self._lazy_construct_sig
397
+
291
398
  @property
292
399
  def param_prefix(self):
293
400
  """
@@ -381,6 +488,8 @@ class Cell(Cell_):
381
488
  `parameter_layout_dict` represents the tensor layout of a parameter, which is inferred by shard strategy and
382
489
  distributed operator information.
383
490
  """
491
+ if self._parameter_layout_dict is None:
492
+ super().__setattr__("_parameter_layout_dict", {})
384
493
  return self._parameter_layout_dict
385
494
 
386
495
  @property
@@ -396,6 +505,8 @@ class Cell(Cell_):
396
505
 
397
506
  @property
398
507
  def parallel_parameter_name_list(self):
508
+ if self._parallel_parameter_name_list is None:
509
+ super().__setattr__("_parallel_parameter_name_list", ())
399
510
  return self._parallel_parameter_name_list
400
511
 
401
512
  @parallel_parameter_name_list.setter
@@ -450,6 +561,8 @@ class Cell(Cell_):
450
561
 
451
562
  @property
452
563
  def parallel_parameter_merge_net_dict(self):
564
+ if self._parallel_parameter_merge_net_dict is None:
565
+ super().__setattr__("_parallel_parameter_merge_net_dict", {})
453
566
  return self._parallel_parameter_merge_net_dict
454
567
 
455
568
  @parallel_parameter_merge_net_dict.setter
@@ -867,6 +980,7 @@ class Cell(Cell_):
867
980
  if hasattr(self, "compile_cache") and self.compile_cache:
868
981
  _cell_graph_executor.del_net_res(self, self.compile_cache)
869
982
  Cell.total_instance_count -= 1
983
+ Cell.global_cells.pop(self, None)
870
984
 
871
985
  def __delattr__(self, name):
872
986
  if name in self._params:
@@ -879,47 +993,15 @@ class Cell(Cell_):
879
993
  del self._params_list[name]
880
994
  else:
881
995
  object.__delattr__(self, name)
882
- self._attr_synced = False
883
-
884
- def _cast_mixed_precision_inputs(self, inputs, dst_type):
885
- """Cast input for mixed precision"""
886
- res = list()
887
- for item in inputs:
888
- if isinstance(item, tuple):
889
- res.append(self._cast_mixed_precision_inputs(item, dst_type))
890
- elif isinstance(item, float):
891
- res.append(self.cast(item, dst_type))
892
- elif hasattr(item, "dtype") and item.dtype in \
893
- {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
894
- res.append(self.cast(item, dst_type))
895
- else:
896
- res.append(item)
897
- return tuple(res)
898
996
 
899
997
  def cast_inputs(self, inputs, dst_type):
900
998
  """
901
999
  Cast inputs to specified type.
902
1000
 
903
- Args:
904
- inputs (tuple[Tensor]): The cell inputs.
905
- dst_type (mindspore.dtype): The specified data type.
906
-
907
- returns:
908
- tuple[Tensor], the result with destination data type.
1001
+ .. warning::
1002
+ This interface will be deprecated in future versions.
909
1003
  """
910
- res = list()
911
- for item in inputs:
912
- if isinstance(item, tuple):
913
- res.append(self.cast_inputs(item, dst_type))
914
- else:
915
- res.append(self.cast(item, dst_type))
916
- return tuple(res)
917
-
918
- def _do_parameter_broadcast(self):
919
- if context.get_auto_parallel_context("parallel_mode") == ParallelMode.DATA_PARALLEL:
920
- if not self.parameter_broadcast_done:
921
- _pynative_executor.parameter_broadcast(self, self.phase)
922
- self.parameter_broadcast_done = True
1004
+ logger.warning(f"'cast_inputs' will be deprecated in future versions.")
923
1005
 
924
1006
  def run_construct(self, cast_inputs, kwargs):
925
1007
  """
@@ -940,29 +1022,29 @@ class Cell(Cell_):
940
1022
  output = self._run_construct(cast_inputs, kwargs)
941
1023
  return output
942
1024
 
943
- def _run_construct(self, *inputs, **kwargs):
1025
+ def _run_construct(self, *args, **kwargs):
944
1026
  """Run the construct function"""
945
1027
  if self._forward_pre_hook:
946
- inputs = self._run_forward_pre_hook(inputs)
1028
+ args, kwargs = self._run_forward_pre_hook(args, kwargs)
947
1029
 
948
1030
  if self._shard_fn is not None:
949
- output = self._shard_fn(*inputs, **kwargs)
1031
+ output = self._shard_fn(*args, **kwargs)
950
1032
  elif _pynative_executor.requires_grad():
951
1033
  if self._backward_hook:
952
- output = self._backward_hook_construct(*inputs, **kwargs)
1034
+ output = self._backward_hook_construct(*args, **kwargs)
953
1035
  elif self._recompute_cell is not None:
954
- output = self._recompute_cell(*inputs, **kwargs)
1036
+ output = self._recompute_cell(*args, **kwargs)
955
1037
  elif self.has_bprop:
956
- output = self._call_custom_bprop(*inputs, **kwargs)
1038
+ output = self._call_custom_bprop(*args, **kwargs)
957
1039
  else:
958
- output = self.construct(*inputs, **kwargs)
1040
+ output = self.construct(*args, **kwargs)
959
1041
  else:
960
- output = self.construct(*inputs, **kwargs)
1042
+ output = self.construct(*args, **kwargs)
961
1043
 
962
1044
  if self._forward_hook:
963
- output = self._run_forward_hook(inputs, output)
1045
+ output = self._run_forward_hook(args, kwargs, output)
964
1046
 
965
- if self._backward_pre_hook:
1047
+ if self._backward_pre_hook and _pynative_executor.requires_grad():
966
1048
  output = self._run_backward_pre_hook(output)
967
1049
 
968
1050
  return output
@@ -998,6 +1080,7 @@ class Cell(Cell_):
998
1080
  f"{default_args} default argument, total {positional_args + default_args}, "
999
1081
  f"but got {len(args)}.")
1000
1082
 
1083
+ # pylint: disable=E0203
1001
1084
  def _hook_fn_registered(self):
1002
1085
  '''Hook function in graph mode'''
1003
1086
  # Check super().__init__() in graph mode.
@@ -1141,9 +1224,9 @@ class Cell(Cell_):
1141
1224
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
1142
1225
 
1143
1226
  Note:
1144
- If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
1145
- "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
1146
- If the input contain Parameter, its strategy should be set in `in_strategy`.
1227
+ - It is valid only in semi auto parallel or auto parallel mode.
1228
+ In other parallel modes, strategies set here will be ignored.
1229
+ - If the input contain Parameter, its strategy should be set in `in_strategy`.
1147
1230
 
1148
1231
  Args:
1149
1232
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
@@ -1196,27 +1279,6 @@ class Cell(Cell_):
1196
1279
  self._shard_fn = fn
1197
1280
  return fn
1198
1281
 
1199
- def auto_cast_inputs(self, inputs):
1200
- """
1201
- Auto cast inputs in mixed precision scenarios.
1202
-
1203
- Args:
1204
- inputs (tuple): the inputs of construct.
1205
-
1206
- Returns:
1207
- Tuple, the inputs after data type cast.
1208
- """
1209
- msg = f"'auto_cast_inputs' is deprecated from version 2.0 and will be removed in a future version."
1210
- logger.warning(msg)
1211
- cast_inputs = inputs
1212
- mixed_type = self.get_mixed_precision_type()
1213
- if mixed_type == MixedPrecisionType.FP16:
1214
- cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float16)
1215
- if mixed_type == MixedPrecisionType.FP32:
1216
- cast_inputs = self._cast_mixed_precision_inputs(inputs, mstype.float32)
1217
-
1218
- return cast_inputs
1219
-
1220
1282
  def _init_check(self):
1221
1283
  for param in self.get_parameters(expand=False):
1222
1284
  if param.has_init:
@@ -1229,10 +1291,16 @@ class Cell(Cell_):
1229
1291
  self._is_check_and_refresh = True
1230
1292
 
1231
1293
  def _predict(self, *args, **kwargs):
1294
+ '''Graph executor for predict'''
1232
1295
  if not hasattr(self, "phase"):
1233
1296
  return False, None
1234
1297
  if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
1235
- new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
1298
+ new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, True)
1299
+ if self.jit_config_dict:
1300
+ jit_config_dict = self.jit_config_dict
1301
+ else:
1302
+ jit_config_dict = JitConfig().jit_config_dict
1303
+ _cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
1236
1304
  res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
1237
1305
  res = _convert_python_data(res)
1238
1306
  return True, res
@@ -1242,7 +1310,7 @@ class Cell(Cell_):
1242
1310
  # Run in Graph mode.
1243
1311
  if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
1244
1312
  if kwargs:
1245
- bound_arguments = self.sig.bind(*args, **kwargs)
1313
+ bound_arguments = self._construct_sig.bind(*args, **kwargs)
1246
1314
  bound_arguments.apply_defaults()
1247
1315
  args = bound_arguments.args
1248
1316
  kwargs = bound_arguments.kwargs
@@ -1324,37 +1392,12 @@ class Cell(Cell_):
1324
1392
  """
1325
1393
  with _no_grad():
1326
1394
  output = self.construct(*args, **kwargs)
1327
- _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
1328
- return output
1395
+ return _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
1329
1396
 
1330
1397
  def _add_attr(self, name, value):
1331
1398
  if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
1332
1399
  super(Cell, self)._add_attr(name, value)
1333
1400
 
1334
- def _sync_attr_for_compile(self):
1335
- """Sync the attr to c++ object."""
1336
- if self._attr_synced:
1337
- return
1338
- cells = self.__dict__.get('_cells')
1339
- for key in cells:
1340
- cell = cells[key]
1341
- cell._sync_attr_for_compile()
1342
- self._add_attr(key, cell)
1343
- params = self.__dict__.get('_params')
1344
- for key in params:
1345
- if '.' in key:
1346
- continue
1347
- param = params[key]
1348
- self._add_attr(key, param)
1349
- params_list = self.__dict__.get('_params_list')
1350
- for key in params_list:
1351
- params_list_item = params_list[key]
1352
- self._add_attr(key, params_list_item)
1353
- for key in self.__dict__:
1354
- value = self.__dict__[key]
1355
- self._add_attr(key, value)
1356
- self._attr_synced = True
1357
-
1358
1401
  def _set_attr_for_param_or_param_tuple(self, name, value):
1359
1402
  """Set attr for param and tensor."""
1360
1403
  if isinstance(value, Parameter):
@@ -1369,16 +1412,14 @@ class Cell(Cell_):
1369
1412
  # If there are multiple identical objects, their names only check once.
1370
1413
  continue
1371
1414
  exist_objs.add(item)
1372
- if item.name == PARAMETER_NAME_DEFAULT:
1373
- logger.warning("For 'Cell', the parameter definition is deprecated.\n"
1374
- "Please set a unique name for the parameter in ParameterTuple '{}'.".format(value))
1375
- item.name = item.name + "$" + str(self._id)
1415
+ if _is_parameter_generated(item.name):
1416
+ item.name = "Parameter$" + str(self._id)
1376
1417
  self._id += 1
1377
- self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1378
1418
  if item.name in exist_names:
1379
1419
  raise ValueError("The value {} , its name '{}' already exists. "
1380
1420
  "Please set a unique name for the parameter.".format(value, item.name))
1381
1421
  exist_names.add(item.name)
1422
+ self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1382
1423
 
1383
1424
  if context._get_mode() == context.PYNATIVE_MODE:
1384
1425
  if name in self.__dict__:
@@ -1398,9 +1439,6 @@ class Cell(Cell_):
1398
1439
  # If there are multiple identical objects, their names only check once.
1399
1440
  continue
1400
1441
  self.exist_objs.add(item)
1401
- if item.name == PARAMETER_NAME_DEFAULT:
1402
- item.name = item.name + "$" + str(self._id)
1403
- self._id += 1
1404
1442
  if item.name in self.exist_names:
1405
1443
  raise ValueError(f"The value {value} , its name '{item.name}' already exists. "
1406
1444
  "Please set a unique name for the parameter.")
@@ -1513,24 +1551,6 @@ class Cell(Cell_):
1513
1551
  main_str += ")"
1514
1552
  return main_str
1515
1553
 
1516
- def load_parameter_slice(self, params):
1517
- """
1518
- Replace parameters with sliced tensors by parallel strategies.
1519
-
1520
- Note:
1521
- This interface is deprecated.
1522
- """
1523
- logger.warning("'load_parameter_slice' function is deprecated.")
1524
-
1525
- def set_parallel_input_with_inputs(self, *inputs):
1526
- """
1527
- Slice inputs tensors by parallel strategies.
1528
-
1529
- Note:
1530
- This interface is deprecated.
1531
- """
1532
- logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
1533
-
1534
1554
  def set_inputs(self, *inputs, **kwargs):
1535
1555
  """
1536
1556
  Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
@@ -1665,7 +1685,6 @@ class Cell(Cell_):
1665
1685
  _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
1666
1686
  self._check_parameter_consistency(compile_args, args)
1667
1687
  Validator.check_symbolic_shape(compile_args, args)
1668
- self.saved_dynamic_shape = compile_args
1669
1688
  return compile_args
1670
1689
  return args
1671
1690
 
@@ -1678,8 +1697,9 @@ class Cell(Cell_):
1678
1697
  kwargs (dict): Kwargs of the Cell object.
1679
1698
  """
1680
1699
  _init_auto_parallel_context(self)
1681
- self._compile_args = self._get_compile_args(args)
1682
- _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1700
+ compile_args = self._get_compile_args(args)
1701
+ self._has_mutable_args_list = _get_mutable_flags(compile_args)
1702
+ _cell_graph_executor.compile(self, *compile_args, phase=self.phase,
1683
1703
  jit_config_dict=self._jit_config_dict, **kwargs)
1684
1704
  _clear_auto_parallel_context(self)
1685
1705
 
@@ -1698,25 +1718,14 @@ class Cell(Cell_):
1698
1718
  Object, the result of executing.
1699
1719
  """
1700
1720
  self.compile(*args, **kwargs)
1701
- self.add_flags(ge_sync_data=False)
1702
- new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
1721
+ new_args = _get_args_for_run(self, args, kwargs, self._has_mutable_args_list, False)
1722
+ if self.jit_config_dict:
1723
+ jit_config_dict = self.jit_config_dict
1724
+ else:
1725
+ jit_config_dict = JitConfig().jit_config_dict
1726
+ _cell_graph_executor._graph_executor.set_jit_config(jit_config_dict)
1703
1727
  return _cell_graph_executor(self, *new_args, phase=self.phase)
1704
1728
 
1705
- def auto_parallel_compile_and_run(self):
1706
- """
1707
- Whether or not to execute compile and run in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
1708
-
1709
- Note:
1710
- This interface is deprecated.
1711
- """
1712
- logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
1713
-
1714
- def exec_checkpoint_graph(self):
1715
- """Executes GE saving checkpoint graph operation."""
1716
- logger.warning("'exec_checkpoint_graph' function is deprecated.")
1717
- self.add_flags(ge_sync_data=True)
1718
- _cell_graph_executor(self, phase='save')
1719
-
1720
1729
  def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
1721
1730
  """
1722
1731
  Adds a parameter to the current cell.
@@ -1762,35 +1771,10 @@ class Cell(Cell_):
1762
1771
  if not isinstance(param, Parameter) and param is not None:
1763
1772
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1764
1773
  f"but got {type(param)}.")
1765
- if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1774
+ if isinstance(param, Parameter) and _is_parameter_generated(param.name):
1766
1775
  param.name = param_name
1767
1776
  self._params[param_name] = param
1768
1777
 
1769
- def cast_param(self, param):
1770
- """
1771
- Cast parameter according to auto mix precision level in pynative mode.
1772
-
1773
- This interface is currently used in the case of auto mix precision and usually needs not to be used explicitly.
1774
-
1775
- Args:
1776
- param (Parameter): Parameters, the type of which should be cast.
1777
-
1778
- Returns:
1779
- Parameter, the input parameter with type automatically cast.
1780
- """
1781
- msg = f"'cast_param' is deprecated from version 2.0 and will be removed in a future version."
1782
- logger.warning(msg)
1783
- mixed_type = self.get_mixed_precision_type()
1784
- if mixed_type != MixedPrecisionType.NOTSET:
1785
- if mixed_type == MixedPrecisionType.FP32:
1786
- param.set_cast_dtype(mstype.float32)
1787
- elif mixed_type == MixedPrecisionType.FP16:
1788
- param.set_cast_dtype(mstype.float16)
1789
- elif hasattr(param, "set_cast_dtype"):
1790
- # retest dtype
1791
- param.set_cast_dtype()
1792
- return param
1793
-
1794
1778
  def insert_child_to_cell(self, child_name, child_cell):
1795
1779
  """
1796
1780
  Adds a child cell to the current cell with a given name.
@@ -1850,27 +1834,10 @@ class Cell(Cell_):
1850
1834
  """
1851
1835
  Remove the redundant parameters.
1852
1836
 
1853
- This interface usually needs not to be used explicitly.
1837
+ .. warning::
1838
+ This interface will be deprecated in future versions.
1854
1839
  """
1855
- cells = self.cells_and_names()
1856
- for _, cell in cells:
1857
- params = cell._params.items()
1858
- for param_name, param in list(params):
1859
- if param.name not in self.parallel_parameter_name_list:
1860
- cell._params.pop(param_name)
1861
- logger.info("remove the redundant parameter: %s", param.name)
1862
- continue
1863
- cell_dict = cell.__dict__
1864
- for key in cell_dict:
1865
- if isinstance(cell_dict[key], ParameterTuple):
1866
- param_tuple = cell_dict[key]
1867
- new_param_tuple = []
1868
- for param in param_tuple:
1869
- if param.name not in self.parallel_parameter_name_list:
1870
- logger.info("remove the redundant parameter: %s in ParameterTuple", param.name)
1871
- continue
1872
- new_param_tuple.append(param)
1873
- cell.__dict__[key] = ParameterTuple(new_param_tuple)
1840
+ logger.warning(f"'remove_redundant_parameters' will be deprecated in future versions.")
1874
1841
 
1875
1842
  def _get_cell_parallel_mode(self):
1876
1843
  """Determine whether the current cell is in parallel mode."""
@@ -1926,16 +1893,13 @@ class Cell(Cell_):
1926
1893
  # replace all original usage.
1927
1894
  cells = self.cells_and_names()
1928
1895
  is_parallel_mode = self._get_cell_parallel_mode()
1929
- is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
1930
1896
 
1931
1897
  for _, cell in cells:
1932
1898
  params = cell._params.items()
1933
1899
  for param_name, param in params:
1934
- not_sliced = not param.sliced
1935
- judgment = not_sliced
1936
1900
  if param.param_info.is_pipeline_shared_param:
1937
1901
  continue
1938
- if is_graph_mode and is_parallel_mode and judgment:
1902
+ if is_parallel_mode and not param.sliced:
1939
1903
  continue
1940
1904
  if not auto_parallel_mode:
1941
1905
  cell._params[param_name] = _updata(param)
@@ -1948,11 +1912,9 @@ class Cell(Cell_):
1948
1912
  param_tuple = cell_dict[key]
1949
1913
  new_param_tuple = []
1950
1914
  for param in param_tuple:
1951
- not_sliced = not param.sliced
1952
- judgment = not_sliced
1953
1915
  if param.param_info.is_pipeline_shared_param:
1954
1916
  continue
1955
- if is_graph_mode and is_parallel_mode and judgment:
1917
+ if is_parallel_mode and not param.sliced:
1956
1918
  continue
1957
1919
  if not auto_parallel_mode:
1958
1920
  new_param_tuple.append(_updata(param))
@@ -2591,15 +2553,6 @@ class Cell(Cell_):
2591
2553
  self.add_flags_recursive(broadcast_flag=mode)
2592
2554
  return self
2593
2555
 
2594
- def set_auto_parallel(self):
2595
- """
2596
- Set the cell to auto parallel mode.
2597
-
2598
- Note:
2599
- This interface is deprecated.
2600
- """
2601
- logger.warning("'set_auto_parallel' function is deprecated.")
2602
-
2603
2556
  def set_jit_config(self, jit_config):
2604
2557
  """
2605
2558
  Set jit config for cell.
@@ -2645,25 +2598,38 @@ class Cell(Cell_):
2645
2598
  raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
2646
2599
  Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
2647
2600
 
2648
- def register_forward_pre_hook(self, hook_fn):
2601
+ def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
2649
2602
  """
2650
2603
  Register forward pre hook function for Cell object.
2651
2604
 
2605
+ The hook will be called before :func:`mindspore.nn.Cell.construct` is invoked.
2606
+
2607
+ The hook function should be one of the following signatures:
2608
+
2609
+ - `hook_fn(cell, args) -> None or new_args` , when `with_kwargs` is ``Flase`` .
2610
+ - `hook_fn(cell, args, kwargs) -> None or (new_args, new_kwargs)` , when `with_kwargs` is ``True`` .
2611
+
2612
+ where:
2613
+
2614
+ - `cell` (Cell): Cell object on which the hook is registered.
2615
+ - `args` (tuple): Positional arguments passed to the `construct` function.
2616
+ - `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
2617
+ `with_kwargs` is ``True`` .
2618
+
2652
2619
  Note:
2653
- - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2654
- - 'hook_fn' must be defined as the following code.
2655
- `cell` is the object of registered Cell. `inputs` is the forward
2656
- input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
2657
- forward input objects.
2658
- - It should have the following signature:
2659
- hook_fn(cell, inputs) -> new input objects or none.
2660
- - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2661
- `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
2662
- called in the `construct` function of the Cell object, a hook function will be added at each run time of
2663
- Cell object.
2620
+ - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2621
+ - The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
2622
+ single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
2623
+ returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
2624
+ - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2625
+ `construct` function of Cell object.
2626
+ - In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
2627
+ `hook_fn` will be added at each run time of Cell object.
2664
2628
 
2665
2629
  Args:
2666
2630
  hook_fn (function): Python function. Forward pre hook function.
2631
+ with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
2632
+ function. Default: ``False`` .
2667
2633
 
2668
2634
  Returns:
2669
2635
  A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
@@ -2705,60 +2671,66 @@ class Cell(Cell_):
2705
2671
  if context._get_mode() == context.GRAPH_MODE:
2706
2672
  return HookHandle()
2707
2673
  check_hook_fn(hook_fn)
2708
- handle = HookHandle(self._forward_pre_hook)
2674
+ handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
2709
2675
  self._forward_pre_hook[handle.handle_id] = hook_fn
2676
+ if with_kwargs:
2677
+ self._forward_pre_hook_with_kwargs[handle.handle_id] = True
2710
2678
  return handle
2711
2679
 
2712
- def _run_forward_pre_hook(self, inputs):
2680
+ def _run_forward_pre_hook(self, args, kwargs):
2713
2681
  """
2714
2682
  Running forward pre hook function registered on Cell object.
2683
+ """
2684
+ for hook_id, hook_fn in self._forward_pre_hook.items():
2685
+ if hook_id in self._forward_pre_hook_with_kwargs:
2686
+ ret = hook_fn(self, args, kwargs)
2687
+ if ret is not None:
2688
+ if isinstance(ret, tuple) and len(ret) == 2:
2689
+ args, kwargs = ret
2690
+ else:
2691
+ raise RuntimeError(
2692
+ "forward pre hook with kwargs must return None or a tuple of (new_args, new_kwargs), "
2693
+ f"but got {ret}"
2694
+ )
2695
+ else:
2696
+ ret = hook_fn(self, args)
2697
+ if ret is not None:
2698
+ if not isinstance(ret, tuple):
2699
+ ret = (ret,)
2700
+ args = ret
2701
+ return args, kwargs
2715
2702
 
2716
- Args:
2717
- inputs: The input objects of cell object.
2703
+ def register_forward_hook(self, hook_fn, with_kwargs=False):
2704
+ """
2705
+ Register forward hook function for Cell object.
2718
2706
 
2719
- Returns:
2720
- - **outputs** - New input objects or none.
2707
+ This hook will be called after :func:`mindspore.nn.Cell.construct` has computed an output.
2721
2708
 
2722
- Supported Platforms:
2723
- ``Ascend`` ``GPU`` ``CPU``
2724
- """
2725
- forward_pre_hook_inputs = inputs
2726
- for fn in self._forward_pre_hook.values():
2727
- ret = fn(self, forward_pre_hook_inputs)
2728
- if ret is not None:
2729
- if not isinstance(ret, tuple):
2730
- forward_pre_hook_inputs = (ret,)
2731
- else:
2732
- forward_pre_hook_inputs = ret
2709
+ The hook function should be one of the following signatures:
2733
2710
 
2734
- if isinstance(inputs, tuple):
2735
- if not isinstance(forward_pre_hook_inputs, tuple):
2736
- forward_pre_hook_inputs = (forward_pre_hook_inputs,)
2737
- if len(forward_pre_hook_inputs) != len(inputs):
2738
- raise TypeError(
2739
- "The forward pre hook return value size is {} not equal to input size {}".format(
2740
- len(forward_pre_hook_inputs), len(inputs)))
2741
- return forward_pre_hook_inputs
2711
+ - `hook_fn(cell, args, output) -> None or new_output` , when `with_kwargs` is ``False`` .
2712
+ - `hook_fn(cell, args, kwargs, output) -> None or new_output` , when `with_kwargs` is ``True`` .
2742
2713
 
2743
- def register_forward_hook(self, hook_fn):
2744
- """
2745
- Set the Cell forward hook function.
2714
+ where:
2715
+
2716
+ - `cell` (Cell): Cell object on which the hook is registered.
2717
+ - `args` (tuple): Positional arguments passed to the `construct` function.
2718
+ - `kwargs` (dict): Keyword arguments passed to the `construct` function. Only passed to `hook_fn` when
2719
+ `with_kwargs` is ``True`` .
2720
+ - `output`: Output generated by the `construct` function.
2746
2721
 
2747
2722
  Note:
2748
- - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2749
- - 'hook_fn' must be defined as the following code.
2750
- `cell` is the object of registered Cell. `inputs` is the forward
2751
- input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
2752
- modify the forward output object by returning new forward output object.
2753
- - It should have the following signature:
2754
- hook_fn(cell, inputs, output) -> new output object or none.
2755
- - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2756
- `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
2757
- called in the `construct` function of the Cell object, a hook function will be added at each run time of
2758
- Cell object.
2723
+ - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2724
+ - The `hook_fn` can modify the forward outputs by returning new outputs.
2725
+ - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2726
+ `construct` function of Cell object.
2727
+ - In the pynative mode, if this method is called inside the `construct` function of the Cell object, a
2728
+ `hook_fn` will be added at each run time of Cell object.
2759
2729
 
2760
2730
  Args:
2761
2731
  hook_fn (function): Python function. Forward hook function.
2732
+ with_kwargs (bool, optional): Specifies whether hook_fn will be passed the kwargs given to the `construct`
2733
+ function. Default: ``False`` .
2762
2734
 
2763
2735
  Returns:
2764
2736
  A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
@@ -2804,38 +2776,24 @@ class Cell(Cell_):
2804
2776
  if context._get_mode() == context.GRAPH_MODE:
2805
2777
  return HookHandle()
2806
2778
  check_hook_fn(hook_fn)
2807
- handle = HookHandle(self._forward_hook)
2779
+ handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
2808
2780
  self._forward_hook[handle.handle_id] = hook_fn
2781
+ if with_kwargs:
2782
+ self._forward_hook_with_kwargs[handle.handle_id] = True
2809
2783
  return handle
2810
2784
 
2811
- def _run_forward_hook(self, inputs, output):
2785
+ def _run_forward_hook(self, args, kwargs, output):
2812
2786
  """
2813
2787
  Running forward hook function registered on Cell object.
2814
-
2815
- Args:
2816
- inputs: The input objects of Cell object.
2817
- output: The output object of Cell object.
2818
-
2819
- Returns:
2820
- - **output** - New output object or none.
2821
-
2822
- Supported Platforms:
2823
- ``Ascend`` ``GPU`` ``CPU``
2824
2788
  """
2825
- forward_hook_output = output
2826
- for fn in self._forward_hook.values():
2827
- ret = fn(self, inputs, forward_hook_output)
2789
+ for hook_id, hook_fn in self._forward_hook.items():
2790
+ if hook_id in self._forward_hook_with_kwargs:
2791
+ ret = hook_fn(self, args, kwargs, output)
2792
+ else:
2793
+ ret = hook_fn(self, args, output)
2828
2794
  if ret is not None:
2829
- forward_hook_output = ret
2830
-
2831
- if isinstance(output, tuple):
2832
- if not isinstance(forward_hook_output, tuple):
2833
- forward_hook_output = (forward_hook_output,)
2834
- if len(forward_hook_output) != len(output):
2835
- raise TypeError(
2836
- "The forward hook return value size is {} not equal to output size {}".format(
2837
- len(forward_hook_output), len(output)))
2838
- return forward_hook_output
2795
+ output = ret
2796
+ return output
2839
2797
 
2840
2798
  def register_backward_pre_hook(self, hook_fn):
2841
2799
  """
@@ -3116,7 +3074,6 @@ class Cell(Cell_):
3116
3074
  OrderedDict([('param_a', Parameter (name=param_a, shape=(3,), dtype=Int64, requires_grad=True)), \
3117
3075
  ('buffer_a', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]))])
3118
3076
  """
3119
- # TODO: Remove `args` and the parsing logic when BC allows.
3120
3077
  if args:
3121
3078
  # DeprecationWarning is ignored by default
3122
3079
  warnings.warn(
@@ -3169,7 +3126,7 @@ class Cell(Cell_):
3169
3126
 
3170
3127
  It should have the following signature:
3171
3128
 
3172
- hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
3129
+ hook(cell, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None
3173
3130
 
3174
3131
  Args:
3175
3132
  hook (Callable): The hook function before `load_state_dict` is called.
@@ -3584,12 +3541,6 @@ class Cell(Cell_):
3584
3541
  for param in params:
3585
3542
  param.set_param_ps(init_in_server)
3586
3543
 
3587
- @deprecated("1.8", "set_param_fl")
3588
- def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
3589
- params = self.parameters_and_names()
3590
- for param in params:
3591
- param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
3592
-
3593
3544
  def set_comm_fusion(self, fusion_type, recurse=True):
3594
3545
  """
3595
3546
  Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
@@ -3698,7 +3649,7 @@ class Cell(Cell_):
3698
3649
  self._recompute()
3699
3650
  if 'mp_comm_recompute' in kwargs.keys():
3700
3651
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
3701
- if 'parallel_optimizer_comm_recompute' in kwargs.keys():
3652
+ if 'parallel_optimizer_comm_recompute' in kwargs:
3702
3653
  if kwargs.get('parallel_optimizer_comm_recompute', False):
3703
3654
  logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
3704
3655
  "is replaced with zero3.")
@@ -3711,38 +3662,6 @@ class Cell(Cell_):
3711
3662
  "the key kwargs must be 'mp_comm_recompute', "
3712
3663
  "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
3713
3664
 
3714
- @deprecated("2.3", "infer_param_pipeline_stage")
3715
- def infer_param_pipeline_stage(self):
3716
- """
3717
- Infer pipeline stages of all parameters in the cell.
3718
-
3719
- Note:
3720
- - The interface is deprecated from version 2.3 and will be removed in a future version.
3721
-
3722
- Returns:
3723
- The params belong to current stage in pipeline parallel.
3724
-
3725
- Raises:
3726
- RuntimeError: If there is a parameter does not belong to any stage.
3727
- """
3728
- from mindspore.parallel._utils import _get_global_rank, _get_device_num
3729
- logger.warning(f"This interface may be deleted in the future.")
3730
- stage_num = context.get_auto_parallel_context("pipeline_stages")
3731
- device_num = _get_device_num()
3732
- rank_id = _get_global_rank()
3733
- per_stage_devices = device_num // stage_num
3734
- current_stage = rank_id // per_stage_devices
3735
- params = []
3736
- for param in self.trainable_params():
3737
- if not param._pipeline_stage_list: # pylint: disable=W0212
3738
- raise RuntimeError("For 'infer_param_pipeline_stage', the parameter {} does not belong to any stage, "
3739
- "please check whether the cell where the param locates has been set "
3740
- "'pipeline_stage'. Otherwise, the parameter should use 'add_pipeline_stage' "
3741
- "to add its stage information".format(param.name))
3742
- if current_stage in param._pipeline_stage_list:
3743
- params.append(param)
3744
- return params
3745
-
3746
3665
  def place(self, role, rank_id):
3747
3666
  """
3748
3667
  Set the label for all operators in this cell.
@@ -3772,19 +3691,6 @@ class Cell(Cell_):
3772
3691
  for op in all_ops:
3773
3692
  op.place(role, rank_id)
3774
3693
 
3775
- def _mixed_precision_cast(self, inputs):
3776
- mixed_type = self.get_mixed_precision_type()
3777
- if mixed_type == MixedPrecisionType.NOTSET:
3778
- return inputs
3779
- if mixed_type == MixedPrecisionType.FP16:
3780
- cast_type = mstype.float16
3781
- elif mixed_type == MixedPrecisionType.BF16:
3782
- cast_type = mstype.bfloat16
3783
- else:
3784
- cast_type = mstype.float32
3785
- cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
3786
- return cast_inputs
3787
-
3788
3694
  def _get_attr_from_cell(self, network):
3789
3695
  if not isinstance(network, Cell):
3790
3696
  return
@@ -3793,91 +3699,11 @@ class Cell(Cell_):
3793
3699
  if hasattr(network, "_amp_level"):
3794
3700
  self._amp_level = getattr(network, "_amp_level")
3795
3701
 
3796
- def _register_parameters_hook(self, forward_hook=None, backward_hook=None, all=False):
3702
+ def _set_jit_graph_name(self, key):
3797
3703
  """
3798
- Register the forward hook for parameters and register the backward hook for the corresponding gradient.
3799
-
3800
- .. warning::
3801
- This is an experimental prototype that is subject to change and/or deletion.
3802
-
3803
- Note:
3804
- - The `_register_parameters_hook(forward_hook, backward_hook)` only work in graph mode
3805
- - The `forward_hook` must be defined as the following code.
3806
- `parameters`: the tuple of the trainble parameters of the Cell, each element in the tuple shuould be
3807
- in the format of `(param_name, Parameter)`.
3808
- - The `forward_hook` should have the following signature:
3809
- forward_hook(parameters) -> None.
3810
- - The `backward_hook` must be defined as the following code.
3811
- `gradients`: the tuple of the gradients corresponding to the trainble parameters of the Cell, each
3812
- element in the tuple shuould be in the format of `(param_name, gradient)`.
3813
- - The `backward_hook` should have the following signature:
3814
- backward_hook(parameters) -> New gradients.
3815
-
3816
- Args:
3817
- forward_hook (function, optional): Python function or ``None``, Forward hook function. Default: ``None``
3818
- backward_hook (function, optional): Python function or ``None``, Backward hook function. Default ``None``
3819
- all (bool, optional): bool, whether to set hooks for all sub cells recursively. Default: ``False``
3820
-
3821
- Returns:
3822
- None
3823
-
3824
- Raises:
3825
- RuntimeError: If the `forward_hook` or `backward_hook ` has unspoorted syntax under GRAPH MODE.
3826
- TypeError: If the `forward_hook` or `backward_hook` is not defined as required.
3827
-
3828
- Supported Platforms:
3829
- ``Ascend`` ``GPU`` ``CPU``
3830
-
3831
- Examples:
3832
- >>> import mindspore as ms
3833
- >>> from mindspore import Tensor, nn, ops, Parameter
3834
- >>>
3835
- >>> ms.set_context(mode=ms.GRAPH_MODE)
3836
- >>> def parameter_hook(parameters):
3837
- ... print("--- enter parameter hook ---")
3838
- ... for name, param in parameters:
3839
- ... print (name, param)
3840
- ... print("--- leave parameter hook ---")
3841
- ...
3842
- >>> def gradient_hook(gradients):
3843
- ... print("--- enter gradient hook ---")
3844
- ... outs = []
3845
- ... for name, gradient in gradients:
3846
- ... print(name, gradient)
3847
- ... outs.append(gradient * 2) # double gradient
3848
- ... print("--- leave gradient hook ---")
3849
- ... return outs
3850
- ...
3851
- >>> class Net(nn.Cell):
3852
- ... def __init__(self)
3853
- ... super(Net, self).__init__()
3854
- ... self.w = Parameter(Tensor(np.array([3.0], np.float32)), name='w')
3855
- ... def construct(self, x):
3856
- ... return self.w * x
3857
- ...
3858
- >>> grad = ops.GradOperation(get_by_list=True)
3859
- >>> net = Net()
3860
- >>> net._register_parameters_hook(forward_hook=parameter_hook, backward_hook=gradient_hook)
3861
- >>> x = Tensor(np.array([4.0]).astype(np.float32))
3862
- >>> output = grad(net, net.trainable_params())(x)
3863
- --- enter parameter hook ---
3864
- w
3865
- Tensor(shape=[1], dtype=Float32, value=[ 3.00000000e+00])
3866
- --- leave parameter hook ---
3867
- --- enter gradient hook ---
3868
- w
3869
- Tensor(shape=[1], dtype=Float32, value=[ 4.00000000e+00])
3870
- --- leave gradient hook ---
3871
- >>> print("doubled grad: ", output)
3872
- doubled grad: (Tensor(shape=[1], dtype=Float32, value=[ 8.00000000e+00]),)
3873
- """
3874
- if not all:
3875
- self._parameters_forward_hook = forward_hook
3876
- self._parameters_backward_hook = backward_hook
3877
- else:
3878
- for _, cell in self.cells_and_names():
3879
- cell._parameters_forward_hook = forward_hook
3880
- cell._parameters_backward_hook = backward_hook
3704
+ Set jit graph name.
3705
+ """
3706
+ self._jit_graph_name = key
3881
3707
 
3882
3708
 
3883
3709
  class GraphCell(Cell):