mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,8 +23,8 @@ from collections import OrderedDict
23
23
  from types import FunctionType, MethodType
24
24
  import numpy
25
25
 
26
- import mindspore.dataset as ds
27
26
  from mindspore._checkparam import args_type_check
27
+ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
28
28
  from mindspore import log as logger
29
29
  from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
30
30
  from mindspore.common.hook_handle import HookHandle
@@ -42,29 +42,16 @@ from mindspore.ops.primitive import Primitive
42
42
  from mindspore.ops.operations import _inner_ops as inner
43
43
  from mindspore.parallel.shard import Shard
44
44
  from mindspore._check_jit_forbidden_api import jit_forbidden_register
45
+ from mindspore.common._decorator import deprecated
46
+ from mindspore._c_expression import PackExpander
47
+ from mindspore.ops._tracefunc import _convert_tensor, _SetMixedPrecision, PackFunc
45
48
 
46
49
 
47
50
  def _check_args(args):
48
51
  """Check the input args's type"""
49
- index = 1
50
52
  for item in args:
51
53
  if isinstance(item, Tensor) and item.has_init:
52
54
  item.init_data()
53
- elif isinstance(item, numpy.ndarray):
54
- suffix = "th"
55
- if index == 1:
56
- suffix = "st"
57
- elif index == 2:
58
- suffix = "nd"
59
- elif index == 3:
60
- suffix = "rd"
61
-
62
- input_index = str(index) + suffix
63
- raise TypeError(f"For 'Cell', inputs should not be numpy array. Only support bool, int, float, None, "
64
- f"Tensor, Parameter, mstype.Number(mstype.bool, mstype.int, mstype.float, mstype.uint"
65
- f"), and tuple or list containing only these types, and dict whose values are these "
66
- f"types, but the {input_index} arg type is {type(item)}.")
67
- index += 1
68
55
 
69
56
 
70
57
  class Cell(Cell_):
@@ -77,15 +64,25 @@ class Cell(Cell_):
77
64
  graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
78
65
  PYNATIVE_MODE (dynamic graph mode).
79
66
 
67
+ .. note::
68
+ Cell is the inference mode by default. For a class that inherits a Cell,
69
+ if the training and inference have different structures, the subclass performs the inference branch by default.
70
+ To set the training mode, refer to `mindspore.nn.Cell.set_train` .
71
+
72
+ .. warning::
73
+ In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
74
+ named 'phase' or 'cells', otherwise, an error will be raised.
75
+
80
76
  Args:
81
77
  auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
82
- affects the names of parameters in the `Cell`. If set to True, the parameter name will be
83
- automatically prefixed, otherwise not. In general, the backbone network should be set to True,
84
- otherwise the duplicate name problem will appear. The cell to train the backbone network, such as
85
- optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to False, otherwise the
86
- parameter name in backbone will be changed by mistake. Default: True.
78
+ affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
79
+ automatically prefixed, otherwise not. In general, the backbone network should be set to
80
+ ``True`` , otherwise the duplicate name problem will appear. The cell to train the backbone
81
+ network, such as optimizer and :class:`mindspore.nn.TrainOneStepCell`, should be set to
82
+ ``False`` , otherwise the parameter name in backbone will be changed by mistake.
83
+ Default: ``True`` .
87
84
  flags (dict, optional): Network configuration information, currently it is used for the binding of network
88
- and dataset. Users can also customize network attributes by this parameter. Default: None.
85
+ and dataset. Users can also customize network attributes by this parameter. Default: ``None`` .
89
86
 
90
87
  Supported Platforms:
91
88
  ``Ascend`` ``GPU`` ``CPU``
@@ -167,7 +164,9 @@ class Cell(Cell_):
167
164
  self.saved_dynamic_shape = None
168
165
  self._jit_config_dict = dict()
169
166
  self.grad_ops_label = False
170
- self.to_float_fp16 = False
167
+ self.ge_sync_data = False
168
+ self._is_check_and_refresh = False
169
+ self._amp_level = ""
171
170
 
172
171
  def __getstate__(self):
173
172
  base = Cell_.__getstate__(self)
@@ -199,6 +198,23 @@ class Cell(Cell_):
199
198
  def param_prefix(self):
200
199
  """
201
200
  Param prefix is the prefix of current cell's direct child parameter.
201
+
202
+ Examples:
203
+ >>> import mindspore as ms
204
+ >>> from mindspore import Tensor, nn
205
+ ...
206
+ >>> class Net(nn.Cell):
207
+ ... def __init__(self):
208
+ ... super(Net, self).__init__()
209
+ ... self.dense = nn.Dense(2, 2)
210
+ ...
211
+ ... def construct(self, x):
212
+ ... x = self.dense(x)
213
+ ... return x
214
+ >>> net = Net()
215
+ >>> net.update_cell_prefix()
216
+ >>> print(net.dense.param_prefix)
217
+ dense
202
218
  """
203
219
  return self._param_prefix
204
220
 
@@ -206,6 +222,10 @@ class Cell(Cell_):
206
222
  def bprop_debug(self):
207
223
  """
208
224
  Get whether cell custom bprop debug is enabled.
225
+
226
+ Tutorial Examples:
227
+ - `Cell and Parameter - Custom Cell Reverse
228
+ <https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
209
229
  """
210
230
  return self._bprop_debug
211
231
 
@@ -220,7 +240,7 @@ class Cell(Cell_):
220
240
  and add to graph when bprop debug is false.
221
241
 
222
242
  Args:
223
- value (bool): Specifies whether to enable bprop debug. Default: False.
243
+ value (bool): Specifies whether to enable bprop debug. Default: ``False``.
224
244
  """
225
245
  if not isinstance(value, bool):
226
246
  raise TypeError(f"For 'Cell', the property 'bprop_debug' must be bool type, but got type {type(value)}.")
@@ -312,6 +332,21 @@ class Cell(Cell_):
312
332
  for item in self.trainable_params():
313
333
  item.add_pipeline_stage(value)
314
334
 
335
+ @property
336
+ def pipeline_segment(self):
337
+ return self._pipeline_segment
338
+
339
+ @pipeline_segment.setter
340
+ def pipeline_segment(self, value):
341
+ if not isinstance(value, int) or isinstance(value, bool):
342
+ raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
343
+ "must be int type, but got type : {}".format(type(value)))
344
+
345
+ if value < 0:
346
+ raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
347
+ "can not be less than 0, but got {}".format(value))
348
+ self._pipeline_segment = value
349
+
315
350
  @property
316
351
  def parallel_parameter_merge_net_dict(self):
317
352
  return self._parallel_parameter_merge_net_dict
@@ -348,13 +383,14 @@ class Cell(Cell_):
348
383
  if '_params_list' in self.__dict__:
349
384
  params_list = self.__dict__['_params_list']
350
385
  if name in params_list:
351
- return ParameterTuple(params_list[name])
386
+ return params_list[name]
352
387
  raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
353
388
 
354
389
  def __del__(self):
355
- # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
356
- # here using pop(id(self), None) to avoid KeyError exception
357
- cells_compile_cache.pop(id(self), None)
390
+ if isinstance(cells_compile_cache, dict):
391
+ # while deepcopy a cell instance, the copied cell instance can't be added to cells_compile_cache
392
+ # here using pop(id(self), None) to avoid KeyError exception
393
+ cells_compile_cache.pop(id(self), None)
358
394
  try:
359
395
  if self.compile_cache:
360
396
  _cell_graph_executor.del_net_res(self, self.compile_cache)
@@ -367,11 +403,11 @@ class Cell(Cell_):
367
403
  del self._params[name]
368
404
  elif name in self._cells:
369
405
  del self._cells[name]
406
+ elif '_params_list' in self.__dict__ and name in self._params_list:
407
+ del self._params_list[name]
408
+ elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
409
+ del self._tensor_list[name]
370
410
  else:
371
- if '_params_list' in self.__dict__ and name in self._params_list:
372
- del self._params_list[name]
373
- elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
374
- del self._tensor_list[name]
375
411
  object.__delattr__(self, name)
376
412
  self._attr_synced = False
377
413
 
@@ -383,7 +419,8 @@ class Cell(Cell_):
383
419
  res.append(self._cast_mixed_precision_inputs(item, dst_type))
384
420
  elif isinstance(item, float):
385
421
  res.append(self.cast(item, dst_type))
386
- elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64}:
422
+ elif hasattr(item, "dtype") and item.dtype in \
423
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
387
424
  res.append(self.cast(item, dst_type))
388
425
  else:
389
426
  res.append(item)
@@ -438,7 +475,7 @@ class Cell(Cell_):
438
475
  if self._enable_forward_pre_hook:
439
476
  cast_inputs = self._run_forward_pre_hook(cast_inputs)
440
477
  if self._enable_backward_hook:
441
- output = self._backward_hook_construct(*cast_inputs)
478
+ output = self._backward_hook_construct(*cast_inputs, **kwargs)
442
479
  elif hasattr(self, "_shard_fn"):
443
480
  output = self._shard_fn(*cast_inputs, **kwargs)
444
481
  else:
@@ -546,19 +583,19 @@ class Cell(Cell_):
546
583
  in_strategy (tuple): Define the layout of inputs, each element of the tuple should be a tuple or None. Tuple
547
584
  defines the layout of the corresponding input and None represents a data parallel strategy.
548
585
  out_strategy (Union[None, tuple]): Define the layout of outputs similar with in_strategy.
549
- It is not in use right now. Default: None.
586
+ It is not in use right now. Default: ``None`` .
550
587
  parameter_plan (Union[dict, None]): Define the layout for the specified parameters. Each element in dict
551
588
  defines the layout of the parameter like "param_name: layout".
552
589
  The key is a parameter name of type 'str'.
553
590
  The value is a 1-D integer tuple, indicating the corresponding layout.
554
591
  If the parameter name is incorrect or the corresponding parameter
555
592
  has been set, the parameter setting will be ignored.
556
- Default: None.
593
+ Default: ``None`` .
557
594
  device (string): Select a certain device target. It is not in use right now.
558
- Support ["CPU", "GPU", "Ascend"]. Default: "Ascend".
595
+ Support [ ``"CPU"`` , ``"GPU"`` , ``"Ascend"`` ]. Default: ``"Ascend"`` .
559
596
  level (int): Option for parallel strategy infer algorithm, namely the object function, maximize computation
560
597
  over communication ratio, maximize speed performance, minimize memory usage etc. It is not in
561
- use right now. Support ["0", "1", "2"]. Default: "0".
598
+ use right now. Support [ ``"0"`` , ``"1"`` , ``"2"`` ]. Default: ``0`` .
562
599
 
563
600
  Returns:
564
601
  Cell, the cell itself.
@@ -627,6 +664,13 @@ class Cell(Cell_):
627
664
  args = bound_arguments.args
628
665
  kwargs = bound_arguments.kwargs
629
666
 
667
+ if PackFunc.is_tracing():
668
+ return self._run_tracefunc(*args, **kwargs)
669
+
670
+ if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
671
+ self.check_names_and_refresh_name()
672
+ self._is_check_and_refresh = True
673
+
630
674
  # Run in Graph mode.
631
675
  if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
632
676
  self._check_construct_args(*args)
@@ -646,7 +690,7 @@ class Cell(Cell_):
646
690
  _check_args(args)
647
691
  self._check_cell_flags_in_pynative()
648
692
 
649
- if self.requires_grad:
693
+ if self.requires_grad and _pynative_executor.enable_grad():
650
694
  _pynative_executor.set_grad_flag(True)
651
695
 
652
696
  if self._dynamic_shape_inputs is not None:
@@ -881,16 +925,16 @@ class Cell(Cell_):
881
925
  Examples:
882
926
  >>> import numpy as np
883
927
  >>> import mindspore as ms
884
- >>> from mindspore import nn, Tensor, context
928
+ >>> from mindspore import nn, Tensor
885
929
  >>>
886
- >>> class reluNet(nn.Cell):
930
+ >>> class ReluNet(nn.Cell):
887
931
  ... def __init__(self):
888
- ... super(reluNet, self).__init__()
932
+ ... super(ReluNet, self).__init__()
889
933
  ... self.relu = nn.ReLU()
890
934
  ... def construct(self, x):
891
935
  ... return self.relu(x)
892
936
  >>>
893
- >>> net = reluNet()
937
+ >>> net = ReluNet()
894
938
  >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
895
939
  >>> net.set_inputs(input_dyn)
896
940
  >>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
@@ -899,15 +943,10 @@ class Cell(Cell_):
899
943
  if self.grad_ops_label:
900
944
  logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
901
945
  f'generated.')
902
- for ele in inputs:
903
- if isinstance(ele, str):
904
- raise TypeError(f"For element in 'set_inputs', the type must not be str.")
905
946
  self._dynamic_shape_inputs = inputs
906
947
  self._check_construct_args(*inputs)
907
- if self._dynamic_shape_inputs:
908
- ds.config.set_dynamic_shape(True)
909
948
  if context._get_mode() == context.PYNATIVE_MODE:
910
- _pynative_executor.set_dynamic_input(self)
949
+ _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
911
950
 
912
951
  def get_inputs(self):
913
952
  """
@@ -918,6 +957,26 @@ class Cell(Cell_):
918
957
 
919
958
  .. warning::
920
959
  This is an experimental API that is subject to change or deletion.
960
+
961
+ Examples:
962
+ >>> import numpy as np
963
+ >>> import mindspore as ms
964
+ >>> from mindspore import nn, Tensor
965
+ >>>
966
+ >>> class ReluNet(nn.Cell):
967
+ ... def __init__(self):
968
+ ... super(ReluNet, self).__init__()
969
+ ... self.relu = nn.ReLU()
970
+ ... def construct(self, x):
971
+ ... return self.relu(x)
972
+ >>>
973
+ >>> net = ReluNet()
974
+ >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
975
+ >>> net.set_inputs(input_dyn)
976
+ >>> get_inputs = net.get_inputs()
977
+ >>> print(get_inputs)
978
+ (Tensor(shape=[3, -1], dtype=Float32, value= ),)
979
+
921
980
  """
922
981
 
923
982
  return self._dynamic_shape_inputs
@@ -930,6 +989,10 @@ class Cell(Cell_):
930
989
  args (tuple): Args of the Cell object.
931
990
  kwargs (dict): Kwargs of the Cell object.
932
991
  """
992
+ # this is used only for test
993
+ if is_auto_dynamic() and (self._dynamic_shape_inputs is None or self._dynamic_shape_inputs[0] is None):
994
+ self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
995
+
933
996
  if self._dynamic_shape_inputs is None:
934
997
  _cell_graph_executor.compile(self, phase=self.phase,
935
998
  jit_config_dict=self._jit_config_dict, *args, **kwargs)
@@ -955,7 +1018,7 @@ class Cell(Cell_):
955
1018
  Object, the result of executing.
956
1019
  """
957
1020
  self.compile(*args, **kwargs)
958
-
1021
+ self.add_flags(ge_sync_data=False)
959
1022
  new_args = _get_args_for_run(self, args, kwargs)
960
1023
  return _cell_graph_executor(self, *new_args, phase=self.phase)
961
1024
 
@@ -969,7 +1032,8 @@ class Cell(Cell_):
969
1032
  logger.warning("'auto_parallel_compile_and_run' function is deprecated.")
970
1033
 
971
1034
  def exec_checkpoint_graph(self):
972
- """Executes saving checkpoint graph operation."""
1035
+ """Executes GE saving checkpoint graph operation."""
1036
+ self.add_flags(ge_sync_data=True)
973
1037
  _cell_graph_executor(self, phase='save')
974
1038
 
975
1039
  def insert_param_to_cell(self, param_name, param, check_name_contain_dot=True):
@@ -982,11 +1046,28 @@ class Cell(Cell_):
982
1046
  Args:
983
1047
  param_name (str): Name of the parameter.
984
1048
  param (Parameter): Parameter to be inserted to the cell.
985
- check_name_contain_dot (bool): Determines whether the name input is compatible. Default: True.
1049
+ check_name_contain_dot (bool): Determines whether the name input is compatible. Default: ``True`` .
986
1050
 
987
1051
  Raises:
988
1052
  KeyError: If the name of parameter is null or contains dot.
989
1053
  TypeError: If the type of parameter is not Parameter.
1054
+
1055
+ Examples:
1056
+ >>> import mindspore as ms
1057
+ >>> from mindspore import Tensor, nn, Parameter
1058
+ ...
1059
+ >>> class Net(nn.Cell):
1060
+ ... def __init__(self):
1061
+ ... super(Net, self).__init__()
1062
+ ... self.relu = nn.ReLU()
1063
+ ...
1064
+ ... def construct(self, x):
1065
+ ... x = self.relu(x)
1066
+ ... return x
1067
+ >>> net = Net()
1068
+ >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
1069
+ >>> print(net.bias)
1070
+ Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
990
1071
  """
991
1072
  if not param_name:
992
1073
  raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
@@ -1000,6 +1081,9 @@ class Cell(Cell_):
1000
1081
  if not isinstance(param, Parameter) and param is not None:
1001
1082
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1002
1083
  f"but got {type(param)}.")
1084
+ if param is None:
1085
+ raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
1086
+ f"but got None.")
1003
1087
  if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1004
1088
  param.name = param_name
1005
1089
  self._params[param_name] = param
@@ -1041,6 +1125,18 @@ class Cell(Cell_):
1041
1125
  KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1042
1126
  TypeError: If type of `child_name` is not str.
1043
1127
  TypeError: Child Cell's type is incorrect.
1128
+
1129
+ Examples:
1130
+ >>> import mindspore as ms
1131
+ >>> from mindspore import Tensor, nn
1132
+ ...
1133
+ >>> net1 = nn.ReLU()
1134
+ >>> net2 = nn.Dense(2, 2)
1135
+ >>> net1.insert_child_to_cell("child", net2)
1136
+ >>> print(net1)
1137
+ ReLU<
1138
+ (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1139
+ >
1044
1140
  """
1045
1141
  if not isinstance(child_name, str):
1046
1142
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
@@ -1107,10 +1203,29 @@ class Cell(Cell_):
1107
1203
  `init_parameters_data`, do not save these results.
1108
1204
 
1109
1205
  Args:
1110
- auto_parallel_mode (bool): If running in auto_parallel_mode. Default: False.
1206
+ auto_parallel_mode (bool): If running in auto_parallel_mode. Default: ``False`` .
1111
1207
 
1112
1208
  Returns:
1113
1209
  Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
1210
+
1211
+ Examples:
1212
+ >>> import mindspore as ms
1213
+ >>> from mindspore import Tensor, nn
1214
+ ...
1215
+ >>> class Net(nn.Cell):
1216
+ ... def __init__(self):
1217
+ ... super(Net, self).__init__()
1218
+ ... self.dense = nn.Dense(2, 2)
1219
+ ...
1220
+ ... def construct(self, x):
1221
+ ... x = self.dense(x)
1222
+ ... return x
1223
+ >>> net = Net()
1224
+ >>> print(net.init_parameters_data())
1225
+ {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
1226
+ Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
1227
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
1228
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
1114
1229
  """
1115
1230
  replace = dict()
1116
1231
 
@@ -1152,10 +1267,28 @@ class Cell(Cell_):
1152
1267
  Gets the parameters dictionary of this cell.
1153
1268
 
1154
1269
  Args:
1155
- recurse (bool): Whether contains the parameters of subcells. Default: True.
1270
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1156
1271
 
1157
1272
  Returns:
1158
1273
  OrderedDict, return parameters dictionary.
1274
+
1275
+ Examples:
1276
+ >>> import mindspore as ms
1277
+ >>> from mindspore import Tensor, nn, Parameter
1278
+ ...
1279
+ >>> class Net(nn.Cell):
1280
+ ... def __init__(self):
1281
+ ... super(Net, self).__init__()
1282
+ ... self.dense = nn.Dense(2, 2)
1283
+ ...
1284
+ ... def construct(self, x):
1285
+ ... x = self.dense(x)
1286
+ ... return x
1287
+ >>> net = Net()
1288
+ >>> print(net.parameters_dict())
1289
+ OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
1290
+ requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
1291
+ requires_grad=True))])
1159
1292
  """
1160
1293
  param_dict = OrderedDict()
1161
1294
  for param in self.get_parameters(expand=recurse):
@@ -1167,7 +1300,7 @@ class Cell(Cell_):
1167
1300
  Gets the parameters broadcast dictionary of this cell.
1168
1301
 
1169
1302
  Args:
1170
- recurse (bool): Whether contains the parameters of subcells. Default: True.
1303
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1171
1304
 
1172
1305
  Returns:
1173
1306
  OrderedDict, return parameters broadcast dictionary.
@@ -1185,11 +1318,11 @@ class Cell(Cell_):
1185
1318
  Adds the `prefix` string to the names of parameters.
1186
1319
 
1187
1320
  Args:
1188
- prefix (str): The prefix string. Default: ''.
1189
- recurse (bool): Whether contains the parameters of subcells. Default: True.
1321
+ prefix (str): The prefix string. Default: ``''`` .
1322
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True`` .
1190
1323
  """
1191
1324
 
1192
- Validator.check_str_by_regular(prefix)
1325
+ Validator.check_str_and_none_by_regular(prefix)
1193
1326
  for name, param in self.parameters_and_names(expand=recurse):
1194
1327
  if prefix != '':
1195
1328
  param.is_init = False
@@ -1205,7 +1338,7 @@ class Cell(Cell_):
1205
1338
 
1206
1339
  Args:
1207
1340
  prefix (str): The prefix string. Default: ''.
1208
- recurse (bool): Whether contains the parameters of subcells. Default: True.
1341
+ recurse (bool): Whether contains the parameters of subcells. Default: ``True``.
1209
1342
  """
1210
1343
 
1211
1344
  Validator.check_str_by_regular(prefix)
@@ -1224,10 +1357,14 @@ class Cell(Cell_):
1224
1357
  Returns a list of all trainable parameters.
1225
1358
 
1226
1359
  Args:
1227
- recurse (bool): Whether contains the trainable parameters of subcells. Default: True.
1360
+ recurse (bool): Whether contains the trainable parameters of subcells. Default: ``True`` .
1228
1361
 
1229
1362
  Returns:
1230
1363
  List, the list of trainable parameters.
1364
+
1365
+ Tutorial Examples:
1366
+ - `Model Training - Optimizer
1367
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
1231
1368
  """
1232
1369
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1233
1370
 
@@ -1239,7 +1376,7 @@ class Cell(Cell_):
1239
1376
  Returns a list of all untrainable parameters.
1240
1377
 
1241
1378
  Args:
1242
- recurse (bool): Whether contains the untrainable parameters of subcells. Default: True.
1379
+ recurse (bool): Whether contains the untrainable parameters of subcells. Default: ``True`` .
1243
1380
 
1244
1381
  Returns:
1245
1382
  List, the list of untrainable parameters.
@@ -1251,25 +1388,58 @@ class Cell(Cell_):
1251
1388
  """
1252
1389
  Returns an iterator over cell parameters.
1253
1390
 
1254
- Yields parameters of this cell. If `expand` is true, yield parameters of this cell and all subcells.
1391
+ Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
1392
+ For more details about subcells, please see the example below.
1255
1393
 
1256
1394
  Args:
1257
- expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
1258
- that are direct members of this cell. Default: True.
1395
+ expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
1396
+ parameters that are direct members of this cell. Default: ``True`` .
1259
1397
 
1260
1398
  Returns:
1261
1399
  Iteration, all parameters at the cell.
1262
1400
 
1263
1401
  Examples:
1264
- >>> from mindspore import nn
1265
- >>> net = nn.Dense(3, 4)
1266
- >>> parameters = []
1267
- >>> for item in net.get_parameters():
1268
- ... parameters.append(item)
1402
+ >>> import mindspore as ms
1403
+ >>> from mindspore import nn, ops, Tensor
1404
+ >>> import numpy as np
1405
+ >>> class TestNet(nn.Cell):
1406
+ ... def __init__(self):
1407
+ ... super().__init__()
1408
+ ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1409
+ ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
1410
+ ... def construct(self, x):
1411
+ ... x += self.my_w1
1412
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1413
+ ... return x
1414
+ >>> class TestNet2(nn.Cell):
1415
+ ... def __init__(self):
1416
+ ... super().__init__()
1417
+ ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1418
+ ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
1419
+ ... # also be gathered.
1420
+ ... self.subcell = TestNet()
1421
+ ... def construct(self, x):
1422
+ ... x += self.my_w1
1423
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1424
+ ... return x
1425
+ >>> net = TestNet2()
1426
+ >>> print([p for p in net.get_parameters(expand=True)])
1427
+ [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
1428
+ shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
1429
+ requires_grad=True)]
1269
1430
  """
1270
1431
  for _, param in self.parameters_and_names(expand=expand):
1271
1432
  yield param
1272
1433
 
1434
+ # pylint: disable=missing-docstring
1435
+ def check_names_and_refresh_name(self):
1436
+ if not hasattr(self, "_params"):
1437
+ return
1438
+ all_name = [i.name for i in dict(self.parameters_and_names()).values()]
1439
+ if len(set(all_name)) < len(all_name):
1440
+ self.update_parameters_name()
1441
+ self.check_names()
1442
+
1273
1443
  def check_names(self):
1274
1444
  """
1275
1445
  Check the names of cell parameters.
@@ -1288,9 +1458,9 @@ class Cell(Cell_):
1288
1458
  Includes the parameter's name and itself.
1289
1459
 
1290
1460
  Args:
1291
- name_prefix (str): Namespace. Default: ''.
1461
+ name_prefix (str): Namespace. Default: ``''`` .
1292
1462
  expand (bool): If true, yields parameters of this cell and all subcells. Otherwise, only yield parameters
1293
- that are direct members of this cell. Default: True.
1463
+ that are direct members of this cell. Default: ``True`` .
1294
1464
 
1295
1465
  Returns:
1296
1466
  Iteration, all the names and corresponding parameters in the cell.
@@ -1302,6 +1472,10 @@ class Cell(Cell_):
1302
1472
  >>> for m in n.parameters_and_names():
1303
1473
  ... if m[0]:
1304
1474
  ... names.append(m[0])
1475
+
1476
+ Tutorial Examples:
1477
+ - `Building a Network - Model Parameters
1478
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
1305
1479
  """
1306
1480
  cells = []
1307
1481
  if expand:
@@ -1313,7 +1487,7 @@ class Cell(Cell_):
1313
1487
  for cell_name, cell in cells:
1314
1488
  params = cell._params.items()
1315
1489
  for par_name, par in params:
1316
- if par.inited_param is not None:
1490
+ if par is not None and par.inited_param is not None:
1317
1491
  par = par.inited_param
1318
1492
  if par is not None and id(par) not in params_set:
1319
1493
  params_set.add(id(par))
@@ -1328,8 +1502,8 @@ class Cell(Cell_):
1328
1502
  Returns an iterator over all cells in the network, including the cell's name and itself.
1329
1503
 
1330
1504
  Args:
1331
- cells (str): Cells to iterate over. Default: None.
1332
- name_prefix (str): Namespace. Default: ''.
1505
+ cells (str): Cells to iterate over. Default: ``None`` .
1506
+ name_prefix (str): Namespace. Default: ``''`` .
1333
1507
 
1334
1508
  Returns:
1335
1509
  Iteration, all the child cells and corresponding names in the cell.
@@ -1370,6 +1544,22 @@ class Cell(Cell_):
1370
1544
 
1371
1545
  Returns:
1372
1546
  Iteration, the immediate cells in the cell.
1547
+
1548
+ Examples:
1549
+ >>> import mindspore as ms
1550
+ >>> from mindspore import Tensor, nn
1551
+ ...
1552
+ >>> class Net(nn.Cell):
1553
+ ... def __init__(self):
1554
+ ... super(Net, self).__init__()
1555
+ ... self.dense = nn.Dense(2, 2)
1556
+ ...
1557
+ ... def construct(self, x):
1558
+ ... x = self.dense(x)
1559
+ ... return x
1560
+ >>> net = Net()
1561
+ >>> print(net.cells())
1562
+ odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
1373
1563
  """
1374
1564
  return self.name_cells().values()
1375
1565
 
@@ -1415,6 +1605,22 @@ class Cell(Cell_):
1415
1605
 
1416
1606
  Returns:
1417
1607
  Dict, all the child cells and corresponding names in the cell.
1608
+
1609
+ Examples:
1610
+ >>> import mindspore as ms
1611
+ >>> from mindspore import Tensor, nn
1612
+ ...
1613
+ >>> class Net(nn.Cell):
1614
+ ... def __init__(self):
1615
+ ... super(Net, self).__init__()
1616
+ ... self.dense = nn.Dense(2, 2)
1617
+ ...
1618
+ ... def construct(self, x):
1619
+ ... x = self.dense(x)
1620
+ ... return x
1621
+ >>> net = Net()
1622
+ >>> print(net.name_cells())
1623
+ OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
1418
1624
  """
1419
1625
  value_set = set()
1420
1626
  cells = OrderedDict()
@@ -1430,13 +1636,8 @@ class Cell(Cell_):
1430
1636
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1431
1637
  if "fp32" in flags and flags.get("fp32", False):
1432
1638
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1433
-
1434
- def _add_mixed_precision_flag_recursive(self, **flags):
1435
- """Add mixed precision flag to each cell"""
1436
- if "fp16" in flags and flags.get("fp16", False):
1437
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1438
- if "fp32" in flags and flags.get("fp32", False):
1439
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1639
+ if "bf16" in flags and flags.get("bf16", False):
1640
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1440
1641
 
1441
1642
  def apply(self, fn):
1442
1643
  """
@@ -1478,7 +1679,24 @@ class Cell(Cell_):
1478
1679
 
1479
1680
  Args:
1480
1681
  flags (dict): Network configuration information, currently it is used for the binding of network and
1481
- dataset. Users can also customize network attributes by this parameter. Default: None.
1682
+ dataset. Users can also customize network attributes by this parameter.
1683
+
1684
+ Examples:
1685
+ >>> import mindspore as ms
1686
+ >>> from mindspore import Tensor, nn
1687
+ ...
1688
+ >>> class Net(nn.Cell):
1689
+ ... def __init__(self):
1690
+ ... super(Net, self).__init__()
1691
+ ... self.relu = nn.ReLU()
1692
+ ...
1693
+ ... def construct(self, x):
1694
+ ... x = self.relu(x)
1695
+ ... return x
1696
+ >>> net = Net()
1697
+ >>> net.add_flags(sink_mode=True)
1698
+ >>> print(net.sink_mode)
1699
+ True
1482
1700
  """
1483
1701
  if not hasattr(self, "_func_graph_flags"):
1484
1702
  self._func_graph_flags = {}
@@ -1493,10 +1711,26 @@ class Cell(Cell_):
1493
1711
 
1494
1712
  Args:
1495
1713
  flags (dict): Network configuration information, currently it is used for the binding of network and
1496
- dataset. Users can also customize network attributes by this parameter. Default: None.
1714
+ dataset. Users can also customize network attributes by this parameter.
1715
+
1716
+ Examples:
1717
+ >>> import mindspore as ms
1718
+ >>> from mindspore import Tensor, nn
1719
+ ...
1720
+ >>> class Net(nn.Cell):
1721
+ ... def __init__(self):
1722
+ ... super(Net, self).__init__()
1723
+ ... self.relu = nn.ReLU()
1724
+ ...
1725
+ ... def construct(self, x):
1726
+ ... x = self.relu(x)
1727
+ ... return x
1728
+ >>> net = Net()
1729
+ >>> net.add_flags_recursive(sink_mode=True)
1730
+ >>> print(net.sink_mode)
1731
+ True
1497
1732
  """
1498
1733
  self.add_flags(**flags)
1499
- self._add_mixed_precision_flag_recursive(**flags)
1500
1734
  for cell in self.cells():
1501
1735
  cell.add_flags_recursive(**flags)
1502
1736
  return self
@@ -1508,17 +1742,28 @@ class Cell(Cell_):
1508
1742
  def get_flags(self):
1509
1743
  """
1510
1744
  Get the self_defined attributes of the cell, which can be added by `add_flags` method.
1745
+
1746
+ Examples:
1747
+ >>> import mindspore as ms
1748
+ >>> from mindspore import Tensor, nn
1749
+ ...
1750
+ >>> class Net(nn.Cell):
1751
+ ... def __init__(self):
1752
+ ... super(Net, self).__init__()
1753
+ ... self.relu = nn.ReLU()
1754
+ ...
1755
+ ... def construct(self, x):
1756
+ ... x = self.relu(x)
1757
+ ... return x
1758
+ >>> net = Net()
1759
+ >>> net.add_flags(sink_mode=True)
1760
+ >>> print(net.get_flags())
1761
+ {'sink_mode':True}
1511
1762
  """
1512
1763
  if not hasattr(self, "_func_graph_flags"):
1513
1764
  self._func_graph_flags = {}
1514
1765
  return self._func_graph_flags
1515
1766
 
1516
- def _set_mixed_precision_type_recursive(self, mixed_type):
1517
- """Set mixed precision type to each cell"""
1518
- Cell_.set_mixed_precision_type(self, mixed_type)
1519
- for cell in self.cells():
1520
- cell._set_mixed_precision_type_recursive(mixed_type)
1521
-
1522
1767
  def to_float(self, dst_type):
1523
1768
  """
1524
1769
  Add cast on all inputs of cell and child cells to run with certain float type.
@@ -1531,13 +1776,13 @@ class Cell(Cell_):
1531
1776
 
1532
1777
  Args:
1533
1778
  dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1534
- dst_type can be `mstype.float16` or `mstype.float32`.
1779
+ dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
1535
1780
 
1536
1781
  Returns:
1537
1782
  Cell, the cell itself.
1538
1783
 
1539
1784
  Raises:
1540
- ValueError: If dst_type is not mstype.float32 or mstype.float16.
1785
+ ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
1541
1786
 
1542
1787
  Supported Platforms:
1543
1788
  ``Ascend`` ``GPU`` ``CPU``
@@ -1549,19 +1794,15 @@ class Cell(Cell_):
1549
1794
  >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1550
1795
  >>> net.to_float(mstype.float16)
1551
1796
  Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1552
- padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
1553
- """
1554
- if dst_type not in (mstype.float16, mstype.float32):
1555
- raise ValueError("For 'to_float', the argument 'dst_type' must be float32 or float16, "
1556
- "but got {}.".format(dst_type))
1557
- if dst_type == mstype.float16:
1558
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1559
- self.to_float_fp16 = True
1560
- else:
1561
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1562
- self.to_float_fp16 = False
1563
- flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
1797
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
1798
+ """
1799
+ if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1800
+ raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
1801
+ "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
1802
+ flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
1803
+ 'bf16': dst_type == mstype.bfloat16}
1564
1804
  self._add_init_args(**flags)
1805
+ self.add_flags_recursive(**flags)
1565
1806
  return self
1566
1807
 
1567
1808
  def set_boost(self, boost_type):
@@ -1570,7 +1811,7 @@ class Cell(Cell_):
1570
1811
  accelerate the algorithm in the algorithm library.
1571
1812
 
1572
1813
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1573
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.0/mindspore/python/mindspore/boost>`_.
1814
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
1574
1815
 
1575
1816
  Note:
1576
1817
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1594,12 +1835,12 @@ class Cell(Cell_):
1594
1835
  def set_grad(self, requires_grad=True):
1595
1836
  """
1596
1837
  Sets the cell flag for gradient. In pynative mode, this parameter specifies whether the network requires
1597
- gradients. If true, the backward network needed to compute the gradients will be generated when the forward
1838
+ gradients. If ``true`` , the backward network needed to compute the gradients will be generated when the forward
1598
1839
  network is executed.
1599
1840
 
1600
1841
  Args:
1601
1842
  requires_grad (bool): Specifies if the net need to grad, if it is
1602
- true, the cell will construct backward network in pynative mode. Default: True.
1843
+ ``true`` , the cell will construct backward network in pynative mode. Default: ``True`` .
1603
1844
 
1604
1845
  Returns:
1605
1846
  Cell, the cell itself.
@@ -1620,15 +1861,19 @@ class Cell(Cell_):
1620
1861
  When execute function Model.eval(), framework will call Cell.set_train(False).
1621
1862
 
1622
1863
  Args:
1623
- mode (bool): Specifies whether the model is training. Default: True.
1864
+ mode (bool): Specifies whether the model is training. Default: ``True`` .
1624
1865
 
1625
1866
  Returns:
1626
1867
  Cell, the cell itself.
1868
+
1869
+ Tutorial Examples:
1870
+ - `Model Training - Implementing Training and Evaluation
1871
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
1627
1872
  """
1628
- if mode is False:
1629
- self._phase = 'predict'
1630
- else:
1873
+ if mode:
1631
1874
  self._phase = 'train'
1875
+ else:
1876
+ self._phase = 'predict'
1632
1877
  self.add_flags_recursive(training=mode)
1633
1878
  return self
1634
1879
 
@@ -1637,7 +1882,7 @@ class Cell(Cell_):
1637
1882
  Set parameter broadcast mode for this cell.
1638
1883
 
1639
1884
  Args:
1640
- mode (bool): Specifies whether the mode is parameter broadcast. Default: True.
1885
+ mode (bool): Specifies whether the mode is parameter broadcast. Default: ``True`` .
1641
1886
  """
1642
1887
  self.add_flags_recursive(broadcast_flag=mode)
1643
1888
  return self
@@ -1657,16 +1902,27 @@ class Cell(Cell_):
1657
1902
 
1658
1903
  Args:
1659
1904
  jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
1905
+
1906
+ Examples:
1907
+ >>> import mindspore as ms
1908
+ >>> from mindspore import Tensor, nn
1909
+ ...
1910
+ >>> class Net(nn.Cell):
1911
+ ... def __init__(self):
1912
+ ... super(Net, self).__init__()
1913
+ ... self.relu = nn.ReLU()
1914
+ ...
1915
+ ... def construct(self, x):
1916
+ ... x = self.relu(x)
1917
+ ... return x
1918
+ >>> net = Net()
1919
+ >>> jitconfig = ms.JitConfig()
1920
+ >>> net.set_jit_config(jitconfig)
1660
1921
  """
1661
1922
  if self._jit_config_dict:
1662
1923
  logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1663
1924
  else:
1664
1925
  self._jit_config_dict = jit_config.jit_config_dict
1665
- enable_ge = os.getenv("MS_ENABLE_GE") == '1'
1666
- enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
1667
- if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
1668
- raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
1669
- format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
1670
1926
 
1671
1927
  def flatten_weights(self, fusion_size=0):
1672
1928
  """
@@ -1679,7 +1935,7 @@ class Cell(Cell_):
1679
1935
  to limit the maximum memory chunk size.
1680
1936
 
1681
1937
  Args:
1682
- fusion_size (int): Maximum memory chunk size in bytes, 0 for unlimited. Default: 0.
1938
+ fusion_size (int): Maximum memory chunk size in bytes, ``0`` for unlimited. Default: ``0`` .
1683
1939
  """
1684
1940
  if fusion_size < 0:
1685
1941
  raise ValueError(f"Negative 'fusion_size' {fusion_size} is invalid.")
@@ -1718,9 +1974,7 @@ class Cell(Cell_):
1718
1974
  Examples:
1719
1975
  >>> import numpy as np
1720
1976
  >>> import mindspore as ms
1721
- >>> import mindspore.nn as nn
1722
- >>> from mindspore import Tensor
1723
- >>> from mindspore.ops import GradOperation
1977
+ >>> from mindspore import Tensor, nn, ops
1724
1978
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1725
1979
  >>> def forward_pre_hook_fn(cell_id, inputs):
1726
1980
  ... print("forward inputs: ", inputs)
@@ -1735,7 +1989,7 @@ class Cell(Cell_):
1735
1989
  ... x = x + x
1736
1990
  ... x = self.mul(x, y)
1737
1991
  ... return x
1738
- >>> grad = GradOperation(get_all=True)
1992
+ >>> grad = ops.GradOperation(get_all=True)
1739
1993
  >>> net = Net()
1740
1994
  >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
1741
1995
  forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
@@ -1820,9 +2074,7 @@ class Cell(Cell_):
1820
2074
  Examples:
1821
2075
  >>> import numpy as np
1822
2076
  >>> import mindspore as ms
1823
- >>> import mindspore.nn as nn
1824
- >>> from mindspore import Tensor
1825
- >>> from mindspore.ops import GradOperation
2077
+ >>> from mindspore import Tensor, nn, ops
1826
2078
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1827
2079
  >>> def forward_hook_fn(cell_id, inputs, output):
1828
2080
  ... print("forward inputs: ", inputs)
@@ -1838,7 +2090,7 @@ class Cell(Cell_):
1838
2090
  ... x = x + x
1839
2091
  ... x = self.mul(x, y)
1840
2092
  ... return x
1841
- >>> grad = GradOperation(get_all=True)
2093
+ >>> grad = ops.GradOperation(get_all=True)
1842
2094
  >>> net = Net()
1843
2095
  >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)))
1844
2096
  forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1],
@@ -1922,9 +2174,7 @@ class Cell(Cell_):
1922
2174
  Examples:
1923
2175
  >>> import numpy as np
1924
2176
  >>> import mindspore as ms
1925
- >>> import mindspore.nn as nn
1926
- >>> from mindspore import Tensor
1927
- >>> from mindspore.ops import GradOperation
2177
+ >>> from mindspore import Tensor, nn, ops
1928
2178
  >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1929
2179
  >>> def backward_hook_fn(cell_id, grad_input, grad_output):
1930
2180
  ... print("backward input: ", grad_input)
@@ -1940,7 +2190,7 @@ class Cell(Cell_):
1940
2190
  ... x = x + x
1941
2191
  ... x = self.relu(x)
1942
2192
  ... return x
1943
- >>> grad = GradOperation(get_all=True)
2193
+ >>> grad = ops.GradOperation(get_all=True)
1944
2194
  >>> net = Net()
1945
2195
  >>> output = grad(net)(Tensor(np.ones([1]).astype(np.float32)))
1946
2196
  backward input: (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)
@@ -1966,12 +2216,13 @@ class Cell(Cell_):
1966
2216
  handle = HookHandle(self, backward_hook_key, "_cell_backward_hook")
1967
2217
  return handle
1968
2218
 
1969
- def _backward_hook_construct(self, *inputs):
2219
+ def _backward_hook_construct(self, *inputs, **kwargs):
1970
2220
  """
1971
2221
  Backward hook construct method to replace original construct method.
1972
2222
 
1973
2223
  Args:
1974
2224
  inputs: The input objects of Cell object.
2225
+ kwargs (dict): Dictionary of variable keyword parameters.
1975
2226
 
1976
2227
  Returns:
1977
2228
  - **outputs** - The output objects of Cell object.
@@ -1983,10 +2234,11 @@ class Cell(Cell_):
1983
2234
  inputs = self._cell_backward_hook(inputs)
1984
2235
  else:
1985
2236
  inputs = self._cell_backward_hook(*inputs)
2237
+ inputs = (inputs,)
1986
2238
  if isinstance(inputs, tuple):
1987
- outputs = self.construct(*inputs)
2239
+ outputs = self.construct(*inputs, **kwargs)
1988
2240
  else:
1989
- outputs = self.construct(inputs)
2241
+ outputs = self.construct(inputs, **kwargs)
1990
2242
  outputs = self._cell_backward_hook(outputs)
1991
2243
  return outputs
1992
2244
 
@@ -2000,23 +2252,16 @@ class Cell(Cell_):
2000
2252
  It is only supported in graph mode.
2001
2253
 
2002
2254
  Args:
2003
- recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
2255
+ recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2004
2256
  init_in_server (bool): Whether trainable parameters updated by parameter server are
2005
- initialized on server. Default: False.
2257
+ initialized on server. Default: ``False`` .
2006
2258
  """
2007
2259
  params = self.trainable_params(recurse)
2008
2260
  for param in params:
2009
2261
  param.set_param_ps(init_in_server)
2010
2262
 
2263
+ @deprecated("1.8", "set_param_fl")
2011
2264
  def set_param_fl(self, push_to_server=False, pull_from_server=False, requires_aggr=True):
2012
- """
2013
- Set the way of parameter and server interaction.
2014
-
2015
- Args:
2016
- push_to_server (bool): Whether the parameter should be pushed to server. Default: False.
2017
- pull_from_server (bool): Whether the parameter should be pulled from server. Default: False.
2018
- requires_aggr (bool): Whether the parameter should be aggregated in the server. Default: True.
2019
- """
2020
2265
  params = self.parameters_and_names()
2021
2266
  for param in params:
2022
2267
  param[1].set_param_fl(push_to_server, pull_from_server, requires_aggr)
@@ -2031,7 +2276,7 @@ class Cell(Cell_):
2031
2276
 
2032
2277
  Args:
2033
2278
  fusion_type (int): The value of `comm_fusion`.
2034
- recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
2279
+ recurse (bool): Whether sets the trainable parameters of subcells. Default: ``True`` .
2035
2280
  """
2036
2281
  Validator.check_non_negative_int(fusion_type)
2037
2282
  for param in self.trainable_params(recurse):
@@ -2118,10 +2363,10 @@ class Cell(Cell_):
2118
2363
 
2119
2364
  Args:
2120
2365
  mp_comm_recompute (bool): Specifies whether the model parallel communication operators
2121
- in the cell are recomputed in auto parallel or semi auto parallel mode. Default: True.
2366
+ in the cell are recomputed in auto parallel or semi auto parallel mode. Default: ``True`` .
2122
2367
  parallel_optimizer_comm_recompute (bool): Specifies whether the communication operator allgathers
2123
2368
  introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
2124
- Default: False.
2369
+ Default: ``False`` .
2125
2370
  """
2126
2371
  self._recompute()
2127
2372
  if 'mp_comm_recompute' in kwargs.keys():
@@ -2133,7 +2378,7 @@ class Cell(Cell_):
2133
2378
  "are not support recomputation in pipeline parallel.")
2134
2379
  elif context.get_auto_parallel_context("pipeline_stages") == 1:
2135
2380
  self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
2136
- if 'recompute_slice_activation' in kwargs.keys():
2381
+ if 'recompute_slice_activation' in kwargs:
2137
2382
  self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
2138
2383
 
2139
2384
  for key, _ in kwargs.items():
@@ -2217,19 +2462,29 @@ class Cell(Cell_):
2217
2462
  """
2218
2463
  if not isinstance(net_input, Tensor):
2219
2464
  raise TypeError(
2220
- f"The {index + 1}th input type of 'set_inputs' must be Tensor, but got {type(net_input)}.")
2465
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the type of {index + 1}th input must be Tensor, "
2466
+ f"but got {type(net_input)}.")
2467
+ is_param_set_input = isinstance(set_input, Parameter)
2468
+ is_param_net_input = isinstance(net_input, Parameter)
2469
+ if (is_param_set_input and not is_param_net_input) or (is_param_net_input and not is_param_set_input):
2470
+ raise TypeError(
2471
+ f"For 'set_inputs' and tuple(list) in 'set_inputs', the {index + 1}th input must be the same "
2472
+ f"as network's input, but got 'set_inputs': {type(set_input)} and network's input: {type(net_input)}.")
2221
2473
  if set_input.dtype != net_input.dtype:
2222
- raise ValueError(
2223
- f"The {index + 1}th input type of 'set_inputs' must be the same as network's input, "
2224
- f"but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
2225
- if net_input.dim() != 0 and set_input.dim() != net_input.dim():
2226
- raise ValueError(
2227
- f"The {index + 1}th input dims of 'set_inputs' must be the same as network's input, "
2228
- f"but got 'set_inputs': {set_input.dim()} and network's input: {net_input.dim()}.")
2229
- if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
2230
- raise ValueError(
2231
- f"The {index + 1}th input shape of 'set_inputs' must be the same as network's input, "
2232
- f"but got 'set_inputs': {set_input.shape} and network's input: {net_input.shape}.")
2474
+ raise TypeError(
2475
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the dtype of {index + 1}th input must be the same "
2476
+ f"as network's input, but got 'set_inputs': {set_input.dtype} and network's input: {net_input.dtype}.")
2477
+ if -2 not in set_input.shape:
2478
+ if net_input.dim() != 0 and set_input.dim() != net_input.dim():
2479
+ raise ValueError(
2480
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the dims of {index + 1}th input must be the "
2481
+ f"same as network's input, but got 'set_inputs': {set_input.dim()} and network's input: "
2482
+ f"{net_input.dim()}.")
2483
+ if not all([ele1 in (-1, ele2) for ele1, ele2 in zip(set_input.shape, net_input.shape)]):
2484
+ raise ValueError(
2485
+ f"For 'set_inputs' and tuple(list) in 'set_inputs',the shape of {index + 1}th input must be the "
2486
+ f"same as network's input, but got 'set_inputs': {set_input.shape} and network's input: "
2487
+ f"{net_input.shape}.")
2233
2488
 
2234
2489
  def _check_compile_dynamic_shape(self, set_inputs, net_inputs):
2235
2490
  """
@@ -2241,22 +2496,61 @@ class Cell(Cell_):
2241
2496
  set_inputs_len = len(set_inputs)
2242
2497
  net_inputs_len = len(net_inputs)
2243
2498
  if set_inputs_len != net_inputs_len:
2244
- raise ValueError("The length of 'set_inputs' must be equal to network's inputs, "
2245
- f"but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
2499
+ raise ValueError("The length of 'set_inputs' or tuple(list) in 'set_inputs' must be equal to network's "
2500
+ f"inputs, but got 'set_inputs': {set_inputs_len} and network's input: {net_inputs_len}.")
2246
2501
  for index, (set_input, net_input) in enumerate(zip(set_inputs, net_inputs)):
2247
2502
  if isinstance(set_input, Tensor):
2248
2503
  self._check_dynamic_tensor(set_input, net_input, index)
2249
2504
  elif isinstance(set_input, (tuple, list)):
2250
2505
  if not isinstance(net_input, (tuple, list)):
2251
2506
  raise TypeError(
2252
- f"The {index + 1}th input type of 'set_inputs' must be tuple or list, "
2253
- f"but got {type(net_input)}.")
2507
+ f"The {index + 1}th input type of 'set_inputs' or tuple(list) in 'set_inputs' must be tuple or "
2508
+ f"list, but got {type(net_input)}.")
2254
2509
  self._check_compile_dynamic_shape(set_input, net_input)
2255
2510
  else:
2511
+ if context._get_mode() == context.PYNATIVE_MODE and set_input is None:
2512
+ continue
2256
2513
  if net_input != set_input:
2257
2514
  raise ValueError(
2258
- f"The {index + 1}th input of 'set_inputs' must be the same with network's input, but got "
2259
- f"set_inputs: {set_input} and network's input: {net_input}.")
2515
+ f"The {index + 1}th input of 'set_inputs' or tuple(list) in 'set_inputs' must be the same with "
2516
+ f"network's input, but got set_inputs: {set_input} and network's input: {net_input}.")
2517
+
2518
+ def _run_tracefunc(self, *args, **kwargs):
2519
+ """ Run Packed Cell in Pack."""
2520
+ args = self._mixed_precision_cast(args)
2521
+ need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
2522
+ if not PackFunc.current.is_pynative_mode and need_subgraph:
2523
+ expander = PackExpander.get_instance()
2524
+ args = expander.begin_subgraph(self, *args)
2525
+ args = [_convert_tensor(a) for a in args]
2526
+ output = self._run_construct(args, kwargs)
2527
+ ret = expander.end_subgraph(self, output)
2528
+ output = _convert_tensor(ret)
2529
+ else:
2530
+ with _SetMixedPrecision(self):
2531
+ output = self._run_construct(args, kwargs)
2532
+ return output
2533
+
2534
+ def _mixed_precision_cast(self, inputs):
2535
+ mixed_type = self.get_mixed_precision_type()
2536
+ if mixed_type == MixedPrecisionType.NOTSET:
2537
+ return inputs
2538
+ if mixed_type == MixedPrecisionType.FP16:
2539
+ cast_type = mstype.float16
2540
+ elif mixed_type == MixedPrecisionType.BF16:
2541
+ cast_type = mstype.bfloat16
2542
+ else:
2543
+ cast_type = mstype.float32
2544
+ cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
2545
+ return cast_inputs
2546
+
2547
+ def _get_attr_from_cell(self, network):
2548
+ if not isinstance(network, Cell):
2549
+ return
2550
+ if hasattr(network, "jit_config_dict"):
2551
+ self._jit_config_dict = network.jit_config_dict
2552
+ if hasattr(network, "_amp_level"):
2553
+ self._amp_level = getattr(network, "_amp_level")
2260
2554
 
2261
2555
 
2262
2556
  class GraphCell(Cell):
@@ -2271,11 +2565,11 @@ class GraphCell(Cell):
2271
2565
  params_init (dict): Parameters need to be inited in the graph.
2272
2566
  The key is the parameter name whose type is str, and the value is a Tensor or Parameter.
2273
2567
  If the parameter exists in the graph according to the name, update it's value.
2274
- If the parameter does not exist, ignore it. Default: None.
2568
+ If the parameter does not exist, ignore it. Default: ``None`` .
2275
2569
  obf_random_seed (Union[int, None]): The random seed used for dynamic obfuscation. "dynamic obfuscation" is
2276
2570
  used for model protection, which can refer to :func:`mindspore.obfuscate_model`. If the input `graph` is
2277
2571
  a func_graph loaded from a mindir file obfuscated with `obf_random_seed` , then `obf_random_seed` should be
2278
- provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: None.
2572
+ provided. `obf_random_seed` should be in (0, 9223372036854775807]. default: ``None`` .
2279
2573
 
2280
2574
  Raises:
2281
2575
  TypeError: If the `graph` is not a FuncGraph.