mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (370) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +5 -2
  3. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +2 -2
  7. mindspore/_extends/builtin_operations.py +3 -3
  8. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  9. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  10. mindspore/_extends/parse/__init__.py +3 -3
  11. mindspore/_extends/parse/compile_config.py +24 -1
  12. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
  13. mindspore/_extends/parse/parser.py +28 -22
  14. mindspore/_extends/parse/resources.py +1 -1
  15. mindspore/_extends/parse/standard_method.py +23 -2
  16. mindspore/_extends/parse/trope.py +2 -1
  17. mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
  18. mindspore/amp.py +0 -18
  19. mindspore/avcodec-59.dll +0 -0
  20. mindspore/avdevice-59.dll +0 -0
  21. mindspore/avfilter-8.dll +0 -0
  22. mindspore/avformat-59.dll +0 -0
  23. mindspore/avutil-57.dll +0 -0
  24. mindspore/boost/base.py +29 -2
  25. mindspore/common/__init__.py +18 -12
  26. mindspore/common/_decorator.py +3 -2
  27. mindspore/common/_grad_function.py +3 -1
  28. mindspore/common/_tensor_cpp_method.py +1 -1
  29. mindspore/common/_tensor_docs.py +371 -96
  30. mindspore/common/_utils.py +7 -43
  31. mindspore/common/api.py +434 -135
  32. mindspore/common/dtype.py +98 -57
  33. mindspore/common/dump.py +7 -108
  34. mindspore/common/dynamic_shape/__init__.py +0 -0
  35. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
  36. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  37. mindspore/common/file_system.py +59 -9
  38. mindspore/common/hook_handle.py +82 -3
  39. mindspore/common/jit_config.py +5 -1
  40. mindspore/common/jit_trace.py +27 -12
  41. mindspore/common/lazy_inline.py +5 -3
  42. mindspore/common/np_dtype.py +3 -3
  43. mindspore/common/parameter.py +17 -127
  44. mindspore/common/recompute.py +4 -13
  45. mindspore/common/tensor.py +50 -217
  46. mindspore/communication/_comm_helper.py +11 -1
  47. mindspore/communication/comm_func.py +138 -4
  48. mindspore/communication/management.py +85 -1
  49. mindspore/config/op_info.config +0 -15
  50. mindspore/context.py +20 -106
  51. mindspore/dataset/__init__.py +1 -1
  52. mindspore/dataset/audio/transforms.py +1 -1
  53. mindspore/dataset/core/config.py +35 -1
  54. mindspore/dataset/engine/datasets.py +338 -319
  55. mindspore/dataset/engine/datasets_user_defined.py +38 -22
  56. mindspore/dataset/engine/datasets_vision.py +1 -1
  57. mindspore/dataset/engine/validators.py +1 -15
  58. mindspore/dataset/transforms/c_transforms.py +2 -2
  59. mindspore/dataset/transforms/transforms.py +3 -3
  60. mindspore/dataset/vision/__init__.py +1 -1
  61. mindspore/dataset/vision/py_transforms.py +8 -8
  62. mindspore/dataset/vision/transforms.py +17 -5
  63. mindspore/dataset/vision/utils.py +632 -21
  64. mindspore/device_context/ascend/op_tuning.py +35 -1
  65. mindspore/dnnl.dll +0 -0
  66. mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
  67. mindspore/graph/custom_pass.py +55 -0
  68. mindspore/include/api/cell.h +28 -4
  69. mindspore/include/api/cfg.h +24 -7
  70. mindspore/include/api/context.h +1 -0
  71. mindspore/include/api/delegate.h +0 -2
  72. mindspore/include/api/dual_abi_helper.h +100 -19
  73. mindspore/include/api/graph.h +14 -1
  74. mindspore/include/api/kernel.h +16 -3
  75. mindspore/include/api/kernel_api.h +9 -1
  76. mindspore/include/api/metrics/accuracy.h +9 -0
  77. mindspore/include/api/model.h +5 -1
  78. mindspore/include/api/model_group.h +4 -0
  79. mindspore/include/api/model_parallel_runner.h +2 -0
  80. mindspore/include/api/status.h +48 -10
  81. mindspore/include/api/types.h +6 -1
  82. mindspore/include/dataset/constants.h +9 -0
  83. mindspore/include/dataset/execute.h +2 -2
  84. mindspore/jpeg62.dll +0 -0
  85. mindspore/mindrecord/__init__.py +3 -3
  86. mindspore/mindrecord/common/exceptions.py +1 -0
  87. mindspore/mindrecord/config.py +1 -1
  88. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  89. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  90. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  91. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  92. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  93. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  94. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  95. mindspore/mindrecord/filereader.py +4 -4
  96. mindspore/mindrecord/filewriter.py +5 -5
  97. mindspore/mindrecord/mindpage.py +2 -2
  98. mindspore/mindrecord/tools/cifar10.py +4 -3
  99. mindspore/mindrecord/tools/cifar100.py +1 -1
  100. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  101. mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
  102. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  103. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  104. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  105. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  106. mindspore/mindspore_backend_common.dll +0 -0
  107. mindspore/mindspore_backend_manager.dll +0 -0
  108. mindspore/mindspore_cluster.dll +0 -0
  109. mindspore/mindspore_common.dll +0 -0
  110. mindspore/mindspore_core.dll +0 -0
  111. mindspore/mindspore_cpu.dll +0 -0
  112. mindspore/mindspore_dump.dll +0 -0
  113. mindspore/mindspore_frontend.dll +0 -0
  114. mindspore/mindspore_glog.dll +0 -0
  115. mindspore/mindspore_hardware_abstract.dll +0 -0
  116. mindspore/mindspore_memory_pool.dll +0 -0
  117. mindspore/mindspore_ms_backend.dll +0 -0
  118. mindspore/mindspore_ops.dll +0 -0
  119. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  120. mindspore/mindspore_profiler.dll +0 -0
  121. mindspore/mindspore_pyboost.dll +0 -0
  122. mindspore/mindspore_pynative.dll +0 -0
  123. mindspore/mindspore_runtime_pipeline.dll +0 -0
  124. mindspore/mindspore_runtime_utils.dll +0 -0
  125. mindspore/mindspore_tools.dll +0 -0
  126. mindspore/mint/__init__.py +15 -10
  127. mindspore/mint/distributed/__init__.py +4 -0
  128. mindspore/mint/distributed/distributed.py +392 -69
  129. mindspore/mint/nn/__init__.py +2 -16
  130. mindspore/mint/nn/functional.py +4 -110
  131. mindspore/mint/nn/layer/__init__.py +0 -2
  132. mindspore/mint/nn/layer/_functions.py +1 -2
  133. mindspore/mint/nn/layer/activation.py +0 -6
  134. mindspore/mint/nn/layer/basic.py +0 -47
  135. mindspore/mint/nn/layer/conv.py +10 -10
  136. mindspore/mint/nn/layer/normalization.py +11 -16
  137. mindspore/mint/nn/layer/pooling.py +0 -4
  138. mindspore/nn/__init__.py +1 -3
  139. mindspore/nn/cell.py +231 -239
  140. mindspore/nn/layer/activation.py +4 -2
  141. mindspore/nn/layer/basic.py +56 -14
  142. mindspore/nn/layer/container.py +16 -0
  143. mindspore/nn/layer/embedding.py +4 -169
  144. mindspore/nn/layer/image.py +1 -1
  145. mindspore/nn/layer/normalization.py +2 -1
  146. mindspore/nn/layer/thor_layer.py +4 -85
  147. mindspore/nn/optim/ada_grad.py +0 -1
  148. mindspore/nn/optim/adafactor.py +0 -1
  149. mindspore/nn/optim/adam.py +32 -127
  150. mindspore/nn/optim/adamax.py +0 -1
  151. mindspore/nn/optim/asgd.py +0 -1
  152. mindspore/nn/optim/ftrl.py +8 -102
  153. mindspore/nn/optim/lamb.py +1 -4
  154. mindspore/nn/optim/lars.py +0 -3
  155. mindspore/nn/optim/lazyadam.py +25 -218
  156. mindspore/nn/optim/momentum.py +5 -43
  157. mindspore/nn/optim/optimizer.py +6 -55
  158. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  159. mindspore/nn/optim/rmsprop.py +0 -1
  160. mindspore/nn/optim/rprop.py +0 -1
  161. mindspore/nn/optim/sgd.py +0 -1
  162. mindspore/nn/optim/tft_wrapper.py +2 -4
  163. mindspore/nn/optim/thor.py +0 -2
  164. mindspore/nn/probability/bijector/bijector.py +7 -8
  165. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  166. mindspore/nn/probability/bijector/power_transform.py +20 -21
  167. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  168. mindspore/nn/probability/bijector/softplus.py +13 -14
  169. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  170. mindspore/nn/wrap/cell_wrapper.py +39 -5
  171. mindspore/nn/wrap/grad_reducer.py +4 -89
  172. mindspore/numpy/array_creations.py +4 -4
  173. mindspore/numpy/fft.py +9 -9
  174. mindspore/numpy/utils_const.py +1 -1
  175. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  176. mindspore/onnx/onnx_export.py +137 -0
  177. mindspore/opencv_core4110.dll +0 -0
  178. mindspore/opencv_imgcodecs4110.dll +0 -0
  179. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  180. mindspore/ops/__init__.py +2 -0
  181. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  182. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  183. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  184. mindspore/ops/_op_impl/cpu/__init__.py +1 -5
  185. mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
  186. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
  187. mindspore/ops/auto_generate/gen_extend_func.py +6 -11
  188. mindspore/ops/auto_generate/gen_ops_def.py +385 -154
  189. mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
  190. mindspore/ops/communication.py +97 -0
  191. mindspore/ops/composite/__init__.py +5 -2
  192. mindspore/ops/composite/base.py +16 -2
  193. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  194. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  195. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  196. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  197. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  198. mindspore/ops/function/__init__.py +2 -0
  199. mindspore/ops/function/array_func.py +24 -18
  200. mindspore/ops/function/comm_func.py +3883 -0
  201. mindspore/ops/function/debug_func.py +7 -6
  202. mindspore/ops/function/grad/grad_func.py +4 -12
  203. mindspore/ops/function/math_func.py +89 -86
  204. mindspore/ops/function/nn_func.py +92 -313
  205. mindspore/ops/function/random_func.py +9 -18
  206. mindspore/ops/functional.py +4 -1
  207. mindspore/ops/functional_overload.py +377 -30
  208. mindspore/ops/operations/__init__.py +2 -5
  209. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  210. mindspore/ops/operations/_inner_ops.py +12 -50
  211. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  212. mindspore/ops/operations/array_ops.py +5 -50
  213. mindspore/ops/operations/comm_ops.py +95 -17
  214. mindspore/ops/operations/custom_ops.py +237 -22
  215. mindspore/ops/operations/debug_ops.py +33 -35
  216. mindspore/ops/operations/manually_defined/ops_def.py +39 -318
  217. mindspore/ops/operations/math_ops.py +5 -5
  218. mindspore/ops/operations/nn_ops.py +3 -3
  219. mindspore/ops/operations/sparse_ops.py +0 -83
  220. mindspore/ops/primitive.py +4 -27
  221. mindspore/ops/tensor_method.py +88 -10
  222. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  223. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  224. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  225. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  226. mindspore/ops_generate/common/gen_constants.py +11 -10
  227. mindspore/ops_generate/common/op_proto.py +18 -1
  228. mindspore/ops_generate/common/template.py +102 -245
  229. mindspore/ops_generate/common/template_utils.py +212 -0
  230. mindspore/ops_generate/gen_custom_ops.py +69 -0
  231. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  232. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  233. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  234. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  235. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  236. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  237. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  238. mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
  239. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  240. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  241. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  242. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  243. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  244. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  245. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  246. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  247. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  248. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  249. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  250. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  251. mindspore/parallel/_auto_parallel_context.py +5 -15
  252. mindspore/parallel/_cell_wrapper.py +1 -1
  253. mindspore/parallel/_parallel_serialization.py +4 -6
  254. mindspore/parallel/_ps_context.py +2 -2
  255. mindspore/parallel/_utils.py +34 -17
  256. mindspore/parallel/auto_parallel.py +23 -9
  257. mindspore/parallel/checkpoint_transform.py +20 -2
  258. mindspore/parallel/cluster/process_entity/_api.py +28 -33
  259. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  260. mindspore/parallel/cluster/run.py +5 -3
  261. mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
  262. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  263. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  264. mindspore/parallel/function/reshard_func.py +6 -5
  265. mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
  266. mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
  267. mindspore/parallel/shard.py +7 -21
  268. mindspore/parallel/strategy.py +336 -0
  269. mindspore/parallel/transform_safetensors.py +127 -20
  270. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
  271. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
  272. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  273. mindspore/profiler/common/constant.py +5 -0
  274. mindspore/profiler/common/file_manager.py +9 -0
  275. mindspore/profiler/common/msprof_cmd_tool.py +40 -4
  276. mindspore/profiler/common/path_manager.py +65 -24
  277. mindspore/profiler/common/profiler_context.py +27 -14
  278. mindspore/profiler/common/profiler_info.py +3 -3
  279. mindspore/profiler/common/profiler_meta_data.py +1 -0
  280. mindspore/profiler/common/profiler_op_analyse.py +10 -6
  281. mindspore/profiler/common/profiler_path_manager.py +13 -0
  282. mindspore/profiler/common/util.py +30 -3
  283. mindspore/profiler/dynamic_profiler.py +91 -46
  284. mindspore/profiler/envprofiler.py +30 -5
  285. mindspore/profiler/experimental_config.py +18 -2
  286. mindspore/profiler/platform/cpu_profiler.py +10 -4
  287. mindspore/profiler/platform/npu_profiler.py +34 -7
  288. mindspore/profiler/profiler.py +193 -145
  289. mindspore/profiler/profiler_action_controller.py +1 -1
  290. mindspore/profiler/profiler_interface.py +2 -2
  291. mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
  292. mindspore/run_check/_check_version.py +108 -24
  293. mindspore/runtime/__init__.py +9 -6
  294. mindspore/runtime/executor.py +35 -0
  295. mindspore/runtime/memory.py +113 -0
  296. mindspore/runtime/thread_bind_core.py +1 -1
  297. mindspore/swresample-4.dll +0 -0
  298. mindspore/swscale-6.dll +0 -0
  299. mindspore/tinyxml2.dll +0 -0
  300. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  301. mindspore/tools/data_dump.py +130 -0
  302. mindspore/tools/sdc_detect.py +91 -0
  303. mindspore/tools/stress_detect.py +63 -0
  304. mindspore/train/__init__.py +6 -6
  305. mindspore/train/_utils.py +8 -21
  306. mindspore/train/amp.py +6 -7
  307. mindspore/train/callback/_callback.py +2 -1
  308. mindspore/train/callback/_checkpoint.py +1 -17
  309. mindspore/train/callback/_flops_collector.py +10 -6
  310. mindspore/train/callback/_train_fault_tolerance.py +72 -25
  311. mindspore/train/data_sink.py +5 -9
  312. mindspore/train/dataset_helper.py +5 -5
  313. mindspore/train/model.py +41 -230
  314. mindspore/train/serialization.py +160 -401
  315. mindspore/train/train_thor/model_thor.py +2 -2
  316. mindspore/turbojpeg.dll +0 -0
  317. mindspore/utils/__init__.py +6 -3
  318. mindspore/utils/dlpack.py +92 -0
  319. mindspore/utils/dryrun.py +1 -1
  320. mindspore/utils/runtime_execution_order_check.py +10 -0
  321. mindspore/utils/sdc_detect.py +14 -12
  322. mindspore/utils/stress_detect.py +43 -0
  323. mindspore/utils/utils.py +152 -16
  324. mindspore/version.py +1 -1
  325. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  326. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
  327. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  328. mindspore/communication/_hccl_management.py +0 -297
  329. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
  330. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  331. mindspore/experimental/llm_boost/atb/__init__.py +0 -23
  332. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  333. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  334. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  335. mindspore/experimental/llm_boost/register.py +0 -130
  336. mindspore/experimental/llm_boost/utils.py +0 -31
  337. mindspore/include/OWNERS +0 -7
  338. mindspore/mindspore_cpu_res_manager.dll +0 -0
  339. mindspore/mindspore_ops_kernel_common.dll +0 -0
  340. mindspore/mindspore_res_manager.dll +0 -0
  341. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  342. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  343. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  344. mindspore/nn/reinforcement/tensor_array.py +0 -145
  345. mindspore/opencv_core452.dll +0 -0
  346. mindspore/opencv_imgcodecs452.dll +0 -0
  347. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  348. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  349. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  350. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  351. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  352. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  353. mindspore/ops/operations/_tensor_array.py +0 -359
  354. mindspore/ops/operations/rl_ops.py +0 -288
  355. mindspore/parallel/_offload_context.py +0 -275
  356. mindspore/parallel/_recovery_context.py +0 -115
  357. mindspore/parallel/_transformer/__init__.py +0 -35
  358. mindspore/parallel/_transformer/layers.py +0 -765
  359. mindspore/parallel/_transformer/loss.py +0 -251
  360. mindspore/parallel/_transformer/moe.py +0 -693
  361. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  362. mindspore/parallel/_transformer/transformer.py +0 -3124
  363. mindspore/parallel/mpi/_mpi_config.py +0 -116
  364. mindspore/profiler/common/validator/validate_path.py +0 -84
  365. mindspore/train/memory_profiling_pb2.py +0 -298
  366. mindspore/utils/hooks.py +0 -81
  367. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  368. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  369. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  370. {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -39,10 +39,11 @@ from typing import (
39
39
 
40
40
  import weakref
41
41
  import mindspore as ms
42
+ import mindspore.ops as ops
42
43
  from mindspore._checkparam import args_type_check, check_hook_fn
43
- from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
44
+ from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
44
45
  from mindspore import log as logger
45
- from mindspore.common.hook_handle import HookHandle
46
+ from mindspore.common.hook_handle import HookHandle, _update_hook_version
46
47
  from mindspore import context
47
48
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
48
49
  from mindspore import _checkparam as Validator
@@ -92,9 +93,8 @@ def register_cell_buffer_registration_hook(hook: Callable[..., None],):
92
93
  A handle that can be used to remove the added hook by calling
93
94
  `handle.remove()`.
94
95
  """
95
- from mindspore.utils.hooks import _RemovableHandle
96
- handle = _RemovableHandle(_global_buffer_registration_hooks)
97
- _global_buffer_registration_hooks[handle.id] = hook
96
+ handle = HookHandle(_global_buffer_registration_hooks)
97
+ _global_buffer_registration_hooks[handle.handle_id] = hook
98
98
  return handle
99
99
 
100
100
 
@@ -155,7 +155,8 @@ class Cell(Cell_):
155
155
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
156
156
  '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
157
157
  '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
158
- '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix', 'requires_grad', 'cell_type']
158
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_param_prefix',
159
+ 'requires_grad', 'cell_type', '_in_strategy', '_out_strategy']
159
160
  total_instance_count = 0
160
161
  _buffers: Dict[str, Optional[Tensor]]
161
162
  global_cells = weakref.WeakKeyDictionary()
@@ -191,6 +192,7 @@ class Cell(Cell_):
191
192
  super().__setattr__("_auto_prefix", auto_prefix)
192
193
  super().__setattr__("_scope", None)
193
194
  super().__setattr__("_phase", 'train')
195
+ super().__setattr__("_compile_phase", None)
194
196
  super().__setattr__("_parameter_layout_dict", None)
195
197
  super().__setattr__("_parallel_parameter_name_list", None)
196
198
  super().__setattr__("_parallel_parameter_merge_net_dict", None)
@@ -206,6 +208,7 @@ class Cell(Cell_):
206
208
  super().__setattr__("mixed_precision_type", None)
207
209
  super().__setattr__("_lazy_construct_sig", None)
208
210
  super().__setattr__("_jit_graph_name", '')
211
+ super().__setattr__("_compiled", False)
209
212
  init_pipeline()
210
213
 
211
214
  # call gc to release GE session resources used by non-used cell objects
@@ -239,6 +242,8 @@ class Cell(Cell_):
239
242
  super().__setattr__("_amp_level", "")
240
243
  super().__setattr__("_init_flag", False)
241
244
  super().__setattr__("_shard_fn", None)
245
+ super().__setattr__("_in_strategy", None)
246
+ super().__setattr__("_out_strategy", None)
242
247
  super().__setattr__("has_bprop", False)
243
248
  if hasattr(self, "bprop"):
244
249
  super().__setattr__("has_bprop", True)
@@ -426,6 +431,13 @@ class Cell(Cell_):
426
431
  """
427
432
  return self._bprop_debug
428
433
 
434
+ @property
435
+ def compiled(self):
436
+ """
437
+ Get whether `Cell` is compiled in graph mode.
438
+ """
439
+ return self._compiled
440
+
429
441
  @bprop_debug.setter
430
442
  def bprop_debug(self, value):
431
443
  """
@@ -482,6 +494,19 @@ class Cell(Cell_):
482
494
  raise TypeError(f"For 'Cell', the property 'phase' must be string type, but got type {type(value)}.")
483
495
  self._phase = value
484
496
 
497
+ @property
498
+ def compile_phase(self):
499
+ return self._compile_phase
500
+
501
+ @compile_phase.setter
502
+ def compile_phase(self, value):
503
+ if not isinstance(value, str):
504
+ raise TypeError(f"For 'Cell', 'compile_phase' must be string type, but got type {type(value)}.")
505
+ self._compile_phase = value
506
+ for cell in self._cells.values():
507
+ if cell is not None:
508
+ cell.compile_phase = value
509
+
485
510
  @property
486
511
  def parameter_layout_dict(self):
487
512
  """
@@ -546,10 +571,23 @@ class Cell(Cell_):
546
571
 
547
572
  @property
548
573
  def pipeline_segment(self):
574
+ """
575
+ `pipeline_segment` represents the pipeline segment of current Cell.
576
+ """
549
577
  return self._pipeline_segment
550
578
 
551
579
  @pipeline_segment.setter
552
580
  def pipeline_segment(self, value):
581
+ """
582
+ Set the `pipeline_segment` of a Cell. Only effective in zero_bubble_v scheduler.
583
+
584
+ Args:
585
+ value (int): The pipeline segment of a parameter.
586
+
587
+ Raises:
588
+ TypeError: If `value` is not int type or is a bool type.
589
+ ValueError: If `value` is not a positive integer.
590
+ """
553
591
  if not isinstance(value, int) or isinstance(value, bool):
554
592
  raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
555
593
  "must be int type, but got type : {}".format(type(value)))
@@ -1027,12 +1065,13 @@ class Cell(Cell_):
1027
1065
  if self._forward_pre_hook:
1028
1066
  args, kwargs = self._run_forward_pre_hook(args, kwargs)
1029
1067
 
1068
+ if self._backward_hook:
1069
+ args = self._cell_backward_hook(args)
1070
+
1030
1071
  if self._shard_fn is not None:
1031
1072
  output = self._shard_fn(*args, **kwargs)
1032
1073
  elif _pynative_executor.requires_grad():
1033
- if self._backward_hook:
1034
- output = self._backward_hook_construct(*args, **kwargs)
1035
- elif self._recompute_cell is not None:
1074
+ if self._recompute_cell is not None:
1036
1075
  output = self._recompute_cell(*args, **kwargs)
1037
1076
  elif self.has_bprop:
1038
1077
  output = self._call_custom_bprop(*args, **kwargs)
@@ -1044,8 +1083,11 @@ class Cell(Cell_):
1044
1083
  if self._forward_hook:
1045
1084
  output = self._run_forward_hook(args, kwargs, output)
1046
1085
 
1047
- if self._backward_pre_hook and _pynative_executor.requires_grad():
1048
- output = self._run_backward_pre_hook(output)
1086
+ if self._backward_hook:
1087
+ output = self._cell_backward_hook(output)
1088
+
1089
+ if self._backward_pre_hook:
1090
+ output = self._cell_backward_pre_hook(output)
1049
1091
 
1050
1092
  return output
1051
1093
 
@@ -1080,23 +1122,6 @@ class Cell(Cell_):
1080
1122
  f"{default_args} default argument, total {positional_args + default_args}, "
1081
1123
  f"but got {len(args)}.")
1082
1124
 
1083
- # pylint: disable=E0203
1084
- def _hook_fn_registered(self):
1085
- '''Hook function in graph mode'''
1086
- # Check super().__init__() in graph mode.
1087
- try:
1088
- if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
1089
- return True
1090
- except AttributeError as e:
1091
- raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1092
- f"Please use 'super().__init__()'.") from e
1093
- if not self._is_recursion_hook:
1094
- self._is_recursion_hook = True
1095
- for cell in self.cells():
1096
- if cell._hook_fn_registered():
1097
- return True
1098
- return False
1099
-
1100
1125
  def _get_prims_recursively(self):
1101
1126
  all_prims = list()
1102
1127
  for _, value in self._primitives.items():
@@ -1122,9 +1147,6 @@ class Cell(Cell_):
1122
1147
  >>> net = nn.Dense(3, 4)
1123
1148
  >>> net.set_data_parallel()
1124
1149
  """
1125
- if context._get_mode() == context.PYNATIVE_MODE:
1126
- raise ValueError("set_data_parallel: does not support PyNative mode.")
1127
-
1128
1150
  all_prims = self._get_prims_recursively()
1129
1151
  for prim in all_prims:
1130
1152
  prim.add_prim_attr("strategy_gen_mode", "data_parallel")
@@ -1203,8 +1225,6 @@ class Cell(Cell_):
1203
1225
  ... out = self.blocks[i](out)
1204
1226
  ... return out
1205
1227
  """
1206
- if context._get_mode() == context.PYNATIVE_MODE:
1207
- raise ValueError("The Cell offload does not support PyNative mode now.")
1208
1228
  if isinstance(backward_prefetch, str):
1209
1229
  Validator.check_string(backward_prefetch, ['Auto'], 'backward_prefetch', self.cls_name)
1210
1230
  else:
@@ -1212,11 +1232,10 @@ class Cell(Cell_):
1212
1232
  for prim in self._get_prims_recursively():
1213
1233
  prim._offload(backward_prefetch=backward_prefetch)
1214
1234
 
1215
- def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
1235
+ def shard(self, in_strategy, out_strategy=None, parameter_plan=None):
1216
1236
  """
1217
1237
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
1218
- generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
1219
- execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
1238
+ generated by sharding propagation. In Graph mode, use this method to specify distribution strategy for a Cell,
1220
1239
  strategy for others will be set by sharding propagation.
1221
1240
  in_strategy and out_strategy define the input and output layout respectively.
1222
1241
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
@@ -1228,11 +1247,14 @@ class Cell(Cell_):
1228
1247
  In other parallel modes, strategies set here will be ignored.
1229
1248
  - If the input contain Parameter, its strategy should be set in `in_strategy`.
1230
1249
 
1250
+ .. warning::
1251
+ The method is currently not supported in PyNative mode.
1252
+
1231
1253
  Args:
1232
1254
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
1233
1255
  defines the layout of the corresponding input.
1234
1256
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
1235
- It is not in use right now. Default: ``None`` .
1257
+ Default: ``None`` .
1236
1258
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
1237
1259
  defines the layout of the parameter like "param_name: layout".
1238
1260
  The key is a parameter name of type 'str'.
@@ -1240,14 +1262,6 @@ class Cell(Cell_):
1240
1262
  If the parameter name is incorrect or the corresponding parameter
1241
1263
  has been set, the parameter setting will be ignored.
1242
1264
  Default: ``None`` .
1243
- device (str): Select a certain device target. It is not in use right now.
1244
- Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
1245
- level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
1246
- over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
1247
- use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
1248
-
1249
- Returns:
1250
- Function, return the cell construct function that will be executed under auto parallel process.
1251
1265
 
1252
1266
  Examples:
1253
1267
  >>> import mindspore.nn as nn
@@ -1265,19 +1279,34 @@ class Cell(Cell_):
1265
1279
  ... def __init__(self):
1266
1280
  ... self.block1 = Block()
1267
1281
  ... self.block2 = Block()
1268
- ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
1269
- ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
1282
+ ... self.block2.shard(in_strategy=((2, 1),), parameter_plan={'self.block2.dense1.weight': (4, 1)})
1270
1283
  ... def construct(self, x):
1271
1284
  ... x = self.block1(x)
1272
- ... x = self.block2_shard(x)
1285
+ ... x = self.block2(x)
1273
1286
  ... return x
1274
1287
  """
1275
1288
  if ms.communication.management.get_group_size() == 1:
1276
- return self
1289
+ return
1290
+
1277
1291
  shard_fn = Shard()
1278
- fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
1279
- self._shard_fn = fn
1280
- return fn
1292
+ self._shard_fn = shard_fn(self, in_strategy, out_strategy, parameter_plan)
1293
+
1294
+ if self._in_strategy is not None: # pylint: disable=E0203
1295
+ msg = (
1296
+ "For '%s', 'Shard' has been configured more than once. "
1297
+ "The existing in_strategy is %s and the existing out_strategy is %s. "
1298
+ "The new in_strategy %s and out_strategy %s may not take effect. "
1299
+ "It is recommended to configure 'Shard' only once."
1300
+ ) % (
1301
+ self._cell_tag,
1302
+ self._in_strategy, # pylint: disable=E0203
1303
+ self._out_strategy, # pylint: disable=E0203
1304
+ shard_fn.in_strategy,
1305
+ shard_fn.out_strategy,
1306
+ )
1307
+ logger.warning(msg)
1308
+ self._in_strategy = shard_fn.in_strategy
1309
+ self._out_strategy = shard_fn.out_strategy
1281
1310
 
1282
1311
  def _init_check(self):
1283
1312
  for param in self.get_parameters(expand=False):
@@ -1286,9 +1315,13 @@ class Cell(Cell_):
1286
1315
  self._init_flag = True
1287
1316
 
1288
1317
  def _self_check(self):
1289
- if not self._is_check_and_refresh:
1290
- self.check_names_and_refresh_name()
1291
- self._is_check_and_refresh = True
1318
+ try:
1319
+ if not self._is_check_and_refresh: # pylint: disable=E0203
1320
+ self.check_names_and_refresh_name()
1321
+ self._is_check_and_refresh = True
1322
+ except AttributeError as e:
1323
+ raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
1324
+ f"Please use 'super().__init__()'.") from e
1292
1325
 
1293
1326
  def _predict(self, *args, **kwargs):
1294
1327
  '''Graph executor for predict'''
@@ -1309,6 +1342,7 @@ class Cell(Cell_):
1309
1342
  def __call__(self, *args, **kwargs):
1310
1343
  # Run in Graph mode.
1311
1344
  if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
1345
+ self._compiled = True
1312
1346
  if kwargs:
1313
1347
  bound_arguments = self._construct_sig.bind(*args, **kwargs)
1314
1348
  bound_arguments.apply_defaults()
@@ -1319,11 +1353,8 @@ class Cell(Cell_):
1319
1353
  if predict_compiled:
1320
1354
  return res
1321
1355
  self._check_construct_args(*args)
1322
-
1323
- if self._hook_fn_registered():
1324
- logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
1325
- f"function, please use context.set_context to set pynative mode.")
1326
1356
  self._self_check()
1357
+ self.__compile_cell_hook__ = True
1327
1358
  out = self.compile_and_run(*args, **kwargs)
1328
1359
  return out
1329
1360
 
@@ -1421,16 +1452,7 @@ class Cell(Cell_):
1421
1452
  exist_names.add(item.name)
1422
1453
  self.insert_param_to_cell(item.name, item, check_name_contain_dot=False)
1423
1454
 
1424
- if context._get_mode() == context.PYNATIVE_MODE:
1425
- if name in self.__dict__:
1426
- del self.__dict__[name]
1427
- params = self.__dict__.get('_params')
1428
- if name in params:
1429
- del params[name]
1430
- params_list = self.__dict__.get('_params_list')
1431
- params_list[name] = value
1432
- else:
1433
- object.__setattr__(self, name, value)
1455
+ object.__setattr__(self, name, value)
1434
1456
 
1435
1457
  def _set_attr_for_parameter_in_list_or_tuple(self, name, value):
1436
1458
  """Set attr for parameter in list or tuple."""
@@ -1609,8 +1631,6 @@ class Cell(Cell_):
1609
1631
  _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1610
1632
  else:
1611
1633
  self._check_construct_args(*inputs)
1612
- # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1613
- # which means that incremental mode is lacking dynamic input.
1614
1634
  else:
1615
1635
  self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
1616
1636
 
@@ -1699,6 +1719,7 @@ class Cell(Cell_):
1699
1719
  _init_auto_parallel_context(self)
1700
1720
  compile_args = self._get_compile_args(args)
1701
1721
  self._has_mutable_args_list = _get_mutable_flags(compile_args)
1722
+ _cell_graph_executor.set_real_args(args, kwargs)
1702
1723
  _cell_graph_executor.compile(self, *compile_args, phase=self.phase,
1703
1724
  jit_config_dict=self._jit_config_dict, **kwargs)
1704
1725
  _clear_auto_parallel_context(self)
@@ -2581,23 +2602,7 @@ class Cell(Cell_):
2581
2602
  else:
2582
2603
  self._jit_config_dict = jit_config.jit_config_dict
2583
2604
 
2584
- def flatten_weights(self, fusion_size=0):
2585
- """
2586
- Reset data for weight parameters so that they are using contiguous memory chunks grouped by data type.
2587
-
2588
- Note:
2589
- By default, parameters with same data type will using a single contiguous memory chunk. but for
2590
- some models with huge number of parameters, splitting a large memory chunk into several smaller
2591
- memory chunks has the potential for performance gains, if this is the case, we can use 'fusion_size'
2592
- to limit the maximum memory chunk size.
2593
-
2594
- Args:
2595
- fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
2596
- """
2597
- if fusion_size < 0:
2598
- raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
2599
- Tensor._flatten_tensors(self.trainable_params(), fusion_size) # pylint: disable=W0212
2600
-
2605
+ @jit_forbidden_register
2601
2606
  def register_forward_pre_hook(self, hook_fn, with_kwargs=False):
2602
2607
  """
2603
2608
  Register forward pre hook function for Cell object.
@@ -2617,7 +2622,6 @@ class Cell(Cell_):
2617
2622
  `with_kwargs` is ``True`` .
2618
2623
 
2619
2624
  Note:
2620
- - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2621
2625
  - The `hook_fn` can modify the forward inputs by returning new inputs. If `with_kwargs` is ``Flase`` , a
2622
2626
  single value (whick will be wrapped into a tuple unless already a tuple) or a tuple of args should be
2623
2627
  returned. If `with_kwargs` is ``True`` , both `args` and `kwargs` should be returned.
@@ -2668,15 +2672,15 @@ class Cell(Cell_):
2668
2672
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2669
2673
  value= [ 2.00000000e+00]))
2670
2674
  """
2671
- if context._get_mode() == context.GRAPH_MODE:
2672
- return HookHandle()
2673
2675
  check_hook_fn(hook_fn)
2674
2676
  handle = HookHandle(self._forward_pre_hook, extra_dict=self._forward_pre_hook_with_kwargs)
2675
2677
  self._forward_pre_hook[handle.handle_id] = hook_fn
2676
2678
  if with_kwargs:
2677
2679
  self._forward_pre_hook_with_kwargs[handle.handle_id] = True
2680
+ _update_hook_version()
2678
2681
  return handle
2679
2682
 
2683
+ @jit_forbidden_register
2680
2684
  def _run_forward_pre_hook(self, args, kwargs):
2681
2685
  """
2682
2686
  Running forward pre hook function registered on Cell object.
@@ -2700,6 +2704,35 @@ class Cell(Cell_):
2700
2704
  args = ret
2701
2705
  return args, kwargs
2702
2706
 
2707
+ def _jit_forward_pre_hook(self, inputs):
2708
+ """
2709
+ Compile forward pre hook function registered on Cell object.
2710
+
2711
+ Args:
2712
+ inputs: The input objects of cell object.
2713
+
2714
+ Returns:
2715
+ - **outputs** - New input objects or none.
2716
+
2717
+ Supported Platforms:
2718
+ ``Ascend`` ``GPU`` ``CPU``
2719
+ """
2720
+ forward_pre_hook_inputs = inputs
2721
+ for fn in self._forward_pre_hook.values():
2722
+ ret = fn(self, forward_pre_hook_inputs)
2723
+ if ret is not None:
2724
+ if not isinstance(ret, tuple):
2725
+ forward_pre_hook_inputs = (ret,)
2726
+ else:
2727
+ forward_pre_hook_inputs = ret
2728
+
2729
+ if len(forward_pre_hook_inputs) != len(inputs):
2730
+ raise TypeError(
2731
+ "The forward pre hook return value size is {} not equal to input size {}".format(
2732
+ len(forward_pre_hook_inputs), len(inputs)))
2733
+ return forward_pre_hook_inputs
2734
+
2735
+ @jit_forbidden_register
2703
2736
  def register_forward_hook(self, hook_fn, with_kwargs=False):
2704
2737
  """
2705
2738
  Register forward hook function for Cell object.
@@ -2720,7 +2753,6 @@ class Cell(Cell_):
2720
2753
  - `output`: Output generated by the `construct` function.
2721
2754
 
2722
2755
  Note:
2723
- - The feature does not take effect in graph mode or in PyNative mode with functions decorated by jit.
2724
2756
  - The `hook_fn` can modify the forward outputs by returning new outputs.
2725
2757
  - In order to prevent running failed when switching to graph mode, it is not recommended to call it in the
2726
2758
  `construct` function of Cell object.
@@ -2773,15 +2805,44 @@ class Cell(Cell_):
2773
2805
  """
2774
2806
  if self.has_bprop:
2775
2807
  return HookHandle()
2776
- if context._get_mode() == context.GRAPH_MODE:
2777
- return HookHandle()
2778
2808
  check_hook_fn(hook_fn)
2779
2809
  handle = HookHandle(self._forward_hook, extra_dict=self._forward_hook_with_kwargs)
2780
2810
  self._forward_hook[handle.handle_id] = hook_fn
2781
2811
  if with_kwargs:
2782
2812
  self._forward_hook_with_kwargs[handle.handle_id] = True
2813
+ _update_hook_version()
2783
2814
  return handle
2784
2815
 
2816
+ def _jit_forward_hook(self, inputs, output):
2817
+ """
2818
+ Compile forward hook function registered on Cell object.
2819
+
2820
+ Args:
2821
+ inputs: The input objects of Cell object.
2822
+ output: The output object of Cell object.
2823
+
2824
+ Returns:
2825
+ - **output** - New output object or none.
2826
+
2827
+ Supported Platforms:
2828
+ ``Ascend`` ``GPU`` ``CPU``
2829
+ """
2830
+ forward_hook_output = output
2831
+ for fn in self._forward_hook.values():
2832
+ ret = fn(self, inputs, forward_hook_output)
2833
+ if ret is not None:
2834
+ forward_hook_output = ret
2835
+
2836
+ if isinstance(output, tuple):
2837
+ if not isinstance(forward_hook_output, tuple):
2838
+ forward_hook_output = (forward_hook_output,)
2839
+ if len(forward_hook_output) != len(output):
2840
+ raise TypeError(
2841
+ "The forward hook return value size is {} not equal to output size {}".format(
2842
+ len(forward_hook_output), len(output)))
2843
+ return forward_hook_output
2844
+
2845
+ @jit_forbidden_register
2785
2846
  def _run_forward_hook(self, args, kwargs, output):
2786
2847
  """
2787
2848
  Running forward hook function registered on Cell object.
@@ -2795,12 +2856,12 @@ class Cell(Cell_):
2795
2856
  output = ret
2796
2857
  return output
2797
2858
 
2859
+ @jit_forbidden_register
2798
2860
  def register_backward_pre_hook(self, hook_fn):
2799
2861
  """
2800
2862
  Register the backward pre hook function.
2801
2863
 
2802
2864
  Note:
2803
- - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2804
2865
  - The 'hook_fn' must be defined as the following code.
2805
2866
  `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2806
2867
  - The 'hook_fn' should have the following signature:
@@ -2849,44 +2910,17 @@ class Cell(Cell_):
2849
2910
  >>> print(output)
2850
2911
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2851
2912
  """
2852
- if context._get_mode() == context.GRAPH_MODE:
2853
- return HookHandle()
2854
2913
  check_hook_fn(hook_fn)
2855
- handle = HookHandle(self._backward_pre_hook)
2914
+ handle = HookHandle(self._backward_pre_hook, extra_dict=None)
2856
2915
  self._backward_pre_hook[handle.handle_id] = hook_fn
2857
- if self._cell_backward_pre_hook is None:
2916
+ if self._cell_backward_pre_hook is None: # pylint: disable=E0203
2858
2917
  # Generate a CellBackwardHook prim, and add function for it
2859
2918
  self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2860
2919
  self, self._backward_pre_hook)
2861
2920
  self._cell_backward_pre_hook.register_backward_pre_hook()
2921
+ _update_hook_version()
2862
2922
  return handle
2863
2923
 
2864
- def _run_backward_pre_hook(self, outputs):
2865
- """
2866
- Running backward pre hook function registered on Cell object.
2867
-
2868
- Args:
2869
- outputs: The output objects of cell object.
2870
-
2871
- Returns:
2872
- - **outputs** - New backward gradient or None.
2873
-
2874
- Supported Platforms:
2875
- ``Ascend`` ``GPU`` ``CPU``
2876
- """
2877
- if isinstance(outputs, tuple):
2878
- ret = self._cell_backward_pre_hook(*outputs)
2879
- else:
2880
- ret = self._cell_backward_pre_hook(outputs)
2881
- if isinstance(outputs, tuple):
2882
- if len(outputs) == 1:
2883
- ret = (ret,)
2884
- if len(ret) != len(outputs):
2885
- raise TypeError(
2886
- "The backward pre hook return value size is {} not equal to output size {}".format(
2887
- len(ret), len(outputs)))
2888
- return ret
2889
-
2890
2924
  def get_extra_state(self) -> Any:
2891
2925
  """Return any extra state to include in the cell's state_dict.
2892
2926
 
@@ -2939,9 +2973,8 @@ class Cell(Cell_):
2939
2973
  A handle that can be used to remove the added hook by calling
2940
2974
  `handle.remove()`.
2941
2975
  """
2942
- from mindspore.utils.hooks import _RemovableHandle
2943
- handle = _RemovableHandle(self._state_dict_hooks)
2944
- self._state_dict_hooks[handle.id] = hook
2976
+ handle = HookHandle(self._state_dict_hooks)
2977
+ self._state_dict_hooks[handle.handle_id] = hook
2945
2978
  return handle
2946
2979
 
2947
2980
  @jit_forbidden_register
@@ -2987,9 +3020,8 @@ class Cell(Cell_):
2987
3020
  >>> print("extra_param" in net_state_dict)
2988
3021
  True
2989
3022
  """
2990
- from mindspore.utils.hooks import _RemovableHandle
2991
- handle = _RemovableHandle(self._state_dict_pre_hooks)
2992
- self._state_dict_pre_hooks[handle.id] = hook
3023
+ handle = HookHandle(self._state_dict_pre_hooks)
3024
+ self._state_dict_pre_hooks[handle.handle_id] = hook
2993
3025
  return handle
2994
3026
 
2995
3027
  def _save_to_state_dict(self, destination, prefix, keep_vars):
@@ -3135,9 +3167,8 @@ class Cell(Cell_):
3135
3167
  A handle that can be used to remove the added hook by calling
3136
3168
  `handle.remove()`.
3137
3169
  """
3138
- from mindspore.utils.hooks import _RemovableHandle
3139
- handle = _RemovableHandle(self._load_state_dict_pre_hooks)
3140
- self._load_state_dict_pre_hooks[handle.id] = hook
3170
+ handle = HookHandle(self._load_state_dict_pre_hooks)
3171
+ self._load_state_dict_pre_hooks[handle.handle_id] = hook
3141
3172
  return handle
3142
3173
 
3143
3174
  @jit_forbidden_register
@@ -3169,9 +3200,8 @@ class Cell(Cell_):
3169
3200
  A handle that can be used to remove the added hook by calling
3170
3201
  `handle.remove()`.
3171
3202
  """
3172
- from mindspore.utils.hooks import _RemovableHandle
3173
- handle = _RemovableHandle(self._load_state_dict_post_hooks)
3174
- self._load_state_dict_post_hooks[handle.id] = hook
3203
+ handle = HookHandle(self._load_state_dict_post_hooks)
3204
+ self._load_state_dict_post_hooks[handle.handle_id] = hook
3175
3205
  return handle
3176
3206
 
3177
3207
  def _load_from_state_dict(
@@ -3407,12 +3437,12 @@ class Cell(Cell_):
3407
3437
  )
3408
3438
  return _IncompatibleKeys(missing_keys, unexpected_keys)
3409
3439
 
3440
+ @jit_forbidden_register
3410
3441
  def register_backward_hook(self, hook_fn):
3411
3442
  """
3412
3443
  Register the backward hook function.
3413
3444
 
3414
3445
  Note:
3415
- - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
3416
3446
  - The 'hook_fn' must be defined as the following code.
3417
3447
  `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
3418
3448
  the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
@@ -3464,83 +3494,17 @@ class Cell(Cell_):
3464
3494
  >>> print(output)
3465
3495
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
3466
3496
  """
3467
- if context._get_mode() == context.GRAPH_MODE:
3468
- return HookHandle()
3469
3497
  check_hook_fn(hook_fn)
3470
- handle = HookHandle(self._backward_hook)
3498
+ handle = HookHandle(self._backward_hook, extra_dict=None)
3471
3499
  self._backward_hook[handle.handle_id] = hook_fn
3472
- if self._cell_backward_hook is None:
3500
+ if self._cell_backward_hook is None: # pylint: disable=E0203
3473
3501
  # Generate a CellBackwardHook prim, and add function for it
3474
3502
  self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
3475
3503
  self, self._backward_hook)
3476
3504
  self._cell_backward_hook.register_backward_hook()
3505
+ _update_hook_version()
3477
3506
  return handle
3478
3507
 
3479
- def _backward_hook_construct(self, *inputs, **kwargs):
3480
- """
3481
- Backward hook construct method to replace original construct method.
3482
-
3483
- Args:
3484
- inputs: The input objects of Cell object.
3485
- kwargs (dict): Dictionary of variable keyword parameters.
3486
-
3487
- Returns:
3488
- - **outputs** - The output objects of Cell object.
3489
-
3490
- Supported Platforms:
3491
- ``Ascend`` ``GPU`` ``CPU``
3492
- """
3493
- # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
3494
- outputs = self._cell_backward_hook(*inputs)
3495
- # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
3496
- # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
3497
- # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
3498
- is_need_unwrap = False
3499
- if isinstance(outputs, tuple) and len(inputs) != 1:
3500
- is_need_unwrap = True
3501
-
3502
- if self._recompute_cell is not None:
3503
- if is_need_unwrap:
3504
- outputs = self._recompute_cell(*outputs, **kwargs)
3505
- else:
3506
- outputs = self._recompute_cell(outputs, **kwargs)
3507
- elif self.has_bprop:
3508
- if is_need_unwrap:
3509
- outputs = self._call_custom_bprop(*outputs, **kwargs)
3510
- else:
3511
- outputs = self._call_custom_bprop(outputs, **kwargs)
3512
- else:
3513
- if is_need_unwrap:
3514
- outputs = self.construct(*outputs, **kwargs)
3515
- else:
3516
- outputs = self.construct(outputs, **kwargs)
3517
- if isinstance(outputs, tuple):
3518
- new_outputs = self._cell_backward_hook(*outputs)
3519
- else:
3520
- new_outputs = self._cell_backward_hook(outputs)
3521
- # if outputs is (X,) and new_outpus is X
3522
- if isinstance(outputs, tuple) and len(outputs) == 1:
3523
- new_outputs = (new_outputs,)
3524
- return new_outputs
3525
-
3526
- def set_param_ps(self, recurse=True, init_in_server=False):
3527
- """
3528
- Set whether the trainable parameters are updated by parameter server and whether the
3529
- trainable parameters are initialized on server.
3530
-
3531
- Note:
3532
- It only works when a running task is in the parameter server mode.
3533
- It is only supported in graph mode.
3534
-
3535
- Args:
3536
- recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
3537
- init_in_server (bool): Whether trainable parameters updated by parameter server are
3538
- initialized on server. Default: ``False`` .
3539
- """
3540
- params = self.trainable_params(recurse)
3541
- for param in params:
3542
- param.set_param_ps(init_in_server)
3543
-
3544
3508
  def set_comm_fusion(self, fusion_type, recurse=True):
3545
3509
  """
3546
3510
  Set `comm_fusion` for all the parameters in this cell. Please refer to the description of
@@ -3601,7 +3565,7 @@ class Cell(Cell_):
3601
3565
  """
3602
3566
  Validator.check_bool(mode)
3603
3567
  Validator.check_bool(output_recompute)
3604
- if not self._has_config_recompute:
3568
+ if not self._has_config_recompute: # pylint: disable=E0203
3605
3569
  self._has_config_recompute = True
3606
3570
  else:
3607
3571
  logger.info("The recompute interface can be configured only once."
@@ -3644,8 +3608,7 @@ class Cell(Cell_):
3644
3608
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
3645
3609
  Default: ``False`` .
3646
3610
  """
3647
- if context.get_context("mode") == context.PYNATIVE_MODE:
3648
- self._recompute_cell = recompute_registry.get()(self.construct)
3611
+ self._recompute_cell = recompute_registry.get()(self.construct)
3649
3612
  self._recompute()
3650
3613
  if 'mp_comm_recompute' in kwargs.keys():
3651
3614
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
@@ -3662,35 +3625,6 @@ class Cell(Cell_):
3662
3625
  "the key kwargs must be 'mp_comm_recompute', "
3663
3626
  "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
3664
3627
 
3665
- def place(self, role, rank_id):
3666
- """
3667
- Set the label for all operators in this cell.
3668
- This label tells MindSpore compiler on which process this cell should be launched.
3669
- And each process's identical label consists of input `role` and `rank_id`.
3670
- So by setting different cells with different labels, which will be launched on different processes,
3671
- users can launch a distributed training or predicting job.
3672
-
3673
- Note:
3674
- - This method is effective only after
3675
- `mindspore.communication.init()` is called for dynamic cluster building.
3676
-
3677
- Args:
3678
- role (str): The role of the process on which this cell will be launched.
3679
- Only 'MS_WORKER' is supported for now.
3680
- rank_id (int): The rank id of the process on which this cell will be launched.
3681
- The rank is unique in processes with the same role.
3682
-
3683
- Examples:
3684
- >>> from mindspore import context
3685
- >>> import mindspore.nn as nn
3686
- >>> context.set_context(mode=context.GRAPH_MODE)
3687
- >>> fc = nn.Dense(2, 3)
3688
- >>> fc.place('MS_WORKER', 0)
3689
- """
3690
- all_ops = self._get_prims_recursively()
3691
- for op in all_ops:
3692
- op.place(role, rank_id)
3693
-
3694
3628
  def _get_attr_from_cell(self, network):
3695
3629
  if not isinstance(network, Cell):
3696
3630
  return
@@ -3705,6 +3639,64 @@ class Cell(Cell_):
3705
3639
  """
3706
3640
  self._jit_graph_name = key
3707
3641
 
3642
+ def _jit_backward_pre_hook(self, grad_output):
3643
+ new_grad_output = grad_output
3644
+ if not isinstance(grad_output, tuple):
3645
+ new_grad_output = (grad_output,)
3646
+
3647
+ for fn in self._backward_pre_hook.values():
3648
+ ret = fn(self, new_grad_output)
3649
+ if ret is not None:
3650
+ if not isinstance(ret, tuple):
3651
+ output = (ret,)
3652
+ else:
3653
+ output = ret
3654
+ else:
3655
+ output = ops.Depend()(new_grad_output, ret)
3656
+ new_grad_output = output
3657
+
3658
+ if not isinstance(grad_output, tuple):
3659
+ if len(new_grad_output) == 1:
3660
+ return new_grad_output[0]
3661
+ raise TypeError(
3662
+ "The backward pre hook return value size is {} not equal to input size 1".format(
3663
+ len(new_grad_output)))
3664
+
3665
+ if len(new_grad_output) != len(grad_output):
3666
+ raise TypeError(
3667
+ "The backward pre hook return value size is {} not equal to input size {}".format(
3668
+ len(new_grad_output), len(grad_output)))
3669
+
3670
+ return new_grad_output
3671
+
3672
+ def _jit_backward_hook(self, grad_input, grad_output):
3673
+ backward_hook_input = grad_input
3674
+ backward_hook_output = grad_output
3675
+ if not isinstance(grad_input, tuple):
3676
+ backward_hook_input = (grad_input,)
3677
+ if not isinstance(grad_output, tuple):
3678
+ backward_hook_output = (grad_output,)
3679
+
3680
+ for fn in self._backward_hook.values():
3681
+ ret = fn(self, backward_hook_input, backward_hook_output)
3682
+ if ret is not None:
3683
+ if not isinstance(ret, tuple):
3684
+ output = (ret,)
3685
+ else:
3686
+ output = ret
3687
+ else:
3688
+ output = ops.Depend()(backward_hook_input, ret)
3689
+
3690
+ backward_hook_input = output
3691
+
3692
+ if not isinstance(grad_input, tuple):
3693
+ return backward_hook_input[0]
3694
+
3695
+ if len(backward_hook_input) != len(grad_input):
3696
+ raise TypeError(
3697
+ "The backward hook return value size is {} not equal to input size {}".format(
3698
+ len(backward_hook_input), len(grad_input)))
3699
+ return backward_hook_input
3708
3700
 
3709
3701
  class GraphCell(Cell):
3710
3702
  """