mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2023 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,10 +20,9 @@ import inspect
20
20
  import os
21
21
  import time
22
22
  from collections import OrderedDict
23
- from types import FunctionType, MethodType
24
23
  import numpy
25
24
 
26
- from mindspore._checkparam import args_type_check
25
+ from mindspore._checkparam import args_type_check, check_hook_fn
27
26
  from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
28
27
  from mindspore import log as logger
29
28
  from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
@@ -33,8 +32,9 @@ from mindspore import context
33
32
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
34
33
  from mindspore import _checkparam as Validator
35
34
  from mindspore.common import dtype as mstype
36
- from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
37
- from mindspore.common.api import _generate_branch_control_input
35
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, _no_grad
36
+ from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
37
+ from mindspore.common.api import _process_dyn_args, _generate_dyn_compile_args
38
38
  from mindspore.common.parameter import Parameter, ParameterTuple
39
39
  from mindspore.common.tensor import Tensor
40
40
  from mindspore.ops.operations import Cast
@@ -43,16 +43,7 @@ from mindspore.ops.operations import _inner_ops as inner
43
43
  from mindspore.parallel.shard import Shard
44
44
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
45
  from mindspore.common._decorator import deprecated
46
- from mindspore._c_expression import PackExpander
47
- from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
48
-
49
-
50
- def _check_args(args):
51
- """Check the input args's type"""
52
- for item in args:
53
- if isinstance(item, Tensor) and item.has_init:
54
- item.init_data()
55
-
46
+ from mindspore.common._register_for_recompute import recompute_registry
56
47
 
57
48
  class Cell(Cell_):
58
49
  """
@@ -89,7 +80,7 @@ class Cell(Cell_):
89
80
 
90
81
  Examples:
91
82
  >>> import mindspore.nn as nn
92
- >>> import mindspore.ops as ops
83
+ >>> from mindspore import ops
93
84
  >>> class MyCell(nn.Cell):
94
85
  ... def __init__(self, forward_net):
95
86
  ... super(MyCell, self).__init__(auto_prefix=False)
@@ -109,17 +100,19 @@ class Cell(Cell_):
109
100
  """
110
101
 
111
102
  IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_create_time',
112
- '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
113
- '_forward_pre_hook', '_forward_hook', '_enable_forward_pre_hook', '_enable_forward_hook',
114
- '_bprop_debug', '_enable_backward_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
103
+ '_func_graph_flags', '_parameter_layout_dict', '_params_list', '_phase', '_bprop_debug',
104
+ '_forward_pre_hook', '_forward_hook', '_backward_pre_hook', '_backward_hook',
105
+ '_cell_backward_pre_hook', '_cell_backward_hook', '_is_run', '_param_prefix',
115
106
  '_attr_synced', 'pynative', 'requires_grad', 'cell_type']
107
+ total_instance_count = 0
116
108
 
117
109
  def __init__(self, auto_prefix=True, flags=None):
118
110
  Cell_.__init__(self, self._cell_tag)
111
+ Cell.total_instance_count += 1
112
+ self.instance_count = Cell.total_instance_count
119
113
  self._params = OrderedDict()
120
114
  self._cells = OrderedDict()
121
115
  self._params_list = OrderedDict()
122
- self._tensor_list = OrderedDict()
123
116
  self._primitives = OrderedDict()
124
117
  self.training = False
125
118
  self.requires_grad = False
@@ -135,11 +128,15 @@ class Cell(Cell_):
135
128
  self._create_time = int(time.time() * 1e9)
136
129
  self.arguments_key = ""
137
130
  self.compile_cache = set()
131
+ self.phase_cache = dict()
138
132
  cells_compile_cache[id(self)] = self.compile_cache
139
133
  self.parameter_broadcast_done = False
140
134
  self._id = 1
141
135
  self.exist_names = set("")
142
136
  self.exist_objs = set()
137
+ self._recompute_cell = None
138
+ self.mixed_precision_type = None
139
+ self.sig = inspect.signature(self.construct)
143
140
  init_pipeline()
144
141
 
145
142
  # call gc to release GE session resources used by non-used cell objects
@@ -149,24 +146,33 @@ class Cell(Cell_):
149
146
  if flags:
150
147
  self.add_flags(**flags)
151
148
  self._bprop_debug = False
149
+
150
+ # hook
152
151
  self._forward_pre_hook = OrderedDict()
153
152
  self._forward_hook = OrderedDict()
154
- self._enable_forward_pre_hook = False
155
- self._enable_forward_hook = False
156
- self._enable_backward_hook = False
153
+ self._backward_pre_hook = OrderedDict()
154
+ self._cell_backward_pre_hook = None
155
+ self._backward_hook = OrderedDict()
157
156
  self._cell_backward_hook = None
158
157
  self._is_recursion_hook = False
158
+
159
159
  self.cell_type = None
160
160
  self.cast = Cast()
161
161
  self._has_config_recompute = False
162
162
  self._user_parameters = []
163
163
  self._dynamic_shape_inputs = None
164
+ self._compile_args = None
164
165
  self.saved_dynamic_shape = None
165
166
  self._jit_config_dict = dict()
166
167
  self.grad_ops_label = False
167
168
  self.ge_sync_data = False
168
169
  self._is_check_and_refresh = False
169
170
  self._amp_level = ""
171
+ self._init_flag = False
172
+ self._shard_fn = None
173
+ self.has_bprop = False
174
+ if hasattr(self, "bprop"):
175
+ self.has_bprop = True
170
176
 
171
177
  def __getstate__(self):
172
178
  base = Cell_.__getstate__(self)
@@ -224,8 +230,9 @@ class Cell(Cell_):
224
230
  Get whether cell custom bprop debug is enabled.
225
231
 
226
232
  Tutorial Examples:
227
- - `Cell and Parameter - Custom Cell Reverse
228
- <https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
233
+ - `Custom Neural Network Layers - Custom Cell Reverse
234
+ <https://mindspore.cn/docs/en/master/model_train/custom_program/network_custom.html
235
+ #custom-cell-reverse>`_
229
236
  """
230
237
  return self._bprop_debug
231
238
 
@@ -317,10 +324,23 @@ class Cell(Cell_):
317
324
 
318
325
  @property
319
326
  def pipeline_stage(self):
327
+ """
328
+ `pipeline_stage` represents the pipeline stage of current Cell.
329
+ """
320
330
  return self._pipeline_stage
321
331
 
322
332
  @pipeline_stage.setter
323
333
  def pipeline_stage(self, value):
334
+ """
335
+ Set the `pipeline_stage` of a Cell.
336
+
337
+ Args:
338
+ value (int): The pipeline stage of a parameter.
339
+
340
+ Raises:
341
+ TypeError: If `value` is not int type or is a bool type.
342
+ ValueError: If `value` is not a positive integer.
343
+ """
324
344
  if not isinstance(value, int) or isinstance(value, bool):
325
345
  raise TypeError("For 'Cell', the property 'pipeline_stage' "
326
346
  "must be int type, but got type : {}".format(type(value)))
@@ -362,6 +382,10 @@ class Cell(Cell_):
362
382
  def jit_config_dict(self):
363
383
  return self._jit_config_dict
364
384
 
385
+ @property
386
+ def enable_backward_hook(self):
387
+ return self._enable_backward_hook
388
+
365
389
  def get_func_graph_proto(self):
366
390
  """Return graph binary proto."""
367
391
  exec_id = ".".join([self.phase, str(self.create_time), str(id(self))])
@@ -376,10 +400,6 @@ class Cell(Cell_):
376
400
  cells = self.__dict__['_cells']
377
401
  if name in cells:
378
402
  return cells[name]
379
- if '_tensor_list' in self.__dict__:
380
- tensor_list = self.__dict__['_tensor_list']
381
- if name in tensor_list:
382
- return tensor_list[name]
383
403
  if '_params_list' in self.__dict__:
384
404
  params_list = self.__dict__['_params_list']
385
405
  if name in params_list:
@@ -391,12 +411,9 @@ class Cell(Cell_):
391
411
  # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
392
412
  # here using pop(id(self), None) to avoid KeyError exception
393
413
  cells_compile_cache.pop(id(self), None)
394
- try:
395
- if self.compile_cache:
396
- _cell_graph_executor.del_net_res(self, self.compile_cache)
397
- except AttributeError as e:
398
- raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
399
- f"Please use 'super().__init__()'.") from e
414
+ if hasattr(self, "compile_cache") and self.compile_cache:
415
+ _cell_graph_executor.del_net_res(self, self.compile_cache)
416
+ Cell.total_instance_count -= 1
400
417
 
401
418
  def __delattr__(self, name):
402
419
  if name in self._params:
@@ -405,8 +422,6 @@ class Cell(Cell_):
405
422
  del self._cells[name]
406
423
  elif '_params_list' in self.__dict__ and name in self._params_list:
407
424
  del self._params_list[name]
408
- elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
409
- del self._tensor_list[name]
410
425
  else:
411
426
  object.__delattr__(self, name)
412
427
  self._attr_synced = False
@@ -420,7 +435,7 @@ class Cell(Cell_):
420
435
  elif isinstance(item, float):
421
436
  res.append(self.cast(item, dst_type))
422
437
  elif hasattr(item, "dtype") and item.dtype in \
423
- {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
438
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
424
439
  res.append(self.cast(item, dst_type))
425
440
  else:
426
441
  res.append(item)
@@ -470,18 +485,28 @@ class Cell(Cell_):
470
485
  output = self._run_construct(cast_inputs, kwargs)
471
486
  return output
472
487
 
473
- def _run_construct(self, cast_inputs, kwargs):
488
+ def _run_construct(self, *inputs, **kwargs):
474
489
  """Run the construct function"""
475
- if self._enable_forward_pre_hook:
476
- cast_inputs = self._run_forward_pre_hook(cast_inputs)
477
- if self._enable_backward_hook:
478
- output = self._backward_hook_construct(*cast_inputs, **kwargs)
479
- elif hasattr(self, "_shard_fn"):
480
- output = self._shard_fn(*cast_inputs, **kwargs)
490
+ if self._forward_pre_hook:
491
+ inputs = self._run_forward_pre_hook(inputs)
492
+
493
+ if self._backward_hook:
494
+ output = self._backward_hook_construct(*inputs, **kwargs)
495
+ elif self._shard_fn is not None:
496
+ output = self._shard_fn(*inputs, **kwargs)
497
+ elif self._recompute_cell is not None:
498
+ output = self._recompute_cell(*inputs, **kwargs)
499
+ elif self.has_bprop and _pynative_executor.requires_grad():
500
+ output = self._call_custom_bprop(*inputs, **kwargs)
481
501
  else:
482
- output = self.construct(*cast_inputs, **kwargs)
483
- if self._enable_forward_hook:
484
- output = self._run_forward_hook(cast_inputs, output)
502
+ output = self.construct(*inputs, **kwargs)
503
+
504
+ if self._forward_hook:
505
+ output = self._run_forward_hook(inputs, output)
506
+
507
+ if self._backward_pre_hook:
508
+ output = self._run_backward_pre_hook(output)
509
+
485
510
  return output
486
511
 
487
512
  def _check_construct_args(self, *args):
@@ -519,7 +544,7 @@ class Cell(Cell_):
519
544
  '''Hook function in graph mode'''
520
545
  # Check super().__init__() in graph mode.
521
546
  try:
522
- if self._enable_forward_pre_hook or self._enable_forward_hook or self._enable_backward_hook:
547
+ if self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook:
523
548
  return True
524
549
  except AttributeError as e:
525
550
  raise AttributeError(f"The '{type(self).__name__}' object does not inherit attribute from 'cell'. "
@@ -566,22 +591,22 @@ class Cell(Cell_):
566
591
  def shard(self, in_strategy, out_strategy=None, parameter_plan=None, device="Ascend", level=0):
567
592
  """
568
593
  Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be
569
- generated by sharding propagation. In PyNative mode, use this method
570
- to specify a Cell for distributed execution in graph mode.
594
+ generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed
595
+ execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell,
596
+ strategy for others will be set by sharding propagation.
571
597
  in_strategy and out_strategy define the input and output layout respectively.
572
598
  in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of
573
- this input/output, and None represents data_parallel,
574
- which can refer to the description of `mindspore.ops.Primitive.shard`.
599
+ this input/output, which can refer to the description of `mindspore.ops.Primitive.shard`.
575
600
  The parallel strategies of remaining operators are derived from the strategy specified by the input and output.
576
601
 
577
602
  Note:
578
- Only effective in PYNATIVE_MODE and in either ParallelMode.AUTO_PARALLEL with
579
- search_mode in auto_parallel_context set as sharding_propagation.
603
+ If Cell.shard is called, the parallel mode in `set_auto_parallel_context` (parallel_mode) will be set to
604
+ "auto_parallel" and the search mode (search_mode) to "sharding_propagation".
580
605
  If the input contain Parameter, its strategy should be set in `in_strategy`.
581
606
 
582
607
  Args:
583
- in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
584
- defines the layout of the corresponding input and None represents a data parallel strategy.
608
+ in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple. Tuple
609
+ defines the layout of the corresponding input.
585
610
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
586
611
  It is not in use right now. Default: ``None`` .
587
612
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
@@ -598,7 +623,7 @@ class Cell(Cell_):
598
623
  use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
599
624
 
600
625
  Returns:
601
- Cell, the cell itself.
626
+ Function, return the cell construct function that will be executed under auto parallel process.
602
627
 
603
628
  Examples:
604
629
  >>> import mindspore.nn as nn
@@ -616,22 +641,21 @@ class Cell(Cell_):
616
641
  ... def __init__(self):
617
642
  ... self.block1 = Block()
618
643
  ... self.block2 = Block()
619
- ... self.block2.shard(in_strategy=((2, 1),), out_strategy=(None,),
620
- ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
644
+ ... self.block2_shard = self.block2.shard(in_strategy=((2, 1),),
645
+ ... parameter_plan={'self.block2.shard.dense1.weight': (4, 1)})
621
646
  ... def construct(self, x):
622
647
  ... x = self.block1(x)
623
- ... x = self.block2(x)
648
+ ... x = self.block2_shard(x)
624
649
  ... return x
625
650
  """
626
- if context.get_context("mode") != context.PYNATIVE_MODE or \
627
- context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel"]:
628
- raise AssertionError(f"Cell shard only supports auto parallel under PyNative mode. "
629
- f"Please check if you call Cell.shard in the script.")
651
+ if context.get_auto_parallel_context("parallel_mode") not in ["auto_parallel", "semi_auto_parallel"]:
652
+ raise AssertionError(f"Cell shard only supports auto parallel or semi_auto_parallel "
653
+ f"Please check the parallel mode in parallel context.")
630
654
 
631
655
  shard_fn = Shard()
632
656
  fn = shard_fn(self, in_strategy, out_strategy, parameter_plan, device, level)
633
- object.__setattr__(self, "_shard_fn", fn)
634
- return self
657
+ self._shard_fn = fn
658
+ return fn
635
659
 
636
660
  def auto_cast_inputs(self, inputs):
637
661
  """
@@ -654,65 +678,113 @@ class Cell(Cell_):
654
678
 
655
679
  return cast_inputs
656
680
 
657
- def __call__(self, *args, **kwargs):
658
- if self.__class__.construct is Cell.construct:
659
- raise AttributeError("For 'Cell', the method 'construct' is not defined.")
660
-
661
- if kwargs:
662
- bound_arguments = inspect.signature(self.construct).bind(*args, **kwargs)
663
- bound_arguments.apply_defaults()
664
- args = bound_arguments.args
665
- kwargs = bound_arguments.kwargs
681
+ def _init_check(self):
682
+ for param in self.get_parameters(expand=False):
683
+ if param.has_init:
684
+ param.init_data()
685
+ self._init_flag = True
666
686
 
667
- if PackFunc.is_tracing():
668
- return self._run_tracefunc(*args, **kwargs)
669
-
670
- if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
687
+ def _self_check(self):
688
+ if not self._is_check_and_refresh:
671
689
  self.check_names_and_refresh_name()
672
690
  self._is_check_and_refresh = True
673
691
 
692
+ def _predict(self, *args, **kwargs):
693
+ if not hasattr(self, "phase"):
694
+ return False, None
695
+ if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
696
+ new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
697
+ res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
698
+ res = _convert_python_data(res)
699
+ return True, res
700
+ return False, None
701
+
702
+ def __call__(self, *args, **kwargs):
674
703
  # Run in Graph mode.
675
- if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
704
+ if context._get_mode() == context.GRAPH_MODE and os.getenv("MS_JIT") != '0':
705
+ if kwargs:
706
+ bound_arguments = self.sig.bind(*args, **kwargs)
707
+ bound_arguments.apply_defaults()
708
+ args = bound_arguments.args
709
+ kwargs = bound_arguments.kwargs
710
+
711
+ predict_compiled, res = self._predict(*args, **kwargs)
712
+ if predict_compiled:
713
+ return res
676
714
  self._check_construct_args(*args)
715
+
677
716
  if self._hook_fn_registered():
678
717
  logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
679
718
  f"function, please use context.set_context to set pynative mode.")
719
+ self._self_check()
680
720
  out = self.compile_and_run(*args, **kwargs)
681
721
  return out
682
722
 
683
723
  # Run in PyNative mode.
684
- if _pynative_executor.is_first_cell():
685
- _pynative_executor._optimizer = getattr(self, "optimizer", None)
686
- _pynative_executor._top_cell = self
687
- # There many Casts in parameter_broadcast. Enable build faster.
688
- self._do_parameter_broadcast()
724
+ if not (self._init_flag or self._is_check_and_refresh):
725
+ self._init_check()
726
+ self._self_check()
689
727
 
690
- _check_args(args)
691
- self._check_cell_flags_in_pynative()
728
+ if not (self.requires_grad or self._dynamic_shape_inputs or self.mixed_precision_type):
729
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
730
+ self._shard_fn or self._recompute_cell or (self.has_bprop and _pynative_executor.requires_grad())):
731
+ return self.construct(*args, **kwargs)
692
732
 
693
- if self.requires_grad and _pynative_executor.enable_grad():
694
- _pynative_executor.set_grad_flag(True)
733
+ return self._run_construct(*args, **kwargs)
695
734
 
696
- if self._dynamic_shape_inputs is not None:
697
- self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
735
+ return self._complex_call(*args, **kwargs)
698
736
 
699
- try:
737
+ def _complex_call(self, *args, **kwargs):
738
+ """
739
+ PyNative call with requires_grad or hooks
740
+ """
741
+ self._call_pre_process(*args, **kwargs)
742
+
743
+ if not (self._forward_pre_hook or self._forward_hook or self._backward_pre_hook or self._backward_hook or
744
+ self._shard_fn or self._recompute_cell or self.has_bprop):
745
+ output = self.construct(*args, **kwargs)
746
+ else:
747
+ output = self._run_construct(*args, **kwargs)
748
+
749
+ self._call_post_process(output, *args, **kwargs)
750
+
751
+ return output
752
+
753
+ def _call_pre_process(self, *args, **kwargs):
754
+ """
755
+ Process cell info before call construct
756
+ """
757
+ if self.requires_grad:
758
+ _pynative_executor.set_grad_flag(True)
700
759
  _pynative_executor.new_graph(self, *args, **kwargs)
701
- output = self._run_construct(args, kwargs)
760
+ elif self._dynamic_shape_inputs is not None:
761
+ _pynative_executor.set_cell_use_dynamic_shape_process(True)
762
+
763
+ # Set mixed precision
764
+ if self.mixed_precision_type is not None:
765
+ _pynative_executor.set_mixed_precision_type(self.mixed_precision_type)
766
+
767
+ def _call_post_process(self, output, *args, **kwargs):
768
+ """
769
+ Process cell info after call construct
770
+ """
771
+ if self.requires_grad:
702
772
  _pynative_executor.end_graph(self, output, *args, **kwargs)
703
- except Exception as err:
704
- _pynative_executor.clear_res()
705
- raise err
773
+ elif self._dynamic_shape_inputs is not None:
774
+ _pynative_executor.set_cell_use_dynamic_shape_process(False)
706
775
 
707
- if isinstance(output, Parameter):
708
- output = output.data
709
- return output
776
+ # mixed precision reset
777
+ if self.mixed_precision_type is not None:
778
+ _pynative_executor.set_mixed_precision_type(MixedPrecisionType.NOTSET, False)
710
779
 
711
- def _check_cell_flags_in_pynative(self):
712
- """Check the flags added to cell in pynative mode"""
713
- if hasattr(self, "_func_graph_flags") and self._func_graph_flags.get("output_no_recompute"):
714
- raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
715
- "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
780
+ def _call_custom_bprop(self, *args, **kwargs):
781
+ """
782
+ Call custom bprop for cell bprop.
783
+ """
784
+ with _no_grad():
785
+ output = self.construct(*args, **kwargs)
786
+ _pynative_executor.call_custom_bprop(self, output, *args, **kwargs)
787
+ return output
716
788
 
717
789
  def _add_attr(self, name, value):
718
790
  if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
@@ -830,15 +902,6 @@ class Cell(Cell_):
830
902
  else:
831
903
  self.insert_param_to_cell(name, None)
832
904
 
833
- def _set_attr_for_tensor(self, name, value):
834
- if context._get_mode() == context.PYNATIVE_MODE:
835
- tensor_list = self.__dict__.get('_tensor_list')
836
- if name in self.__dict__:
837
- del self.__dict__[name]
838
- tensor_list[name] = value
839
- else:
840
- object.__setattr__(self, name, value)
841
-
842
905
  def __setattr__(self, name, value):
843
906
  cells = self.__dict__.get('_cells')
844
907
  params = self.__dict__.get('_params')
@@ -856,8 +919,6 @@ class Cell(Cell_):
856
919
  if value is not None:
857
920
  raise TypeError(f"For 'Cell', the type of {name} must be cell, but got {type(value).__name__}.")
858
921
  self._cells[name] = None
859
- elif isinstance(value, Tensor):
860
- self._set_attr_for_tensor(name, value)
861
922
  else:
862
923
  if isinstance(value, Primitive):
863
924
  value.set_prim_instance_name(name)
@@ -910,14 +971,25 @@ class Cell(Cell_):
910
971
  """
911
972
  logger.warning("'set_parallel_input_with_inputs' function is deprecated.")
912
973
 
913
- def set_inputs(self, *inputs):
974
+ def set_inputs(self, *inputs, **kwargs):
914
975
  """
915
976
  Save set inputs for computation graph. The number of inputs should be the same with that of the datasets. When
916
977
  using Model for dynamic shape, please make sure that all networks and loss functions passed to the Model are
917
- configured with set_inputs. The inputs can be Tensor of either dynamic or static shape.
978
+ configured with set_inputs. The shape of input Tensor can be either dynamic or static.
979
+
980
+ .. note::
981
+ There are two mode:
982
+
983
+ - Full mode: arguments will be used as all compile inputs for graph-compiling.
984
+ - Incremental mode: arguments will set to some of the Cell inputs, which will be substituted into the input
985
+ at the corresponding position for graph-compiling.
986
+
987
+ Only one of inputs or kwargs can be set. Inputs for full mode and kwargs for incremental mode.
918
988
 
919
989
  Args:
920
- inputs (tuple): Inputs of the Cell object.
990
+ inputs (tuple): Full mode arguments.
991
+ kwargs (dict): Incremental mode arguments. The acceptable key is the name of parameter defined
992
+ in `self.construct`.
921
993
 
922
994
  .. warning::
923
995
  This is an experimental API that is subject to change or deletion.
@@ -937,16 +1009,30 @@ class Cell(Cell_):
937
1009
  >>> net = ReluNet()
938
1010
  >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
939
1011
  >>> net.set_inputs(input_dyn)
940
- >>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
941
- >>> output = net(input1)
1012
+ >>> input = Tensor(np.random.random([3, 10]), dtype=ms.float32)
1013
+ >>> output = net(input)
1014
+ >>>
1015
+ >>> net2 = ReluNet()
1016
+ >>> net2.set_inputs(x=input_dyn)
1017
+ >>> output = net2(input)
942
1018
  """
943
1019
  if self.grad_ops_label:
944
1020
  logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
945
1021
  f'generated.')
946
- self._dynamic_shape_inputs = inputs
947
- self._check_construct_args(*inputs)
948
- if context._get_mode() == context.PYNATIVE_MODE:
949
- _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1022
+ if kwargs and inputs:
1023
+ raise ValueError('For Cell, set_inputs should only set inputs or kwargs(inputs: %s, kwargs: %s)!'
1024
+ % (inputs, kwargs))
1025
+
1026
+ if not kwargs:
1027
+ self._dynamic_shape_inputs = inputs
1028
+ if context._get_mode() == context.PYNATIVE_MODE:
1029
+ _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
1030
+ else:
1031
+ self._check_construct_args(*inputs)
1032
+ # TODO(tronzhang): It may error for no actually args here. So just set in fullmode,
1033
+ # which means that incremental mode is lacking dynamic input.
1034
+ else:
1035
+ self._dynamic_shape_inputs = _process_dyn_args(self.construct, kwargs)
950
1036
 
951
1037
  def get_inputs(self):
952
1038
  """
@@ -981,6 +1067,48 @@ class Cell(Cell_):
981
1067
 
982
1068
  return self._dynamic_shape_inputs
983
1069
 
1070
+ def _check_parameter_consistency(self, set_inputs, net_inputs):
1071
+ """Check consistency for parameter."""
1072
+ for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
1073
+ if isinstance(set_input, Tensor):
1074
+ if not isinstance(net_input, Tensor):
1075
+ raise TypeError(
1076
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must "
1077
+ f"be Tensor, but got {type(net_input)}.")
1078
+ if isinstance(set_input, Parameter) != isinstance(net_input, Parameter):
1079
+ raise TypeError(
1080
+ f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
1081
+ f"as expected, but got expected: {type(set_input)} and input: {type(net_input)}.")
1082
+ elif isinstance(set_input, (tuple, list)):
1083
+ if not isinstance(net_input, (tuple, list)):
1084
+ raise TypeError(
1085
+ f"The {index + 1}th input type of 'set_inputs' or tuple(list) in "
1086
+ f"'set_inputs' must be tuple or list, but got {type(net_input)}.")
1087
+ self._check_parameter_consistency(set_input, net_input)
1088
+
1089
+ def _get_compile_args(self, args):
1090
+ """Get compile arguments."""
1091
+ # this is used only for test
1092
+ set_by_auto_dynamic = False
1093
+ if is_auto_dynamic():
1094
+ if self._dynamic_shape_inputs is None:
1095
+ set_by_auto_dynamic = True
1096
+ else:
1097
+ if isinstance(self._dynamic_shape_inputs, (list, tuple)) and self._dynamic_shape_inputs[0] is None:
1098
+ set_by_auto_dynamic = True
1099
+ if set_by_auto_dynamic:
1100
+ self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
1101
+
1102
+ if self._dynamic_shape_inputs is not None:
1103
+ logger.debug("Compiled Graph with dynamic shape")
1104
+ compile_args = _generate_dyn_compile_args(args, self._dynamic_shape_inputs)
1105
+ _cell_graph_executor._graph_executor.check_argument_consistency(compile_args, args, "set_inputs")
1106
+ self._check_parameter_consistency(compile_args, args)
1107
+ Validator.check_symbolic_shape(compile_args, args)
1108
+ self.saved_dynamic_shape = compile_args
1109
+ return compile_args
1110
+ return args
1111
+
984
1112
  def compile(self, *args, **kwargs):
985
1113
  """
986
1114
  Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
@@ -989,19 +1117,9 @@ class Cell(Cell_):
989
1117
  args (tuple): Args of the Cell object.
990
1118
  kwargs (dict): Kwargs of the Cell object.
991
1119
  """
992
- # this is used only for test
993
- if is_auto_dynamic() and (self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None):
994
- self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
995
-
996
- if self._dynamic_shape_inputs is None:
997
- _cell_graph_executor.compile(self, phase=self.phase,
998
- jit_config_dict=self._jit_config_dict, *args, **kwargs)
999
- else:
1000
- self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
1001
- self.saved_dynamic_shape = self._dynamic_shape_inputs
1002
- _cell_graph_executor.compile(self, *self._dynamic_shape_inputs, phase=self.phase,
1003
- jit_config_dict=self._jit_config_dict, **kwargs)
1004
- logger.debug("Compiled Graph with dynamic shape")
1120
+ self._compile_args = self._get_compile_args(args)
1121
+ _cell_graph_executor.compile(self, *self._compile_args, phase=self.phase,
1122
+ jit_config_dict=self._jit_config_dict, **kwargs)
1005
1123
 
1006
1124
  def compile_and_run(self, *args, **kwargs):
1007
1125
  """
@@ -1019,7 +1137,7 @@ class Cell(Cell_):
1019
1137
  """
1020
1138
  self.compile(*args, **kwargs)
1021
1139
  self.add_flags(ge_sync_data=False)
1022
- new_args = _get_args_for_run(self, args, kwargs)
1140
+ new_args = _get_args_for_run(self, args, kwargs, self._compile_args)
1023
1141
  return _cell_graph_executor(self, *new_args, phase=self.phase)
1024
1142
 
1025
1143
  def auto_parallel_compile_and_run(self):
@@ -1033,6 +1151,7 @@ class Cell(Cell_):
1033
1151
 
1034
1152
  def exec_checkpoint_graph(self):
1035
1153
  """Executes GE saving checkpoint graph operation."""
1154
+ logger.warning("'exec_checkpoint_graph' function is deprecated.")
1036
1155
  self.add_flags(ge_sync_data=True)
1037
1156
  _cell_graph_executor(self, phase='save')
1038
1157
 
@@ -1070,14 +1189,14 @@ class Cell(Cell_):
1070
1189
  Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
1071
1190
  """
1072
1191
  if not param_name:
1073
- raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
1192
+ raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not be None.")
1074
1193
  if check_name_contain_dot and '.' in param_name:
1075
- raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not contain \".\"")
1194
+ raise KeyError(f"For 'insert_param_to_cell', the argument 'param_name' should not contain'.' ")
1076
1195
  if '_params' not in self.__dict__:
1077
- raise AttributeError("For 'insert_param_to_cell', please call Cell.__init__() firstly.")
1196
+ raise AttributeError(f"For 'insert_param_to_cell', please call Cell.__init__() firstly.")
1078
1197
  if hasattr(self, param_name) and param_name not in self._params:
1079
- raise KeyError("For 'insert_param_to_cell', the {} parameter already exists in the network. Cannot "
1080
- "insert another parameter with the same name.".format(param_name))
1198
+ raise KeyError(f"For 'insert_param_to_cell', the {param_name} parameter already exists in the network."
1199
+ f"Cannot insert another parameter with the same name.")
1081
1200
  if not isinstance(param, Parameter) and param is not None:
1082
1201
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1083
1202
  f"but got {type(param)}.")
@@ -1139,11 +1258,11 @@ class Cell(Cell_):
1139
1258
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
1140
1259
  f"but got {type(child_name)}.")
1141
1260
  if not child_name or '.' in child_name:
1142
- raise KeyError("For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
1143
- "can not contain '.'")
1261
+ raise KeyError(f"For 'insert_child_to_cell', the parameter 'child_name' can not be None and "
1262
+ "can not contain '.' ")
1144
1263
  if hasattr(self, child_name) and child_name not in self._cells:
1145
- raise KeyError("For 'insert_child_to_cell', the {} child cell already exists in the network. Cannot "
1146
- "insert another child cell with the same name.".format(child_name))
1264
+ raise KeyError(f"For 'insert_child_to_cell', the {child_name} child cell already exists in the network."
1265
+ f"Cannot insert another child cell with the same name.")
1147
1266
  if not isinstance(child_cell, Cell) and child_cell is not None:
1148
1267
  raise TypeError(f"For 'insert_child_to_cell', the argument 'child_cell' must be 'Cell' if not None, "
1149
1268
  f"but got type {type(child_cell)}.")
@@ -1163,7 +1282,7 @@ class Cell(Cell_):
1163
1282
  Returns:
1164
1283
  Tensor, returns the computed result.
1165
1284
  """
1166
- return None
1285
+ raise AttributeError("For 'Cell', the method 'construct' is not defined.")
1167
1286
 
1168
1287
  def remove_redundant_parameters(self):
1169
1288
  """
@@ -1361,7 +1480,7 @@ class Cell(Cell_):
1361
1480
 
1362
1481
  Tutorial Examples:
1363
1482
  - `Model Training - Optimizer
1364
- <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
1483
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
1365
1484
  """
1366
1485
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1367
1486
 
@@ -1472,7 +1591,7 @@ class Cell(Cell_):
1472
1591
 
1473
1592
  Tutorial Examples:
1474
1593
  - `Building a Network - Model Parameters
1475
- <https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
1594
+ <https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
1476
1595
  """
1477
1596
  cells = []
1478
1597
  if expand:
@@ -1630,10 +1749,13 @@ class Cell(Cell_):
1630
1749
  def _add_mixed_precision_flag(self, **flags):
1631
1750
  """Add mixed precision flag to current cell"""
1632
1751
  if "fp16" in flags and flags.get("fp16", False):
1752
+ self.mixed_precision_type = MixedPrecisionType.FP16
1633
1753
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1634
1754
  if "fp32" in flags and flags.get("fp32", False):
1755
+ self.mixed_precision_type = MixedPrecisionType.FP32
1635
1756
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1636
1757
  if "bf16" in flags and flags.get("bf16", False):
1758
+ self.mixed_precision_type = MixedPrecisionType.BF16
1637
1759
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1638
1760
 
1639
1761
  def apply(self, fn):
@@ -1698,6 +1820,9 @@ class Cell(Cell_):
1698
1820
  if not hasattr(self, "_func_graph_flags"):
1699
1821
  self._func_graph_flags = {}
1700
1822
  self._func_graph_flags.update({**flags})
1823
+ if context._get_mode() == context.PYNATIVE_MODE and self._func_graph_flags.get("output_no_recompute"):
1824
+ raise TypeError("Recompute is not supported in PyNative mode currently, you can use "
1825
+ "'context.set_context(mode=context.GRAPH_MODE)' or @jit to set graph mode.")
1701
1826
  self.__dict__.update({**flags})
1702
1827
  self._add_mixed_precision_flag(**flags)
1703
1828
  return self
@@ -1808,7 +1933,7 @@ class Cell(Cell_):
1808
1933
  accelerate the algorithm in the algorithm library.
1809
1934
 
1810
1935
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1811
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
1936
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
1812
1937
 
1813
1938
  Note:
1814
1939
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1865,7 +1990,7 @@ class Cell(Cell_):
1865
1990
 
1866
1991
  Tutorial Examples:
1867
1992
  - `Model Training - Implementing Training and Evaluation
1868
- <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
1993
+ <https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
1869
1994
  """
1870
1995
  if mode:
1871
1996
  self._phase = 'train'
@@ -1945,11 +2070,11 @@ class Cell(Cell_):
1945
2070
  Note:
1946
2071
  - The `register_forward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
1947
2072
  - 'hook_fn' must be defined as the following code.
1948
- `cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
2073
+ `cell` is the object of registered Cell. `inputs` is the forward
1949
2074
  input objects passed to the Cell. The 'hook_fn' can modify the forward input objects by returning new
1950
2075
  forward input objects.
1951
2076
  - It should have the following signature:
1952
- hook_fn(cell_id, inputs) -> new input objects or none.
2077
+ hook_fn(cell, inputs) -> new input objects or none.
1953
2078
  - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
1954
2079
  `construct` function of Cell object. In the pynative mode, if the `register_forward_pre_hook` function is
1955
2080
  called in the `construct` function of the Cell object, a hook function will be added at each run time of
@@ -1959,8 +2084,8 @@ class Cell(Cell_):
1959
2084
  hook_fn (function): Python function. Forward pre hook function.
1960
2085
 
1961
2086
  Returns:
1962
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
1963
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
2087
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2088
+ `handle.remove()` .
1964
2089
 
1965
2090
  Raises:
1966
2091
  TypeError: If the `hook_fn` is not a function of python.
@@ -1973,7 +2098,7 @@ class Cell(Cell_):
1973
2098
  >>> import mindspore as ms
1974
2099
  >>> from mindspore import Tensor, nn, ops
1975
2100
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1976
- >>> def forward_pre_hook_fn(cell_id, inputs):
2101
+ >>> def forward_pre_hook_fn(cell, inputs):
1977
2102
  ... print("forward inputs: ", inputs)
1978
2103
  ...
1979
2104
  >>> class Net(nn.Cell):
@@ -1995,24 +2120,12 @@ class Cell(Cell_):
1995
2120
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
1996
2121
  value= [ 2.00000000e+00]))
1997
2122
  """
1998
- if context.get_context("mode") != context.PYNATIVE_MODE:
1999
- logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
2000
- f"context.set_context to set pynative mode.")
2123
+ if context._get_mode() == context.GRAPH_MODE:
2001
2124
  return HookHandle()
2002
-
2003
- if not isinstance(hook_fn, (FunctionType, MethodType)):
2004
- raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
2005
- f"function, but got {type(hook_fn)}.")
2006
- if hook_fn.__code__.co_name == "staging_specialize":
2007
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
2008
-
2009
- self._enable_forward_pre_hook = True
2010
- _pynative_executor.set_hook_changed(self)
2011
- if not hasattr(self, '_forward_pre_hook_key'):
2012
- self._forward_pre_hook_key = -1
2013
- self._forward_pre_hook_key += 1
2014
- self._forward_pre_hook[self._forward_pre_hook_key] = hook_fn
2015
- handle = HookHandle(self, self._forward_pre_hook_key, "_forward_pre_hook")
2125
+ if not check_hook_fn("register_forward_pre_hook", hook_fn):
2126
+ return HookHandle()
2127
+ handle = HookHandle(self._forward_pre_hook)
2128
+ self._forward_pre_hook[handle.handle_id] = hook_fn
2016
2129
  return handle
2017
2130
 
2018
2131
  def _run_forward_pre_hook(self, inputs):
@@ -2028,15 +2141,23 @@ class Cell(Cell_):
2028
2141
  Supported Platforms:
2029
2142
  ``Ascend`` ``GPU`` ``CPU``
2030
2143
  """
2031
- cell_id = self.cls_name + "(" + str(id(self)) + ")"
2144
+ forward_pre_hook_inputs = inputs
2032
2145
  for fn in self._forward_pre_hook.values():
2033
- ret = fn(cell_id, inputs)
2146
+ ret = fn(self, forward_pre_hook_inputs)
2034
2147
  if ret is not None:
2035
2148
  if not isinstance(ret, tuple):
2036
- inputs = (ret,)
2149
+ forward_pre_hook_inputs = (ret,)
2037
2150
  else:
2038
- inputs = ret
2039
- return inputs
2151
+ forward_pre_hook_inputs = ret
2152
+
2153
+ if isinstance(inputs, tuple):
2154
+ if not isinstance(forward_pre_hook_inputs, tuple):
2155
+ forward_pre_hook_inputs = (forward_pre_hook_inputs,)
2156
+ if len(forward_pre_hook_inputs) != len(inputs):
2157
+ raise TypeError(
2158
+ "The forward pre hook return value size is {} not equal to input size {}".format(
2159
+ len(forward_pre_hook_inputs), len(inputs)))
2160
+ return forward_pre_hook_inputs
2040
2161
 
2041
2162
  def register_forward_hook(self, hook_fn):
2042
2163
  """
@@ -2045,11 +2166,11 @@ class Cell(Cell_):
2045
2166
  Note:
2046
2167
  - The `register_forward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2047
2168
  - 'hook_fn' must be defined as the following code.
2048
- `cell_id` is the information of registered Cell object, including name and ID. `inputs` is the forward
2169
+ `cell` is the object of registered Cell. `inputs` is the forward
2049
2170
  input objects passed to the Cell. `output` is the forward output object of the Cell. The 'hook_fn' can
2050
2171
  modify the forward output object by returning new forward output object.
2051
2172
  - It should have the following signature:
2052
- hook_fn(cell_id, inputs, output) -> new output object or none.
2173
+ hook_fn(cell, inputs, output) -> new output object or none.
2053
2174
  - In order to prevent running failed when switching to graph mode, it is not recommended to write it in the
2054
2175
  `construct` function of Cell object. In the pynative mode, if the `register_forward_hook` function is
2055
2176
  called in the `construct` function of the Cell object, a hook function will be added at each run time of
@@ -2059,8 +2180,8 @@ class Cell(Cell_):
2059
2180
  hook_fn (function): Python function. Forward hook function.
2060
2181
 
2061
2182
  Returns:
2062
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
2063
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
2183
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2184
+ `handle.remove()` .
2064
2185
 
2065
2186
  Raises:
2066
2187
  TypeError: If the `hook_fn` is not a function of python.
@@ -2073,7 +2194,7 @@ class Cell(Cell_):
2073
2194
  >>> import mindspore as ms
2074
2195
  >>> from mindspore import Tensor, nn, ops
2075
2196
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2076
- >>> def forward_hook_fn(cell_id, inputs, output):
2197
+ >>> def forward_hook_fn(cell, inputs, output):
2077
2198
  ... print("forward inputs: ", inputs)
2078
2199
  ... print("forward output: ", output)
2079
2200
  ...
@@ -2097,24 +2218,12 @@ class Cell(Cell_):
2097
2218
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
2098
2219
  value= [ 2.00000000e+00]))
2099
2220
  """
2100
- if context.get_context("mode") != context.PYNATIVE_MODE:
2101
- logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
2102
- f"context.set_context to set pynative mode.")
2221
+ if context._get_mode() == context.GRAPH_MODE:
2103
2222
  return HookHandle()
2104
-
2105
- if not isinstance(hook_fn, (FunctionType, MethodType)):
2106
- raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
2107
- f"function, but got {type(hook_fn)}.")
2108
- if hook_fn.__code__.co_name == "staging_specialize":
2109
- raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
2110
-
2111
- self._enable_forward_hook = True
2112
- _pynative_executor.set_hook_changed(self)
2113
- if not hasattr(self, '_forward_hook_key'):
2114
- self._forward_hook_key = -1
2115
- self._forward_hook_key += 1
2116
- self._forward_hook[self._forward_hook_key] = hook_fn
2117
- handle = HookHandle(self, self._forward_hook_key, "_forward_hook")
2223
+ if not check_hook_fn("register_forward_hook", hook_fn):
2224
+ return HookHandle()
2225
+ handle = HookHandle(self._forward_hook)
2226
+ self._forward_hook[handle.handle_id] = hook_fn
2118
2227
  return handle
2119
2228
 
2120
2229
  def _run_forward_hook(self, inputs, output):
@@ -2131,12 +2240,110 @@ class Cell(Cell_):
2131
2240
  Supported Platforms:
2132
2241
  ``Ascend`` ``GPU`` ``CPU``
2133
2242
  """
2134
- cell_id = self.cls_name + "(" + str(id(self)) + ")"
2243
+ forward_hook_output = output
2135
2244
  for fn in self._forward_hook.values():
2136
- ret = fn(cell_id, inputs, output)
2245
+ ret = fn(self, inputs, forward_hook_output)
2137
2246
  if ret is not None:
2138
- output = ret
2139
- return output
2247
+ forward_hook_output = ret
2248
+
2249
+ if isinstance(output, tuple):
2250
+ if not isinstance(forward_hook_output, tuple):
2251
+ forward_hook_output = (forward_hook_output,)
2252
+ if len(forward_hook_output) != len(output):
2253
+ raise TypeError(
2254
+ "The forward hook return value size is {} not equal to output size {}".format(
2255
+ len(forward_hook_output), len(output)))
2256
+ return forward_hook_output
2257
+
2258
+ def register_backward_pre_hook(self, hook_fn):
2259
+ """
2260
+ Register the backward pre hook function.
2261
+
2262
+ Note:
2263
+ - The `register_backward_pre_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2264
+ - The 'hook_fn' must be defined as the following code.
2265
+ `cell` is the Cell object. `grad_output` is the gradient passed to the Cell.
2266
+ - The 'hook_fn' should have the following signature:
2267
+ hook_fn(cell, grad_output) -> New grad_output gradient or None.
2268
+ - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2269
+ graph mode, it is not recommended to write it in the `construct` function of Cell object.
2270
+ - In the pynative
2271
+ mode, if the `register_backward_pre_hook` function is called in the `construct` function of the Cell
2272
+ object, a hook function will be added at each run time of Cell object.
2273
+
2274
+ Args:
2275
+ hook_fn (function): Python function. Backward pre hook function.
2276
+
2277
+ Returns:
2278
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2279
+ `handle.remove()` .
2280
+
2281
+ Raises:
2282
+ TypeError: If the `hook_fn` is not a function of python.
2283
+
2284
+ Supported Platforms:
2285
+ ``Ascend`` ``GPU`` ``CPU``
2286
+
2287
+ Examples:
2288
+ >>> import numpy as np
2289
+ >>> import mindspore as ms
2290
+ >>> from mindspore import Tensor, nn, ops
2291
+ >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2292
+ >>> def backward_pre_hook_fn(cell, grad_output):
2293
+ ... print("backward input: ", grad_output)
2294
+ ...
2295
+ >>> class Net(nn.Cell):
2296
+ ... def __init__(self):
2297
+ ... super(Net, self).__init__()
2298
+ ... self.relu = nn.ReLU()
2299
+ ... self.handle = self.relu.register_backward_pre_hook(backward_pre_hook_fn)
2300
+ ...
2301
+ ... def construct(self, x):
2302
+ ... x = x + x
2303
+ ... x = self.relu(x)
2304
+ ... return x
2305
+ >>> grad = ops.GradOperation(get_all=True)
2306
+ >>> net = Net()
2307
+ >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
2308
+ backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
2309
+ >>> print(output)
2310
+ (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2311
+ """
2312
+ if context._get_mode() == context.GRAPH_MODE:
2313
+ return HookHandle()
2314
+ if not check_hook_fn("register_backward_pre_hook", hook_fn):
2315
+ return HookHandle()
2316
+ handle = HookHandle(self._backward_pre_hook)
2317
+ self._backward_pre_hook[handle.handle_id] = hook_fn
2318
+ if self._cell_backward_pre_hook is None:
2319
+ # Generate a CellBackwardHook prim, and add function for it
2320
+ self._cell_backward_pre_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2321
+ self, self._backward_pre_hook)
2322
+ self._cell_backward_pre_hook.register_backward_pre_hook()
2323
+ return handle
2324
+
2325
+ def _run_backward_pre_hook(self, outputs):
2326
+ """
2327
+ Running backward pre hook function registered on Cell object.
2328
+
2329
+ Args:
2330
+ outputs: The output objects of cell object.
2331
+
2332
+ Returns:
2333
+ - **outputs** - New backward gradient or None.
2334
+
2335
+ Supported Platforms:
2336
+ ``Ascend`` ``GPU`` ``CPU``
2337
+ """
2338
+ ret = self._cell_backward_pre_hook(outputs)
2339
+ if isinstance(outputs, tuple):
2340
+ if not isinstance(ret, tuple):
2341
+ ret = (ret,)
2342
+ if len(ret) != len(outputs):
2343
+ raise TypeError(
2344
+ "The backward pre hook return value size is {} not equal to output size {}".format(
2345
+ len(ret), len(outputs)))
2346
+ return ret
2140
2347
 
2141
2348
  def register_backward_hook(self, hook_fn):
2142
2349
  """
@@ -2145,11 +2352,11 @@ class Cell(Cell_):
2145
2352
  Note:
2146
2353
  - The `register_backward_hook(hook_fn)` does not work in graph mode or functions decorated with 'jit'.
2147
2354
  - The 'hook_fn' must be defined as the following code.
2148
- `cell_id` is the information of registered Cell object, including name and ID. `grad_input` is the
2149
- gradient passed to the Cell. `grad_output` is the gradient computed and passed to the next Cell or
2150
- primitive, which may be modified by returning a new output gradient.
2355
+ `cell` is the registered Cell object. `grad_input` is the gradient computed and passed to
2356
+ the next Cell or primitive, which can be return a new gradient or None. `grad_output` is the gradient
2357
+ passed to the Cell.
2151
2358
  - The 'hook_fn' should have the following signature:
2152
- hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
2359
+ hook_fn(cell, grad_input, grad_output) -> New grad_input gradient or none.
2153
2360
  - The 'hook_fn' is executed in the python environment. In order to prevent running failed when switching to
2154
2361
  graph mode, it is not recommended to write it in the `construct` function of Cell object. In the pynative
2155
2362
  mode, if the `register_backward_hook` function is called in the `construct` function of the Cell object,
@@ -2159,8 +2366,8 @@ class Cell(Cell_):
2159
2366
  hook_fn (function): Python function. Backward hook function.
2160
2367
 
2161
2368
  Returns:
2162
- Handle, it is an instance of `mindspore.common.hook_handle.HookHandle` and corresponding to the `hook_fn` .
2163
- The handle can be used to remove the added `hook_fn` by calling `handle.remove()` .
2369
+ A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
2370
+ `handle.remove()` .
2164
2371
 
2165
2372
  Raises:
2166
2373
  TypeError: If the `hook_fn` is not a function of python.
@@ -2173,9 +2380,9 @@ class Cell(Cell_):
2173
2380
  >>> import mindspore as ms
2174
2381
  >>> from mindspore import Tensor, nn, ops
2175
2382
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
2176
- >>> def backward_hook_fn(cell_id, grad_input, grad_output):
2177
- ... print("backward input: ", grad_input)
2178
- ... print("backward output: ", grad_output)
2383
+ >>> def backward_hook_fn(cell, grad_input, grad_output):
2384
+ ... print("backward input: ", grad_output)
2385
+ ... print("backward output: ", grad_input)
2179
2386
  ...
2180
2387
  >>> class Net(nn.Cell):
2181
2388
  ... def __init__(self):
@@ -2195,22 +2402,17 @@ class Cell(Cell_):
2195
2402
  >>> print(output)
2196
2403
  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
2197
2404
  """
2198
- if context.get_context("mode") != context.PYNATIVE_MODE:
2199
- logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
2200
- f"context.set_context to set pynative mode.")
2405
+ if context._get_mode() == context.GRAPH_MODE:
2201
2406
  return HookHandle()
2202
-
2203
- if not isinstance(hook_fn, (FunctionType, MethodType)):
2204
- raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
2205
- f"function, but got {type(hook_fn)}.")
2407
+ if not check_hook_fn("register_backward_hook", hook_fn):
2408
+ return HookHandle()
2409
+ handle = HookHandle(self._backward_hook)
2410
+ self._backward_hook[handle.handle_id] = hook_fn
2206
2411
  if self._cell_backward_hook is None:
2207
- self._enable_backward_hook = True
2208
- self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
2209
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2210
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2211
- else:
2212
- backward_hook_key = self._cell_backward_hook.register_backward_hook(hook_fn)
2213
- handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
2412
+ # Generate a CellBackwardHook prim, and add function for it
2413
+ self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")",
2414
+ self, self._backward_hook)
2415
+ self._cell_backward_hook.register_backward_hook()
2214
2416
  return handle
2215
2417
 
2216
2418
  def _backward_hook_construct(self, *inputs, **kwargs):
@@ -2227,15 +2429,31 @@ class Cell(Cell_):
2227
2429
  Supported Platforms:
2228
2430
  ``Ascend`` ``GPU`` ``CPU``
2229
2431
  """
2230
- if len(inputs) > 1:
2231
- inputs = self._cell_backward_hook(inputs)
2232
- else:
2233
- inputs = self._cell_backward_hook(*inputs)
2234
- inputs = (inputs,)
2235
- if isinstance(inputs, tuple):
2236
- outputs = self.construct(*inputs, **kwargs)
2432
+ # cell_backward_hook has CellBackwardHook op, so keep input args as they are.
2433
+ outputs = self._cell_backward_hook(*inputs)
2434
+ # If the inputs have more than two args, the outputs will also have more than two args and will be wrapped into
2435
+ # a tuple, so need to do unwrapping. If inputs is empty, we also need to unwrap it.
2436
+ # Because when output of runop method is one, it will not wrap a tuple, we need not unwrap it.
2437
+ is_need_unwrap = False
2438
+ if isinstance(outputs, tuple) and len(inputs) != 1:
2439
+ is_need_unwrap = True
2440
+
2441
+ if self._recompute_cell is not None:
2442
+ if is_need_unwrap:
2443
+ outputs = self._recompute_cell(*outputs, **kwargs)
2444
+ else:
2445
+ outputs = self._recompute_cell(outputs, **kwargs)
2446
+ elif self.has_bprop:
2447
+ if is_need_unwrap:
2448
+ outputs = self._call_custom_bprop(*outputs, **kwargs)
2449
+ else:
2450
+ outputs = self._call_custom_bprop(outputs, **kwargs)
2237
2451
  else:
2238
- outputs = self.construct(inputs, **kwargs)
2452
+ if is_need_unwrap:
2453
+ outputs = self.construct(*outputs, **kwargs)
2454
+ else:
2455
+ outputs = self.construct(outputs, **kwargs)
2456
+
2239
2457
  outputs = self._cell_backward_hook(outputs)
2240
2458
  return outputs
2241
2459
 
@@ -2365,6 +2583,9 @@ class Cell(Cell_):
2365
2583
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
2366
2584
  Default: ``False`` .
2367
2585
  """
2586
+ if context.get_context("mode") == context.PYNATIVE_MODE:
2587
+ self._recompute_cell = recompute_registry.get()(self.construct)
2588
+ return
2368
2589
  self._recompute()
2369
2590
  if 'mp_comm_recompute' in kwargs.keys():
2370
2591
  self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
@@ -2384,16 +2605,13 @@ class Cell(Cell_):
2384
2605
  "the key kwargs must be 'mp_comm_recompute', "
2385
2606
  "'parallel_optimizer_comm_recompute', 'recompute_slice_activation'" % key)
2386
2607
 
2608
+ @deprecated("2.3", "infer_param_pipeline_stage")
2387
2609
  def infer_param_pipeline_stage(self):
2388
2610
  """
2389
2611
  Infer pipeline stages of all parameters in the cell.
2390
2612
 
2391
2613
  Note:
2392
- - If a parameter does not belong to any cell which has been set pipeline_stage,
2393
- the parameter should use add_pipeline_stage to add it's pipeline_stage information.
2394
- - If a parameter P has been used by two operators in different stages "stageA" and "stageB",
2395
- the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
2396
- to add it's stage information before using infer_param_pipeline_stage.
2614
+ - The interface is deprecated from version 2.3 and will be removed in a future version.
2397
2615
 
2398
2616
  Returns:
2399
2617
  The params belong to current stage in pipeline parallel.
@@ -2448,86 +2666,6 @@ class Cell(Cell_):
2448
2666
  for op in all_ops:
2449
2667
  op.place(role, rank_id)
2450
2668
 
2451
- def _check_dynamic_tensor(self, set_input, net_input, index):
2452
- """
2453
- Check if tensor is correctly set for dynamic shape.
2454
-
2455
- Args:
2456
- set_input (Tensor): Tensor set for dynamic shape.
2457
- net_input (Tensor): Input tensor of the Cell object.
2458
- index (int): Tensor index for set inputs.
2459
- """
2460
- if not isinstance(net_input, Tensor):
2461
- raise TypeError(
2462
- f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must be Tensor, "
2463
- f"but got {type(net_input)}.")
2464
- is_param_set_input = isinstance(set_input, Parameter)
2465
- is_param_net_input = isinstance(net_input, Parameter)
2466
- if (is_param_set_input and not is_param_net_input) or (is_param_net_input and not is_param_set_input):
2467
- raise TypeError(
2468
- f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
2469
- f"as network's input, but got 'set_inputs': {type(set_input)} and network's input: {type(net_input)}.")
2470
- if set_input.dtype != net_input.dtype:
2471
- raise TypeError(
2472
- f"For 'set_inputs' and tuple(list) in 'set_inputs',the dtype of {index + 1}th input must be the same "
2473
- f"as network's input, but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
2474
- if -2 not in set_input.shape:
2475
- if net_input.dim() != 0 and set_input.dim() != net_input.dim():
2476
- raise ValueError(
2477
- f"For 'set_inputs' and tuple(list) in 'set_inputs',the dims of {index + 1}th input must be the "
2478
- f"same as network's input, but got 'set_inputs': {set_input.dim()} and network's input: "
2479
- f"{net_input.dim()}.")
2480
- if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
2481
- raise ValueError(
2482
- f"For 'set_inputs' and tuple(list) in 'set_inputs',the shape of {index + 1}th input must be the "
2483
- f"same as network's input, but got 'set_inputs': {set_input.shape} and network's input: "
2484
- f"{net_input.shape}.")
2485
-
2486
- def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
2487
- """
2488
- Check if graph has been compiled with dynamic shape.
2489
-
2490
- Args:
2491
- net_inputs (tuple): Inputs of the Cell object.
2492
- """
2493
- set_inputs_len = len(set_inputs)
2494
- net_inputs_len = len(net_inputs)
2495
- if set_inputs_len != net_inputs_len:
2496
- raise ValueError("The length of 'set_inputs' or tuple(list) in 'set_inputs' must be equal to network's "
2497
- f"inputs, but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
2498
- for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
2499
- if isinstance(set_input, Tensor):
2500
- self._check_dynamic_tensor(set_input, net_input, index)
2501
- elif isinstance(set_input, (tuple, list)):
2502
- if not isinstance(net_input, (tuple, list)):
2503
- raise TypeError(
2504
- f"The {index + 1}th input type of 'set_inputs' or tuple(list) in 'set_inputs' must be tuple or "
2505
- f"list, but got {type(net_input)}.")
2506
- self._check_compile_dynamic_shape(set_input, net_input)
2507
- else:
2508
- if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
2509
- continue
2510
- if net_input != set_input:
2511
- raise ValueError(
2512
- f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
2513
- f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
2514
-
2515
- def _run_tracefunc(self, *args, **kwargs):
2516
- """ Run Packed Cell in Pack."""
2517
- args = self._mixed_precision_cast(args)
2518
- need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
2519
- if not PackFunc.current.is_pynative_mode and need_subgraph:
2520
- expander = PackExpander.get_instance()
2521
- args = expander.begin_subgraph(self, *args)
2522
- args = [_convert_tensor(a) for a in args]
2523
- output = self._run_construct(args, kwargs)
2524
- ret = expander.end_subgraph(self, output)
2525
- output = _convert_tensor(ret)
2526
- else:
2527
- with _SetMixedPrecision(self):
2528
- output = self._run_construct(args, kwargs)
2529
- return output
2530
-
2531
2669
  def _mixed_precision_cast(self, inputs):
2532
2670
  mixed_type = self.get_mixed_precision_type()
2533
2671
  if mixed_type == MixedPrecisionType.NOTSET:
@@ -2633,7 +2771,7 @@ class GraphCell(Cell):
2633
2771
  self._add_attr("graph_load_from_mindir", self.graph)
2634
2772
  if not self.obf_random_seed:
2635
2773
  return self.compile_and_run(*args, **kwargs)
2636
- append_input = Tensor((numpy.ones((1, 1)) * self._branch_control_input).astype(numpy.int32))
2774
+ append_input = Tensor((numpy.ones((1,)) * self._branch_control_input).astype(numpy.int32))
2637
2775
  return self.compile_and_run(*args, append_input, **kwargs)
2638
2776
 
2639
2777