mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -33,8 +33,7 @@ from mindspore import context
33
33
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
34
34
  from mindspore import _checkparam as Validator
35
35
  from mindspore.common import dtype as mstype
36
- from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
37
- _AutoIdentifyDynamicShape
36
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
38
37
  from mindspore.common.api import _generate_branch_control_input
39
38
  from mindspore.common.parameter import Parameter, ParameterTuple
40
39
  from mindspore.common.tensor import Tensor
@@ -65,6 +64,15 @@ class Cell(Cell_):
65
64
  graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
66
65
  PYNATIVE_MODE (dynamic graph mode).
67
66
 
67
+ .. note::
68
+ Cell is the inference mode by default. For a class that inherits a Cell,
69
+ if the training and inference have different structures, the subclass performs the inference branch by default.
70
+ To set the training mode, refer to `mindspore.nn.Cell.set_train` .
71
+
72
+ .. warning::
73
+ In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
74
+ named 'phase' or 'cells', otherwise, an error will be raised.
75
+
68
76
  Args:
69
77
  auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
70
78
  affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
@@ -156,11 +164,9 @@ class Cell(Cell_):
156
164
  self.saved_dynamic_shape = None
157
165
  self._jit_config_dict = dict()
158
166
  self.grad_ops_label = False
159
- self.to_float_fp16 = False
160
- self.ge_init = False
161
167
  self.ge_sync_data = False
162
- self.auto_identify_dynamic_shape = _AutoIdentifyDynamicShape()
163
-
168
+ self._is_check_and_refresh = False
169
+ self._amp_level = ""
164
170
 
165
171
  def __getstate__(self):
166
172
  base = Cell_.__getstate__(self)
@@ -192,6 +198,23 @@ class Cell(Cell_):
192
198
  def param_prefix(self):
193
199
  """
194
200
  Param prefix is the prefix of current cell's direct child parameter.
201
+
202
+ Examples:
203
+ >>> import mindspore as ms
204
+ >>> from mindspore import Tensor, nn
205
+ ...
206
+ >>> class Net(nn.Cell):
207
+ ... def __init__(self):
208
+ ... super(Net, self).__init__()
209
+ ... self.dense = nn.Dense(2, 2)
210
+ ...
211
+ ... def construct(self, x):
212
+ ... x = self.dense(x)
213
+ ... return x
214
+ >>> net = Net()
215
+ >>> net.update_cell_prefix()
216
+ >>> print(net.dense.param_prefix)
217
+ dense
195
218
  """
196
219
  return self._param_prefix
197
220
 
@@ -202,7 +225,7 @@ class Cell(Cell_):
202
225
 
203
226
  Tutorial Examples:
204
227
  - `Cell and Parameter - Custom Cell Reverse
205
- <https://mindspore.cn/tutorials/en/r2.1/advanced/modules/layer.html#custom-cell-reverse>`_
228
+ <https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
206
229
  """
207
230
  return self._bprop_debug
208
231
 
@@ -309,6 +332,21 @@ class Cell(Cell_):
309
332
  for item in self.trainable_params():
310
333
  item.add_pipeline_stage(value)
311
334
 
335
+ @property
336
+ def pipeline_segment(self):
337
+ return self._pipeline_segment
338
+
339
+ @pipeline_segment.setter
340
+ def pipeline_segment(self, value):
341
+ if not isinstance(value, int) or isinstance(value, bool):
342
+ raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
343
+ "must be int type, but got type : {}".format(type(value)))
344
+
345
+ if value < 0:
346
+ raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
347
+ "can not be less than 0, but got {}".format(value))
348
+ self._pipeline_segment = value
349
+
312
350
  @property
313
351
  def parallel_parameter_merge_net_dict(self):
314
352
  return self._parallel_parameter_merge_net_dict
@@ -345,7 +383,7 @@ class Cell(Cell_):
345
383
  if '_params_list' in self.__dict__:
346
384
  params_list = self.__dict__['_params_list']
347
385
  if name in params_list:
348
- return ParameterTuple(params_list[name])
386
+ return params_list[name]
349
387
  raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
350
388
 
351
389
  def __del__(self):
@@ -365,11 +403,11 @@ class Cell(Cell_):
365
403
  del self._params[name]
366
404
  elif name in self._cells:
367
405
  del self._cells[name]
406
+ elif '_params_list' in self.__dict__ and name in self._params_list:
407
+ del self._params_list[name]
408
+ elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
409
+ del self._tensor_list[name]
368
410
  else:
369
- if '_params_list' in self.__dict__ and name in self._params_list:
370
- del self._params_list[name]
371
- elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
372
- del self._tensor_list[name]
373
411
  object.__delattr__(self, name)
374
412
  self._attr_synced = False
375
413
 
@@ -381,8 +419,8 @@ class Cell(Cell_):
381
419
  res.append(self._cast_mixed_precision_inputs(item, dst_type))
382
420
  elif isinstance(item, float):
383
421
  res.append(self.cast(item, dst_type))
384
- elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64} and \
385
- item.dtype != dst_type:
422
+ elif hasattr(item, "dtype") and item.dtype in \
423
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
386
424
  res.append(self.cast(item, dst_type))
387
425
  else:
388
426
  res.append(item)
@@ -629,7 +667,10 @@ class Cell(Cell_):
629
667
  if PackFunc.is_tracing():
630
668
  return self._run_tracefunc(*args, **kwargs)
631
669
 
632
- self.check_names_and_refresh_name()
670
+ if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
671
+ self.check_names_and_refresh_name()
672
+ self._is_check_and_refresh = True
673
+
633
674
  # Run in Graph mode.
634
675
  if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
635
676
  self._check_construct_args(*args)
@@ -886,14 +927,14 @@ class Cell(Cell_):
886
927
  >>> import mindspore as ms
887
928
  >>> from mindspore import nn, Tensor
888
929
  >>>
889
- >>> class reluNet(nn.Cell):
930
+ >>> class ReluNet(nn.Cell):
890
931
  ... def __init__(self):
891
- ... super(reluNet, self).__init__()
932
+ ... super(ReluNet, self).__init__()
892
933
  ... self.relu = nn.ReLU()
893
934
  ... def construct(self, x):
894
935
  ... return self.relu(x)
895
936
  >>>
896
- >>> net = reluNet()
937
+ >>> net = ReluNet()
897
938
  >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
898
939
  >>> net.set_inputs(input_dyn)
899
940
  >>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
@@ -902,13 +943,10 @@ class Cell(Cell_):
902
943
  if self.grad_ops_label:
903
944
  logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
904
945
  f'generated.')
905
- for ele in inputs:
906
- if isinstance(ele, str):
907
- raise TypeError(f"For element in 'set_inputs', the type must not be str.")
908
946
  self._dynamic_shape_inputs = inputs
909
947
  self._check_construct_args(*inputs)
910
948
  if context._get_mode() == context.PYNATIVE_MODE:
911
- _pynative_executor.set_dynamic_input(self)
949
+ _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
912
950
 
913
951
  def get_inputs(self):
914
952
  """
@@ -919,6 +957,26 @@ class Cell(Cell_):
919
957
 
920
958
  .. warning::
921
959
  This is an experimental API that is subject to change or deletion.
960
+
961
+ Examples:
962
+ >>> import numpy as np
963
+ >>> import mindspore as ms
964
+ >>> from mindspore import nn, Tensor
965
+ >>>
966
+ >>> class ReluNet(nn.Cell):
967
+ ... def __init__(self):
968
+ ... super(ReluNet, self).__init__()
969
+ ... self.relu = nn.ReLU()
970
+ ... def construct(self, x):
971
+ ... return self.relu(x)
972
+ >>>
973
+ >>> net = ReluNet()
974
+ >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
975
+ >>> net.set_inputs(input_dyn)
976
+ >>> get_inputs = net.get_inputs()
977
+ >>> print(get_inputs)
978
+ (Tensor(shape=[3, -1], dtype=Float32, value= ),)
979
+
922
980
  """
923
981
 
924
982
  return self._dynamic_shape_inputs
@@ -936,9 +994,8 @@ class Cell(Cell_):
936
994
  self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
937
995
 
938
996
  if self._dynamic_shape_inputs is None:
939
- compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args)
940
997
  _cell_graph_executor.compile(self, phase=self.phase,
941
- jit_config_dict=self._jit_config_dict, *compile_args, **kwargs)
998
+ jit_config_dict=self._jit_config_dict, *args, **kwargs)
942
999
  else:
943
1000
  self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
944
1001
  self.saved_dynamic_shape = self._dynamic_shape_inputs
@@ -994,6 +1051,23 @@ class Cell(Cell_):
994
1051
  Raises:
995
1052
  KeyError: If the name of parameter is null or contains dot.
996
1053
  TypeError: If the type of parameter is not Parameter.
1054
+
1055
+ Examples:
1056
+ >>> import mindspore as ms
1057
+ >>> from mindspore import Tensor, nn, Parameter
1058
+ ...
1059
+ >>> class Net(nn.Cell):
1060
+ ... def __init__(self):
1061
+ ... super(Net, self).__init__()
1062
+ ... self.relu = nn.ReLU()
1063
+ ...
1064
+ ... def construct(self, x):
1065
+ ... x = self.relu(x)
1066
+ ... return x
1067
+ >>> net = Net()
1068
+ >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
1069
+ >>> print(net.bias)
1070
+ Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
997
1071
  """
998
1072
  if not param_name:
999
1073
  raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
@@ -1007,6 +1081,9 @@ class Cell(Cell_):
1007
1081
  if not isinstance(param, Parameter) and param is not None:
1008
1082
  raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must be 'Parameter' if not None, "
1009
1083
  f"but got {type(param)}.")
1084
+ if param is None:
1085
+ raise TypeError(f"For 'insert_param_to_cell', the argument 'param' must not be None, "
1086
+ f"but got None.")
1010
1087
  if isinstance(param, Parameter) and param.name == PARAMETER_NAME_DEFAULT:
1011
1088
  param.name = param_name
1012
1089
  self._params[param_name] = param
@@ -1048,6 +1125,18 @@ class Cell(Cell_):
1048
1125
  KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1049
1126
  TypeError: If type of `child_name` is not str.
1050
1127
  TypeError: Child Cell's type is incorrect.
1128
+
1129
+ Examples:
1130
+ >>> import mindspore as ms
1131
+ >>> from mindspore import Tensor, nn
1132
+ ...
1133
+ >>> net1 = nn.ReLU()
1134
+ >>> net2 = nn.Dense(2, 2)
1135
+ >>> net1.insert_child_to_cell("child", net2)
1136
+ >>> print(net1)
1137
+ ReLU<
1138
+ (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1139
+ >
1051
1140
  """
1052
1141
  if not isinstance(child_name, str):
1053
1142
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
@@ -1118,6 +1207,25 @@ class Cell(Cell_):
1118
1207
 
1119
1208
  Returns:
1120
1209
  Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
1210
+
1211
+ Examples:
1212
+ >>> import mindspore as ms
1213
+ >>> from mindspore import Tensor, nn
1214
+ ...
1215
+ >>> class Net(nn.Cell):
1216
+ ... def __init__(self):
1217
+ ... super(Net, self).__init__()
1218
+ ... self.dense = nn.Dense(2, 2)
1219
+ ...
1220
+ ... def construct(self, x):
1221
+ ... x = self.dense(x)
1222
+ ... return x
1223
+ >>> net = Net()
1224
+ >>> print(net.init_parameters_data())
1225
+ {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
1226
+ Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
1227
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
1228
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
1121
1229
  """
1122
1230
  replace = dict()
1123
1231
 
@@ -1163,6 +1271,24 @@ class Cell(Cell_):
1163
1271
 
1164
1272
  Returns:
1165
1273
  OrderedDict, return parameters dictionary.
1274
+
1275
+ Examples:
1276
+ >>> import mindspore as ms
1277
+ >>> from mindspore import Tensor, nn, Parameter
1278
+ ...
1279
+ >>> class Net(nn.Cell):
1280
+ ... def __init__(self):
1281
+ ... super(Net, self).__init__()
1282
+ ... self.dense = nn.Dense(2, 2)
1283
+ ...
1284
+ ... def construct(self, x):
1285
+ ... x = self.dense(x)
1286
+ ... return x
1287
+ >>> net = Net()
1288
+ >>> print(net.parameters_dict())
1289
+ OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
1290
+ requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
1291
+ requires_grad=True))])
1166
1292
  """
1167
1293
  param_dict = OrderedDict()
1168
1294
  for param in self.get_parameters(expand=recurse):
@@ -1238,7 +1364,7 @@ class Cell(Cell_):
1238
1364
 
1239
1365
  Tutorial Examples:
1240
1366
  - `Model Training - Optimizer
1241
- <https://mindspore.cn/tutorials/en/r2.1/beginner/train.html#optimizer>`_
1367
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
1242
1368
  """
1243
1369
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1244
1370
 
@@ -1263,6 +1389,7 @@ class Cell(Cell_):
1263
1389
  Returns an iterator over cell parameters.
1264
1390
 
1265
1391
  Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
1392
+ For more details about subcells, please see the example below.
1266
1393
 
1267
1394
  Args:
1268
1395
  expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
@@ -1272,11 +1399,34 @@ class Cell(Cell_):
1272
1399
  Iteration, all parameters at the cell.
1273
1400
 
1274
1401
  Examples:
1275
- >>> from mindspore import nn
1276
- >>> net = nn.Dense(3, 4)
1277
- >>> parameters = []
1278
- >>> for item in net.get_parameters():
1279
- ... parameters.append(item)
1402
+ >>> import mindspore as ms
1403
+ >>> from mindspore import nn, ops, Tensor
1404
+ >>> import numpy as np
1405
+ >>> class TestNet(nn.Cell):
1406
+ ... def __init__(self):
1407
+ ... super().__init__()
1408
+ ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1409
+ ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
1410
+ ... def construct(self, x):
1411
+ ... x += self.my_w1
1412
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1413
+ ... return x
1414
+ >>> class TestNet2(nn.Cell):
1415
+ ... def __init__(self):
1416
+ ... super().__init__()
1417
+ ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1418
+ ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
1419
+ ... # also be gathered.
1420
+ ... self.subcell = TestNet()
1421
+ ... def construct(self, x):
1422
+ ... x += self.my_w1
1423
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1424
+ ... return x
1425
+ >>> net = TestNet2()
1426
+ >>> print([p for p in net.get_parameters(expand=True)])
1427
+ [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
1428
+ shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
1429
+ requires_grad=True)]
1280
1430
  """
1281
1431
  for _, param in self.parameters_and_names(expand=expand):
1282
1432
  yield param
@@ -1325,7 +1475,7 @@ class Cell(Cell_):
1325
1475
 
1326
1476
  Tutorial Examples:
1327
1477
  - `Building a Network - Model Parameters
1328
- <https://mindspore.cn/tutorials/en/r2.1/beginner/model.html#model-parameters>`_
1478
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
1329
1479
  """
1330
1480
  cells = []
1331
1481
  if expand:
@@ -1337,7 +1487,7 @@ class Cell(Cell_):
1337
1487
  for cell_name, cell in cells:
1338
1488
  params = cell._params.items()
1339
1489
  for par_name, par in params:
1340
- if par.inited_param is not None:
1490
+ if par is not None and par.inited_param is not None:
1341
1491
  par = par.inited_param
1342
1492
  if par is not None and id(par) not in params_set:
1343
1493
  params_set.add(id(par))
@@ -1394,6 +1544,22 @@ class Cell(Cell_):
1394
1544
 
1395
1545
  Returns:
1396
1546
  Iteration, the immediate cells in the cell.
1547
+
1548
+ Examples:
1549
+ >>> import mindspore as ms
1550
+ >>> from mindspore import Tensor, nn
1551
+ ...
1552
+ >>> class Net(nn.Cell):
1553
+ ... def __init__(self):
1554
+ ... super(Net, self).__init__()
1555
+ ... self.dense = nn.Dense(2, 2)
1556
+ ...
1557
+ ... def construct(self, x):
1558
+ ... x = self.dense(x)
1559
+ ... return x
1560
+ >>> net = Net()
1561
+ >>> print(net.cells())
1562
+ odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
1397
1563
  """
1398
1564
  return self.name_cells().values()
1399
1565
 
@@ -1439,6 +1605,22 @@ class Cell(Cell_):
1439
1605
 
1440
1606
  Returns:
1441
1607
  Dict, all the child cells and corresponding names in the cell.
1608
+
1609
+ Examples:
1610
+ >>> import mindspore as ms
1611
+ >>> from mindspore import Tensor, nn
1612
+ ...
1613
+ >>> class Net(nn.Cell):
1614
+ ... def __init__(self):
1615
+ ... super(Net, self).__init__()
1616
+ ... self.dense = nn.Dense(2, 2)
1617
+ ...
1618
+ ... def construct(self, x):
1619
+ ... x = self.dense(x)
1620
+ ... return x
1621
+ >>> net = Net()
1622
+ >>> print(net.name_cells())
1623
+ OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
1442
1624
  """
1443
1625
  value_set = set()
1444
1626
  cells = OrderedDict()
@@ -1454,13 +1636,8 @@ class Cell(Cell_):
1454
1636
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1455
1637
  if "fp32" in flags and flags.get("fp32", False):
1456
1638
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1457
-
1458
- def _add_mixed_precision_flag_recursive(self, **flags):
1459
- """Add mixed precision flag to each cell"""
1460
- if "fp16" in flags and flags.get("fp16", False):
1461
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1462
- if "fp32" in flags and flags.get("fp32", False):
1463
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1639
+ if "bf16" in flags and flags.get("bf16", False):
1640
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1464
1641
 
1465
1642
  def apply(self, fn):
1466
1643
  """
@@ -1503,6 +1680,23 @@ class Cell(Cell_):
1503
1680
  Args:
1504
1681
  flags (dict): Network configuration information, currently it is used for the binding of network and
1505
1682
  dataset. Users can also customize network attributes by this parameter.
1683
+
1684
+ Examples:
1685
+ >>> import mindspore as ms
1686
+ >>> from mindspore import Tensor, nn
1687
+ ...
1688
+ >>> class Net(nn.Cell):
1689
+ ... def __init__(self):
1690
+ ... super(Net, self).__init__()
1691
+ ... self.relu = nn.ReLU()
1692
+ ...
1693
+ ... def construct(self, x):
1694
+ ... x = self.relu(x)
1695
+ ... return x
1696
+ >>> net = Net()
1697
+ >>> net.add_flags(sink_mode=True)
1698
+ >>> print(net.sink_mode)
1699
+ True
1506
1700
  """
1507
1701
  if not hasattr(self, "_func_graph_flags"):
1508
1702
  self._func_graph_flags = {}
@@ -1518,9 +1712,25 @@ class Cell(Cell_):
1518
1712
  Args:
1519
1713
  flags (dict): Network configuration information, currently it is used for the binding of network and
1520
1714
  dataset. Users can also customize network attributes by this parameter.
1715
+
1716
+ Examples:
1717
+ >>> import mindspore as ms
1718
+ >>> from mindspore import Tensor, nn
1719
+ ...
1720
+ >>> class Net(nn.Cell):
1721
+ ... def __init__(self):
1722
+ ... super(Net, self).__init__()
1723
+ ... self.relu = nn.ReLU()
1724
+ ...
1725
+ ... def construct(self, x):
1726
+ ... x = self.relu(x)
1727
+ ... return x
1728
+ >>> net = Net()
1729
+ >>> net.add_flags_recursive(sink_mode=True)
1730
+ >>> print(net.sink_mode)
1731
+ True
1521
1732
  """
1522
1733
  self.add_flags(**flags)
1523
- self._add_mixed_precision_flag_recursive(**flags)
1524
1734
  for cell in self.cells():
1525
1735
  cell.add_flags_recursive(**flags)
1526
1736
  return self
@@ -1532,17 +1742,28 @@ class Cell(Cell_):
1532
1742
  def get_flags(self):
1533
1743
  """
1534
1744
  Get the self_defined attributes of the cell, which can be added by `add_flags` method.
1745
+
1746
+ Examples:
1747
+ >>> import mindspore as ms
1748
+ >>> from mindspore import Tensor, nn
1749
+ ...
1750
+ >>> class Net(nn.Cell):
1751
+ ... def __init__(self):
1752
+ ... super(Net, self).__init__()
1753
+ ... self.relu = nn.ReLU()
1754
+ ...
1755
+ ... def construct(self, x):
1756
+ ... x = self.relu(x)
1757
+ ... return x
1758
+ >>> net = Net()
1759
+ >>> net.add_flags(sink_mode=True)
1760
+ >>> print(net.get_flags())
1761
+ {'sink_mode':True}
1535
1762
  """
1536
1763
  if not hasattr(self, "_func_graph_flags"):
1537
1764
  self._func_graph_flags = {}
1538
1765
  return self._func_graph_flags
1539
1766
 
1540
- def _set_mixed_precision_type_recursive(self, mixed_type):
1541
- """Set mixed precision type to each cell"""
1542
- Cell_.set_mixed_precision_type(self, mixed_type)
1543
- for cell in self.cells():
1544
- cell._set_mixed_precision_type_recursive(mixed_type)
1545
-
1546
1767
  def to_float(self, dst_type):
1547
1768
  """
1548
1769
  Add cast on all inputs of cell and child cells to run with certain float type.
@@ -1555,13 +1776,13 @@ class Cell(Cell_):
1555
1776
 
1556
1777
  Args:
1557
1778
  dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1558
- dst_type can be `mstype.float16` or `mstype.float32`.
1779
+ dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
1559
1780
 
1560
1781
  Returns:
1561
1782
  Cell, the cell itself.
1562
1783
 
1563
1784
  Raises:
1564
- ValueError: If dst_type is not mstype.float32 or mstype.float16.
1785
+ ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
1565
1786
 
1566
1787
  Supported Platforms:
1567
1788
  ``Ascend`` ``GPU`` ``CPU``
@@ -1573,19 +1794,15 @@ class Cell(Cell_):
1573
1794
  >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1574
1795
  >>> net.to_float(mstype.float16)
1575
1796
  Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1576
- padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
1577
- """
1578
- if dst_type not in (mstype.float16, mstype.float32):
1579
- raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32 or mstype.float16, "
1580
- "but got type: {} and value: {}.".format(type(dst_type), dst_type))
1581
- if dst_type == mstype.float16:
1582
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1583
- self.to_float_fp16 = True
1584
- else:
1585
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1586
- self.to_float_fp16 = False
1587
- flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
1797
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
1798
+ """
1799
+ if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1800
+ raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
1801
+ "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
1802
+ flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
1803
+ 'bf16': dst_type == mstype.bfloat16}
1588
1804
  self._add_init_args(**flags)
1805
+ self.add_flags_recursive(**flags)
1589
1806
  return self
1590
1807
 
1591
1808
  def set_boost(self, boost_type):
@@ -1594,7 +1811,7 @@ class Cell(Cell_):
1594
1811
  accelerate the algorithm in the algorithm library.
1595
1812
 
1596
1813
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1597
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.1/mindspore/python/mindspore/boost>`_.
1814
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
1598
1815
 
1599
1816
  Note:
1600
1817
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1651,12 +1868,12 @@ class Cell(Cell_):
1651
1868
 
1652
1869
  Tutorial Examples:
1653
1870
  - `Model Training - Implementing Training and Evaluation
1654
- <https://mindspore.cn/tutorials/en/r2.1/beginner/train.html#implementing-training-and-evaluation>`_
1871
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
1655
1872
  """
1656
- if mode is False:
1657
- self._phase = 'predict'
1658
- else:
1873
+ if mode:
1659
1874
  self._phase = 'train'
1875
+ else:
1876
+ self._phase = 'predict'
1660
1877
  self.add_flags_recursive(training=mode)
1661
1878
  return self
1662
1879
 
@@ -1685,16 +1902,27 @@ class Cell(Cell_):
1685
1902
 
1686
1903
  Args:
1687
1904
  jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
1905
+
1906
+ Examples:
1907
+ >>> import mindspore as ms
1908
+ >>> from mindspore import Tensor, nn
1909
+ ...
1910
+ >>> class Net(nn.Cell):
1911
+ ... def __init__(self):
1912
+ ... super(Net, self).__init__()
1913
+ ... self.relu = nn.ReLU()
1914
+ ...
1915
+ ... def construct(self, x):
1916
+ ... x = self.relu(x)
1917
+ ... return x
1918
+ >>> net = Net()
1919
+ >>> jitconfig = ms.JitConfig()
1920
+ >>> net.set_jit_config(jitconfig)
1688
1921
  """
1689
1922
  if self._jit_config_dict:
1690
1923
  logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1691
1924
  else:
1692
1925
  self._jit_config_dict = jit_config.jit_config_dict
1693
- enable_ge = os.getenv("MS_ENABLE_GE") == '1'
1694
- enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
1695
- if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
1696
- raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
1697
- format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
1698
1926
 
1699
1927
  def flatten_weights(self, fusion_size=0):
1700
1928
  """
@@ -2290,12 +2518,13 @@ class Cell(Cell_):
2290
2518
  def _run_tracefunc(self, *args, **kwargs):
2291
2519
  """ Run Packed Cell in Pack."""
2292
2520
  args = self._mixed_precision_cast(args)
2293
- if hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags():
2521
+ need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
2522
+ if not PackFunc.current.is_pynative_mode and need_subgraph:
2294
2523
  expander = PackExpander.get_instance()
2295
2524
  args = expander.begin_subgraph(self, *args)
2296
2525
  args = [_convert_tensor(a) for a in args]
2297
2526
  output = self._run_construct(args, kwargs)
2298
- ret = expander.end_subgraph(output)
2527
+ ret = expander.end_subgraph(self, output)
2299
2528
  output = _convert_tensor(ret)
2300
2529
  else:
2301
2530
  with _SetMixedPrecision(self):
@@ -2306,10 +2535,23 @@ class Cell(Cell_):
2306
2535
  mixed_type = self.get_mixed_precision_type()
2307
2536
  if mixed_type == MixedPrecisionType.NOTSET:
2308
2537
  return inputs
2309
- cast_type = mstype.float16 if mixed_type == MixedPrecisionType.FP16 else mstype.float32
2538
+ if mixed_type == MixedPrecisionType.FP16:
2539
+ cast_type = mstype.float16
2540
+ elif mixed_type == MixedPrecisionType.BF16:
2541
+ cast_type = mstype.bfloat16
2542
+ else:
2543
+ cast_type = mstype.float32
2310
2544
  cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
2311
2545
  return cast_inputs
2312
2546
 
2547
+ def _get_attr_from_cell(self, network):
2548
+ if not isinstance(network, Cell):
2549
+ return
2550
+ if hasattr(network, "jit_config_dict"):
2551
+ self._jit_config_dict = network.jit_config_dict
2552
+ if hasattr(network, "_amp_level"):
2553
+ self._amp_level = getattr(network, "_amp_level")
2554
+
2313
2555
 
2314
2556
  class GraphCell(Cell):
2315
2557
  """