mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -33,8 +33,7 @@ from mindspore import context
33
33
  from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
34
34
  from mindspore import _checkparam as Validator
35
35
  from mindspore.common import dtype as mstype
36
- from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache, \
37
- _AutoIdentifyDynamicShape
36
+ from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
38
37
  from mindspore.common.api import _generate_branch_control_input
39
38
  from mindspore.common.parameter import Parameter, ParameterTuple
40
39
  from mindspore.common.tensor import Tensor
@@ -65,6 +64,15 @@ class Cell(Cell_):
65
64
  graph in GRAPH_MODE (static graph mode) and used as the basic module of neural networks in
66
65
  PYNATIVE_MODE (dynamic graph mode).
67
66
 
67
+ .. note::
68
+ Cell is the inference mode by default. For a class that inherits a Cell,
69
+ if the training and inference have different structures, the subclass performs the inference branch by default.
70
+ To set the training mode, refer to `mindspore.nn.Cell.set_train` .
71
+
72
+ .. warning::
73
+ In the subclass of Cell, it's not allowed to define a method named 'cast' and not allowed to define an attribute
74
+ named 'phase' or 'cells', otherwise, an error will be raised.
75
+
68
76
  Args:
69
77
  auto_prefix (bool, optional): Whether to automatically generate NameSpace for Cell and its child cells. It also
70
78
  affects the names of parameters in the `Cell`. If set to ``True`` , the parameter name will be
@@ -156,11 +164,9 @@ class Cell(Cell_):
156
164
  self.saved_dynamic_shape = None
157
165
  self._jit_config_dict = dict()
158
166
  self.grad_ops_label = False
159
- self.to_float_fp16 = False
160
- self.ge_init = False
161
167
  self.ge_sync_data = False
162
- self.auto_identify_dynamic_shape = _AutoIdentifyDynamicShape()
163
-
168
+ self._is_check_and_refresh = False
169
+ self._amp_level = ""
164
170
 
165
171
  def __getstate__(self):
166
172
  base = Cell_.__getstate__(self)
@@ -192,6 +198,23 @@ class Cell(Cell_):
192
198
  def param_prefix(self):
193
199
  """
194
200
  Param prefix is the prefix of current cell's direct child parameter.
201
+
202
+ Examples:
203
+ >>> import mindspore as ms
204
+ >>> from mindspore import Tensor, nn
205
+ ...
206
+ >>> class Net(nn.Cell):
207
+ ... def __init__(self):
208
+ ... super(Net, self).__init__()
209
+ ... self.dense = nn.Dense(2, 2)
210
+ ...
211
+ ... def construct(self, x):
212
+ ... x = self.dense(x)
213
+ ... return x
214
+ >>> net = Net()
215
+ >>> net.update_cell_prefix()
216
+ >>> print(net.dense.param_prefix)
217
+ dense
195
218
  """
196
219
  return self._param_prefix
197
220
 
@@ -202,7 +225,7 @@ class Cell(Cell_):
202
225
 
203
226
  Tutorial Examples:
204
227
  - `Cell and Parameter - Custom Cell Reverse
205
- <https://mindspore.cn/tutorials/en/r2.1/advanced/modules/layer.html#custom-cell-reverse>`_
228
+ <https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#custom-cell-reverse>`_
206
229
  """
207
230
  return self._bprop_debug
208
231
 
@@ -309,6 +332,21 @@ class Cell(Cell_):
309
332
  for item in self.trainable_params():
310
333
  item.add_pipeline_stage(value)
311
334
 
335
+ @property
336
+ def pipeline_segment(self):
337
+ return self._pipeline_segment
338
+
339
+ @pipeline_segment.setter
340
+ def pipeline_segment(self, value):
341
+ if not isinstance(value, int) or isinstance(value, bool):
342
+ raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
343
+ "must be int type, but got type : {}".format(type(value)))
344
+
345
+ if value < 0:
346
+ raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
347
+ "can not be less than 0, but got {}".format(value))
348
+ self._pipeline_segment = value
349
+
312
350
  @property
313
351
  def parallel_parameter_merge_net_dict(self):
314
352
  return self._parallel_parameter_merge_net_dict
@@ -345,7 +383,7 @@ class Cell(Cell_):
345
383
  if '_params_list' in self.__dict__:
346
384
  params_list = self.__dict__['_params_list']
347
385
  if name in params_list:
348
- return ParameterTuple(params_list[name])
386
+ return params_list[name]
349
387
  raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
350
388
 
351
389
  def __del__(self):
@@ -365,11 +403,11 @@ class Cell(Cell_):
365
403
  del self._params[name]
366
404
  elif name in self._cells:
367
405
  del self._cells[name]
406
+ elif '_params_list' in self.__dict__ and name in self._params_list:
407
+ del self._params_list[name]
408
+ elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
409
+ del self._tensor_list[name]
368
410
  else:
369
- if '_params_list' in self.__dict__ and name in self._params_list:
370
- del self._params_list[name]
371
- elif '_tensor_list' in self.__dict__ and name in self._tensor_list:
372
- del self._tensor_list[name]
373
411
  object.__delattr__(self, name)
374
412
  self._attr_synced = False
375
413
 
@@ -381,8 +419,8 @@ class Cell(Cell_):
381
419
  res.append(self._cast_mixed_precision_inputs(item, dst_type))
382
420
  elif isinstance(item, float):
383
421
  res.append(self.cast(item, dst_type))
384
- elif hasattr(item, "dtype") and item.dtype in {mstype.float16, mstype.float32, mstype.float64} and \
385
- item.dtype != dst_type:
422
+ elif hasattr(item, "dtype") and item.dtype in \
423
+ {mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
386
424
  res.append(self.cast(item, dst_type))
387
425
  else:
388
426
  res.append(item)
@@ -629,7 +667,10 @@ class Cell(Cell_):
629
667
  if PackFunc.is_tracing():
630
668
  return self._run_tracefunc(*args, **kwargs)
631
669
 
632
- self.check_names_and_refresh_name()
670
+ if hasattr(self, '_is_check_and_refresh') and not self._is_check_and_refresh:
671
+ self.check_names_and_refresh_name()
672
+ self._is_check_and_refresh = True
673
+
633
674
  # Run in Graph mode.
634
675
  if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
635
676
  self._check_construct_args(*args)
@@ -886,14 +927,14 @@ class Cell(Cell_):
886
927
  >>> import mindspore as ms
887
928
  >>> from mindspore import nn, Tensor
888
929
  >>>
889
- >>> class reluNet(nn.Cell):
930
+ >>> class ReluNet(nn.Cell):
890
931
  ... def __init__(self):
891
- ... super(reluNet, self).__init__()
932
+ ... super(ReluNet, self).__init__()
892
933
  ... self.relu = nn.ReLU()
893
934
  ... def construct(self, x):
894
935
  ... return self.relu(x)
895
936
  >>>
896
- >>> net = reluNet()
937
+ >>> net = ReluNet()
897
938
  >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
898
939
  >>> net.set_inputs(input_dyn)
899
940
  >>> input1 = Tensor(np.random.random([3, 10]), dtype=ms.float32)
@@ -902,13 +943,10 @@ class Cell(Cell_):
902
943
  if self.grad_ops_label:
903
944
  logger.warning(f'For Cell, set_inputs must be set before the gradient function of the network is '
904
945
  f'generated.')
905
- for ele in inputs:
906
- if isinstance(ele, str):
907
- raise TypeError(f"For element in 'set_inputs', the type must not be str.")
908
946
  self._dynamic_shape_inputs = inputs
909
947
  self._check_construct_args(*inputs)
910
948
  if context._get_mode() == context.PYNATIVE_MODE:
911
- _pynative_executor.set_dynamic_input(self)
949
+ _pynative_executor.set_dynamic_input(self, *self._dynamic_shape_inputs)
912
950
 
913
951
  def get_inputs(self):
914
952
  """
@@ -919,6 +957,26 @@ class Cell(Cell_):
919
957
 
920
958
  .. warning::
921
959
  This is an experimental API that is subject to change or deletion.
960
+
961
+ Examples:
962
+ >>> import numpy as np
963
+ >>> import mindspore as ms
964
+ >>> from mindspore import nn, Tensor
965
+ >>>
966
+ >>> class ReluNet(nn.Cell):
967
+ ... def __init__(self):
968
+ ... super(ReluNet, self).__init__()
969
+ ... self.relu = nn.ReLU()
970
+ ... def construct(self, x):
971
+ ... return self.relu(x)
972
+ >>>
973
+ >>> net = ReluNet()
974
+ >>> input_dyn = Tensor(shape=[3, None], dtype=ms.float32)
975
+ >>> net.set_inputs(input_dyn)
976
+ >>> get_inputs = net.get_inputs()
977
+ >>> print(get_inputs)
978
+ (Tensor(shape=[3, -1], dtype=Float32, value= ),)
979
+
922
980
  """
923
981
 
924
982
  return self._dynamic_shape_inputs
@@ -936,9 +994,8 @@ class Cell(Cell_):
936
994
  self._dynamic_shape_inputs = convert_inputs_to_dynamic(*args)
937
995
 
938
996
  if self._dynamic_shape_inputs is None:
939
- compile_args = self.auto_identify_dynamic_shape.auto_dynamic_generate_compile_args(args)
940
997
  _cell_graph_executor.compile(self, phase=self.phase,
941
- jit_config_dict=self._jit_config_dict, *compile_args, **kwargs)
998
+ jit_config_dict=self._jit_config_dict, *args, **kwargs)
942
999
  else:
943
1000
  self._check_compile_dynamic_shape(self._dynamic_shape_inputs, args)
944
1001
  self.saved_dynamic_shape = self._dynamic_shape_inputs
@@ -994,6 +1051,23 @@ class Cell(Cell_):
994
1051
  Raises:
995
1052
  KeyError: If the name of parameter is null or contains dot.
996
1053
  TypeError: If the type of parameter is not Parameter.
1054
+
1055
+ Examples:
1056
+ >>> import mindspore as ms
1057
+ >>> from mindspore import Tensor, nn, Parameter
1058
+ ...
1059
+ >>> class Net(nn.Cell):
1060
+ ... def __init__(self):
1061
+ ... super(Net, self).__init__()
1062
+ ... self.relu = nn.ReLU()
1063
+ ...
1064
+ ... def construct(self, x):
1065
+ ... x = self.relu(x)
1066
+ ... return x
1067
+ >>> net = Net()
1068
+ >>> net.insert_param_to_cell("bias", Parameter(Tensor([1, 2, 3])))
1069
+ >>> print(net.bias)
1070
+ Parameter(name=bias, shape=(3,), dtype=Int64, requires_grad=True)
997
1071
  """
998
1072
  if not param_name:
999
1073
  raise KeyError("For 'insert_param_to_cell', the argument 'param_name' should not be None.")
@@ -1048,6 +1122,18 @@ class Cell(Cell_):
1048
1122
  KeyError: Child Cell's name is incorrect or duplicated with the other child name.
1049
1123
  TypeError: If type of `child_name` is not str.
1050
1124
  TypeError: Child Cell's type is incorrect.
1125
+
1126
+ Examples:
1127
+ >>> import mindspore as ms
1128
+ >>> from mindspore import Tensor, nn
1129
+ ...
1130
+ >>> net1 = nn.ReLU()
1131
+ >>> net2 = nn.Dense(2, 2)
1132
+ >>> net1.insert_child_to_cell("child", net2)
1133
+ >>> print(net1)
1134
+ ReLU<
1135
+ (child): Dense<input_channels=2, output_channels=2, has_bias=True>
1136
+ >
1051
1137
  """
1052
1138
  if not isinstance(child_name, str):
1053
1139
  raise TypeError(f"For 'insert_child_to_cell', the type of parameter 'child_name' must be str, "
@@ -1118,6 +1204,25 @@ class Cell(Cell_):
1118
1204
 
1119
1205
  Returns:
1120
1206
  Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
1207
+
1208
+ Examples:
1209
+ >>> import mindspore as ms
1210
+ >>> from mindspore import Tensor, nn
1211
+ ...
1212
+ >>> class Net(nn.Cell):
1213
+ ... def __init__(self):
1214
+ ... super(Net, self).__init__()
1215
+ ... self.dense = nn.Dense(2, 2)
1216
+ ...
1217
+ ... def construct(self, x):
1218
+ ... x = self.dense(x)
1219
+ ... return x
1220
+ >>> net = Net()
1221
+ >>> print(net.init_parameters_data())
1222
+ {Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True):
1223
+ Parameter (name=dense.weight, shape=(2,2), dtype=Float32, requires_grad=True),
1224
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True):
1225
+ Parameter (name=dense.bias, shape=(2,), dtype=Float32, requires_grad=True)}
1121
1226
  """
1122
1227
  replace = dict()
1123
1228
 
@@ -1163,6 +1268,24 @@ class Cell(Cell_):
1163
1268
 
1164
1269
  Returns:
1165
1270
  OrderedDict, return parameters dictionary.
1271
+
1272
+ Examples:
1273
+ >>> import mindspore as ms
1274
+ >>> from mindspore import Tensor, nn, Parameter
1275
+ ...
1276
+ >>> class Net(nn.Cell):
1277
+ ... def __init__(self):
1278
+ ... super(Net, self).__init__()
1279
+ ... self.dense = nn.Dense(2, 2)
1280
+ ...
1281
+ ... def construct(self, x):
1282
+ ... x = self.dense(x)
1283
+ ... return x
1284
+ >>> net = Net()
1285
+ >>> print(net.parameters_dict())
1286
+ OrderedDict([('dense.weight', Parameter(name=dense.weight, shape=(2, 2), dtype=Float32,
1287
+ requires_grad=True)), ('dense.bias', Parameter(name=dense.bias, shape=(2,), dtype=Float32,
1288
+ requires_grad=True))])
1166
1289
  """
1167
1290
  param_dict = OrderedDict()
1168
1291
  for param in self.get_parameters(expand=recurse):
@@ -1238,7 +1361,7 @@ class Cell(Cell_):
1238
1361
 
1239
1362
  Tutorial Examples:
1240
1363
  - `Model Training - Optimizer
1241
- <https://mindspore.cn/tutorials/en/r2.1/beginner/train.html#optimizer>`_
1364
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#optimizer>`_
1242
1365
  """
1243
1366
  return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
1244
1367
 
@@ -1263,6 +1386,7 @@ class Cell(Cell_):
1263
1386
  Returns an iterator over cell parameters.
1264
1387
 
1265
1388
  Yields parameters of this cell. If `expand` is ``true`` , yield parameters of this cell and all subcells.
1389
+ For more details about subcells, please see the example below.
1266
1390
 
1267
1391
  Args:
1268
1392
  expand (bool): If ``true`` , yields parameters of this cell and all subcells. Otherwise, only yield
@@ -1272,11 +1396,34 @@ class Cell(Cell_):
1272
1396
  Iteration, all parameters at the cell.
1273
1397
 
1274
1398
  Examples:
1275
- >>> from mindspore import nn
1276
- >>> net = nn.Dense(3, 4)
1277
- >>> parameters = []
1278
- >>> for item in net.get_parameters():
1279
- ... parameters.append(item)
1399
+ >>> import mindspore as ms
1400
+ >>> from mindspore import nn, ops, Tensor
1401
+ >>> import numpy as np
1402
+ >>> class TestNet(nn.Cell):
1403
+ ... def __init__(self):
1404
+ ... super().__init__()
1405
+ ... self.my_w1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1406
+ ... self.my_w2 = ms.Parameter(Tensor(np.ones([16]), ms.float32))
1407
+ ... def construct(self, x):
1408
+ ... x += self.my_w1
1409
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1410
+ ... return x
1411
+ >>> class TestNet2(nn.Cell):
1412
+ ... def __init__(self):
1413
+ ... super().__init__()
1414
+ ... self.my_t1 = ms.Parameter(Tensor(np.ones([4, 4]), ms.float32))
1415
+ ... # self.subcell is a subcell of TestNet2, when using expand=True, the parameters of TestNet will
1416
+ ... # also be gathered.
1417
+ ... self.subcell = TestNet()
1418
+ ... def construct(self, x):
1419
+ ... x += self.my_w1
1420
+ ... x = ops.reshape(x, (16,)) - self.my_w2
1421
+ ... return x
1422
+ >>> net = TestNet2()
1423
+ >>> print([p for p in net.get_parameters(expand=True)])
1424
+ [Parameter (name=my_t1, shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w1,
1425
+ shape=(4, 4), dtype=Float32, requires_grad=True), Parameter (name=subcell.my_w2, shape=(16,), dtype=Float32,
1426
+ requires_grad=True)]
1280
1427
  """
1281
1428
  for _, param in self.parameters_and_names(expand=expand):
1282
1429
  yield param
@@ -1325,7 +1472,7 @@ class Cell(Cell_):
1325
1472
 
1326
1473
  Tutorial Examples:
1327
1474
  - `Building a Network - Model Parameters
1328
- <https://mindspore.cn/tutorials/en/r2.1/beginner/model.html#model-parameters>`_
1475
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/model.html#model-parameters>`_
1329
1476
  """
1330
1477
  cells = []
1331
1478
  if expand:
@@ -1337,7 +1484,7 @@ class Cell(Cell_):
1337
1484
  for cell_name, cell in cells:
1338
1485
  params = cell._params.items()
1339
1486
  for par_name, par in params:
1340
- if par.inited_param is not None:
1487
+ if par is not None and par.inited_param is not None:
1341
1488
  par = par.inited_param
1342
1489
  if par is not None and id(par) not in params_set:
1343
1490
  params_set.add(id(par))
@@ -1394,6 +1541,22 @@ class Cell(Cell_):
1394
1541
 
1395
1542
  Returns:
1396
1543
  Iteration, the immediate cells in the cell.
1544
+
1545
+ Examples:
1546
+ >>> import mindspore as ms
1547
+ >>> from mindspore import Tensor, nn
1548
+ ...
1549
+ >>> class Net(nn.Cell):
1550
+ ... def __init__(self):
1551
+ ... super(Net, self).__init__()
1552
+ ... self.dense = nn.Dense(2, 2)
1553
+ ...
1554
+ ... def construct(self, x):
1555
+ ... x = self.dense(x)
1556
+ ... return x
1557
+ >>> net = Net()
1558
+ >>> print(net.cells())
1559
+ odict_values([Dense<input_channels=2, output_channels=2, has_bias=True>])
1397
1560
  """
1398
1561
  return self.name_cells().values()
1399
1562
 
@@ -1439,6 +1602,22 @@ class Cell(Cell_):
1439
1602
 
1440
1603
  Returns:
1441
1604
  Dict, all the child cells and corresponding names in the cell.
1605
+
1606
+ Examples:
1607
+ >>> import mindspore as ms
1608
+ >>> from mindspore import Tensor, nn
1609
+ ...
1610
+ >>> class Net(nn.Cell):
1611
+ ... def __init__(self):
1612
+ ... super(Net, self).__init__()
1613
+ ... self.dense = nn.Dense(2, 2)
1614
+ ...
1615
+ ... def construct(self, x):
1616
+ ... x = self.dense(x)
1617
+ ... return x
1618
+ >>> net = Net()
1619
+ >>> print(net.name_cells())
1620
+ OrderedDict([('dense', Dense<input_channels=2, output_channels=2, has_bias=True>)])
1442
1621
  """
1443
1622
  value_set = set()
1444
1623
  cells = OrderedDict()
@@ -1454,13 +1633,8 @@ class Cell(Cell_):
1454
1633
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
1455
1634
  if "fp32" in flags and flags.get("fp32", False):
1456
1635
  Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
1457
-
1458
- def _add_mixed_precision_flag_recursive(self, **flags):
1459
- """Add mixed precision flag to each cell"""
1460
- if "fp16" in flags and flags.get("fp16", False):
1461
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1462
- if "fp32" in flags and flags.get("fp32", False):
1463
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1636
+ if "bf16" in flags and flags.get("bf16", False):
1637
+ Cell_.set_mixed_precision_type(self, MixedPrecisionType.BF16)
1464
1638
 
1465
1639
  def apply(self, fn):
1466
1640
  """
@@ -1503,6 +1677,23 @@ class Cell(Cell_):
1503
1677
  Args:
1504
1678
  flags (dict): Network configuration information, currently it is used for the binding of network and
1505
1679
  dataset. Users can also customize network attributes by this parameter.
1680
+
1681
+ Examples:
1682
+ >>> import mindspore as ms
1683
+ >>> from mindspore import Tensor, nn
1684
+ ...
1685
+ >>> class Net(nn.Cell):
1686
+ ... def __init__(self):
1687
+ ... super(Net, self).__init__()
1688
+ ... self.relu = nn.ReLU()
1689
+ ...
1690
+ ... def construct(self, x):
1691
+ ... x = self.relu(x)
1692
+ ... return x
1693
+ >>> net = Net()
1694
+ >>> net.add_flags(sink_mode=True)
1695
+ >>> print(net.sink_mode)
1696
+ True
1506
1697
  """
1507
1698
  if not hasattr(self, "_func_graph_flags"):
1508
1699
  self._func_graph_flags = {}
@@ -1518,9 +1709,25 @@ class Cell(Cell_):
1518
1709
  Args:
1519
1710
  flags (dict): Network configuration information, currently it is used for the binding of network and
1520
1711
  dataset. Users can also customize network attributes by this parameter.
1712
+
1713
+ Examples:
1714
+ >>> import mindspore as ms
1715
+ >>> from mindspore import Tensor, nn
1716
+ ...
1717
+ >>> class Net(nn.Cell):
1718
+ ... def __init__(self):
1719
+ ... super(Net, self).__init__()
1720
+ ... self.relu = nn.ReLU()
1721
+ ...
1722
+ ... def construct(self, x):
1723
+ ... x = self.relu(x)
1724
+ ... return x
1725
+ >>> net = Net()
1726
+ >>> net.add_flags_recursive(sink_mode=True)
1727
+ >>> print(net.sink_mode)
1728
+ True
1521
1729
  """
1522
1730
  self.add_flags(**flags)
1523
- self._add_mixed_precision_flag_recursive(**flags)
1524
1731
  for cell in self.cells():
1525
1732
  cell.add_flags_recursive(**flags)
1526
1733
  return self
@@ -1532,17 +1739,28 @@ class Cell(Cell_):
1532
1739
  def get_flags(self):
1533
1740
  """
1534
1741
  Get the self_defined attributes of the cell, which can be added by `add_flags` method.
1742
+
1743
+ Examples:
1744
+ >>> import mindspore as ms
1745
+ >>> from mindspore import Tensor, nn
1746
+ ...
1747
+ >>> class Net(nn.Cell):
1748
+ ... def __init__(self):
1749
+ ... super(Net, self).__init__()
1750
+ ... self.relu = nn.ReLU()
1751
+ ...
1752
+ ... def construct(self, x):
1753
+ ... x = self.relu(x)
1754
+ ... return x
1755
+ >>> net = Net()
1756
+ >>> net.add_flags(sink_mode=True)
1757
+ >>> print(net.get_flags())
1758
+ {'sink_mode':True}
1535
1759
  """
1536
1760
  if not hasattr(self, "_func_graph_flags"):
1537
1761
  self._func_graph_flags = {}
1538
1762
  return self._func_graph_flags
1539
1763
 
1540
- def _set_mixed_precision_type_recursive(self, mixed_type):
1541
- """Set mixed precision type to each cell"""
1542
- Cell_.set_mixed_precision_type(self, mixed_type)
1543
- for cell in self.cells():
1544
- cell._set_mixed_precision_type_recursive(mixed_type)
1545
-
1546
1764
  def to_float(self, dst_type):
1547
1765
  """
1548
1766
  Add cast on all inputs of cell and child cells to run with certain float type.
@@ -1555,13 +1773,13 @@ class Cell(Cell_):
1555
1773
 
1556
1774
  Args:
1557
1775
  dst_type (:class:`mindspore.dtype`): Transfer cell to run with dst_type.
1558
- dst_type can be `mstype.float16` or `mstype.float32`.
1776
+ dst_type can be `mstype.float16` , `mstype.float32` or `mstype.bfloat16`.
1559
1777
 
1560
1778
  Returns:
1561
1779
  Cell, the cell itself.
1562
1780
 
1563
1781
  Raises:
1564
- ValueError: If dst_type is not mstype.float32 or mstype.float16.
1782
+ ValueError: If dst_type is not `mstype.float32` , `mstype.float16` or `mstype.bfloat16`.
1565
1783
 
1566
1784
  Supported Platforms:
1567
1785
  ``Ascend`` ``GPU`` ``CPU``
@@ -1573,19 +1791,15 @@ class Cell(Cell_):
1573
1791
  >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
1574
1792
  >>> net.to_float(mstype.float16)
1575
1793
  Conv2d<input_channels=120, output_channels=240, kernel_size=(4, 4), stride=(1, 1), pad_mode=same,
1576
- padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
1577
- """
1578
- if dst_type not in (mstype.float16, mstype.float32):
1579
- raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32 or mstype.float16, "
1580
- "but got type: {} and value: {}.".format(type(dst_type), dst_type))
1581
- if dst_type == mstype.float16:
1582
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
1583
- self.to_float_fp16 = True
1584
- else:
1585
- self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
1586
- self.to_float_fp16 = False
1587
- flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32}
1794
+ padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=None, format=NCHW>
1795
+ """
1796
+ if dst_type not in (mstype.float16, mstype.float32, mstype.bfloat16):
1797
+ raise ValueError("For 'to_float', the argument 'dst_type' must be mstype.float32, mstype.float16 or "
1798
+ "mstype.bfloat16, but got type: {} and value: {}.".format(type(dst_type), dst_type))
1799
+ flags = {'fp16': dst_type == mstype.float16, 'fp32': dst_type == mstype.float32,
1800
+ 'bf16': dst_type == mstype.bfloat16}
1588
1801
  self._add_init_args(**flags)
1802
+ self.add_flags_recursive(**flags)
1589
1803
  return self
1590
1804
 
1591
1805
  def set_boost(self, boost_type):
@@ -1594,7 +1808,7 @@ class Cell(Cell_):
1594
1808
  accelerate the algorithm in the algorithm library.
1595
1809
 
1596
1810
  If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
1597
- `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.1/mindspore/python/mindspore/boost>`_.
1811
+ `algorithm library <https://gitee.com/mindspore/mindspore/tree/r2.2/mindspore/python/mindspore/boost>`_.
1598
1812
 
1599
1813
  Note:
1600
1814
  Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
@@ -1651,12 +1865,12 @@ class Cell(Cell_):
1651
1865
 
1652
1866
  Tutorial Examples:
1653
1867
  - `Model Training - Implementing Training and Evaluation
1654
- <https://mindspore.cn/tutorials/en/r2.1/beginner/train.html#implementing-training-and-evaluation>`_
1868
+ <https://mindspore.cn/tutorials/en/r2.2/beginner/train.html#training-and-evaluation>`_
1655
1869
  """
1656
- if mode is False:
1657
- self._phase = 'predict'
1658
- else:
1870
+ if mode:
1659
1871
  self._phase = 'train'
1872
+ else:
1873
+ self._phase = 'predict'
1660
1874
  self.add_flags_recursive(training=mode)
1661
1875
  return self
1662
1876
 
@@ -1685,16 +1899,27 @@ class Cell(Cell_):
1685
1899
 
1686
1900
  Args:
1687
1901
  jit_config (JitConfig): Jit config for compile. For details, please refer to :class:`mindspore.JitConfig`.
1902
+
1903
+ Examples:
1904
+ >>> import mindspore as ms
1905
+ >>> from mindspore import Tensor, nn
1906
+ ...
1907
+ >>> class Net(nn.Cell):
1908
+ ... def __init__(self):
1909
+ ... super(Net, self).__init__()
1910
+ ... self.relu = nn.ReLU()
1911
+ ...
1912
+ ... def construct(self, x):
1913
+ ... x = self.relu(x)
1914
+ ... return x
1915
+ >>> net = Net()
1916
+ >>> jitconfig = ms.JitConfig()
1917
+ >>> net.set_jit_config(jitconfig)
1688
1918
  """
1689
1919
  if self._jit_config_dict:
1690
1920
  logger.warning("For Cell, jit config can only be set once, ignore this setting.")
1691
1921
  else:
1692
1922
  self._jit_config_dict = jit_config.jit_config_dict
1693
- enable_ge = os.getenv("MS_ENABLE_GE") == '1'
1694
- enable_jit_level_o3 = self._jit_config_dict.get('jit_level') == "O3"
1695
- if (not enable_ge and enable_jit_level_o3) or (enable_ge and not enable_jit_level_o3):
1696
- raise RuntimeError("GE and jit_level=O3 should be used together, but got MS_ENABLE_GE={}, jie_level={}".
1697
- format(os.getenv("MS_ENABLE_GE"), self.jit_config_dict.get('jit_level')))
1698
1923
 
1699
1924
  def flatten_weights(self, fusion_size=0):
1700
1925
  """
@@ -2290,12 +2515,13 @@ class Cell(Cell_):
2290
2515
  def _run_tracefunc(self, *args, **kwargs):
2291
2516
  """ Run Packed Cell in Pack."""
2292
2517
  args = self._mixed_precision_cast(args)
2293
- if hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags():
2518
+ need_subgraph = hasattr(self, "bprop") or hasattr(self, "_pipeline_stage") or self.get_flags()
2519
+ if not PackFunc.current.is_pynative_mode and need_subgraph:
2294
2520
  expander = PackExpander.get_instance()
2295
2521
  args = expander.begin_subgraph(self, *args)
2296
2522
  args = [_convert_tensor(a) for a in args]
2297
2523
  output = self._run_construct(args, kwargs)
2298
- ret = expander.end_subgraph(output)
2524
+ ret = expander.end_subgraph(self, output)
2299
2525
  output = _convert_tensor(ret)
2300
2526
  else:
2301
2527
  with _SetMixedPrecision(self):
@@ -2306,10 +2532,23 @@ class Cell(Cell_):
2306
2532
  mixed_type = self.get_mixed_precision_type()
2307
2533
  if mixed_type == MixedPrecisionType.NOTSET:
2308
2534
  return inputs
2309
- cast_type = mstype.float16 if mixed_type == MixedPrecisionType.FP16 else mstype.float32
2535
+ if mixed_type == MixedPrecisionType.FP16:
2536
+ cast_type = mstype.float16
2537
+ elif mixed_type == MixedPrecisionType.BF16:
2538
+ cast_type = mstype.bfloat16
2539
+ else:
2540
+ cast_type = mstype.float32
2310
2541
  cast_inputs = self._cast_mixed_precision_inputs(inputs, cast_type)
2311
2542
  return cast_inputs
2312
2543
 
2544
+ def _get_attr_from_cell(self, network):
2545
+ if not isinstance(network, Cell):
2546
+ return
2547
+ if hasattr(network, "jit_config_dict"):
2548
+ self._jit_config_dict = network.jit_config_dict
2549
+ if hasattr(network, "_amp_level"):
2550
+ self._amp_level = getattr(network, "_amp_level")
2551
+
2313
2552
 
2314
2553
  class GraphCell(Cell):
2315
2554
  """