mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.10__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py CHANGED
@@ -19,8 +19,8 @@ import mindspore as ms
19
19
  from mindspore import nn
20
20
  from mindspore import _checkparam as validator
21
21
  from mindspore.common import dtype as mstype
22
- from mindspore.nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
23
- from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
22
+ from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
23
+ from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
24
24
  from mindspore.ops import functional as F
25
25
  from mindspore.parallel._utils import _get_pipeline_stages
26
26
  from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
@@ -30,9 +30,6 @@ from mindspore.ops import Primitive
30
30
  from mindspore import log as logger
31
31
 
32
32
 
33
- STREE = None
34
-
35
-
36
33
  AMP_WHITE_LIST = [
37
34
  nn.Conv1d,
38
35
  nn.Conv2d,
@@ -64,17 +61,19 @@ AMP_BLACK_LIST = [
64
61
  nn.LayerNorm
65
62
  ]
66
63
 
64
+ MS_AMP_BY_REWRITE = False
65
+ _amp_cast_op = P.Cast
67
66
 
68
67
  class _OutputTo16(nn.Cell):
69
68
  """Wrap cell for amp. Cast network output back to float16."""
70
- def __init__(self, backbone):
69
+ def __init__(self, backbone, dtype=mstype.float16):
71
70
  super(_OutputTo16, self).__init__(auto_prefix=False)
72
71
  self._backbone = backbone
73
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
74
- self._jit_config_dict = backbone.jit_config_dict
72
+ self.dtype = dtype
73
+ self._get_attr_from_cell(backbone)
75
74
 
76
75
  def construct(self, *args, **kwargs):
77
- return F.cast(self._backbone(*args, **kwargs), mstype.float16)
76
+ return F.cast(self._backbone(*args, **kwargs), self.dtype)
78
77
 
79
78
 
80
79
  class _OutputTo32(nn.Cell):
@@ -82,63 +81,73 @@ class _OutputTo32(nn.Cell):
82
81
  def __init__(self, backbone):
83
82
  super(_OutputTo32, self).__init__(auto_prefix=False)
84
83
  self._backbone = backbone
85
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
86
- self._jit_config_dict = backbone.jit_config_dict
84
+ self._get_attr_from_cell(backbone)
87
85
 
88
86
  def construct(self, *args, **kwargs):
89
87
  out = self._backbone(*args, **kwargs)
90
88
  return F.mixed_precision_cast(mstype.float32, out)
91
89
 
92
90
 
93
- def _allow_mix_precision(node, allowed_list) -> bool:
91
+
92
+ def _allow_mix_precision(node, allowed_list, dtype) -> bool:
94
93
  """
95
94
  Check whether current node need do mix precision. Follow conditions need to be satisfied:
96
95
  1) Type of node is one of (Primitive, nn.Cell)
97
- 2) Node is not P.Cast()
96
+ 2) Node is not Cast Op
98
97
  3) to_float(mindspore.float16) is not set in Cell
99
98
  """
100
- if node.get_instance() in allowed_list:
99
+ node_inst = node.get_instance()
100
+ if node_inst in allowed_list:
101
101
  return True
102
+ if node.get_targets() is None:
103
+ return False
102
104
  if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
103
105
  return False
104
- if isinstance(node.get_instance(), P.Cast):
106
+ if isinstance(node_inst, _amp_cast_op):
105
107
  return False
106
108
  if issubclass(node.get_instance_type(), nn.Cell):
107
- # if cell is already in allowed_list, it means to_float(mindspore.float16) is set by amp.
108
- # if cell is not in allowed_list, but has to_float(mindspore.float16),
109
- # it means to_float(mindspore.float16) is set by user.
110
- if hasattr(node.get_instance(), "to_float_fp16") and node.get_instance().to_float_fp16:
109
+ # if cell is already in allowed_list, it means to_float() is set by amp.
110
+ # if cell is not in allowed_list, but has to_float(),
111
+ # it means to_float() is set by user.
112
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
113
+ if hasattr(node_inst, to_float_flag) and getattr(node_inst, to_float_flag):
111
114
  return False
112
115
  allowed_list.append(node.get_instance())
113
116
  return True
114
117
 
115
118
 
116
- def _insert_cast_operator_process(node, stree):
119
+ def _insert_cast_operator_process(node, dtype):
117
120
  """insert cast for operators in white_list."""
121
+ dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
118
122
  new_cast_node = None
119
- # insert cast float16 before the primitive operators
123
+ stree = node.get_symbol_tree()
124
+ # insert cast fp16/bf16 before the primitive operators
120
125
  if issubclass(node.get_instance_type(), Primitive):
121
126
  for idx, arg in enumerate(node.get_args()):
122
127
  position = stree.before(node)
123
- new_node = P.Cast()
124
- cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, "mindspore.float16"], [arg.scope, ""])
125
- cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
128
+ new_node = _amp_cast_op()
129
+ cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
130
+ arg_provider = node.get_handler().get_arg_providers()[idx]
131
+ if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
132
+ cast_targets = [stree.unique_name(str(arg))]
133
+ else:
134
+ cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
126
135
  new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
127
136
  targets=cast_targets,
128
137
  args=cast_args,
129
138
  name='incast_{}{}'.format(node.get_name(), idx))
130
139
  stree.insert(position, new_cast_node)
131
140
  node.set_arg_by_node(idx, new_cast_node)
132
- # insert cast float16 before the Cell operators
141
+ # insert cast fp16/bf16 before the Cell operators
133
142
  elif issubclass(node.get_instance_type(), nn.Cell):
134
- node.get_instance().to_float(mstype.float16)
143
+ node.get_instance().to_float(dtype)
135
144
  # ignore if subclass is not one of (Primitive, nn.Cell)
136
145
  else:
137
146
  return
138
147
 
139
148
  # insert cast float32 after the operators
140
149
  position = stree.after(node)
141
- new_node = P.Cast()
150
+ new_node = _amp_cast_op()
142
151
  cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
143
152
  "mindspore.float32"])
144
153
  new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
@@ -156,49 +165,102 @@ def _insert_cast_operator_process(node, stree):
156
165
  user.set_arg_by_node(idx, new_cast_node)
157
166
 
158
167
 
159
- def _insert_cast_operator_white_list(stree, white_list):
168
+ def _insert_cast_operator_white_list(stree, white_list, dtype):
160
169
  """insert cast for operators in white_list."""
161
170
  allowed_list = []
162
- # Ignore if net called ".to_float(mindspore.float16)"
171
+ # Ignore if net called ".to_float(dtype)"
163
172
  net = stree.get_handler().get_origin_network()
164
- if isinstance(net, nn.Cell) and hasattr(net, "to_float_fp16") and net.to_float_fp16:
173
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
174
+ if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
165
175
  return
166
- for node in stree.nodes():
167
- if node.get_targets() is None:
168
- continue
176
+ node_list = []
177
+ node_list.extend(list(stree.nodes()))
178
+ while node_list:
179
+ node = node_list.pop()
169
180
  if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
181
+ if MS_AMP_BY_REWRITE:
182
+ _insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
170
183
  for n in node.get_handler().node_list:
171
184
  if n.get_node_type() == ms.rewrite.NodeType.Tree:
172
185
  _insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
173
- white_list)
186
+ white_list, dtype)
174
187
  elif node.get_node_type() == ms.rewrite.NodeType.Tree:
175
188
  substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
176
- _insert_cast_operator_white_list(substree, white_list)
177
- elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list):
178
- _insert_cast_operator_process(node, stree)
189
+ _insert_cast_operator_white_list(substree, white_list, dtype)
190
+ elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
191
+ if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
192
+ nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
193
+ node_list.extend(nodes)
194
+ elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
195
+ _insert_cast_operator_process(node, dtype)
179
196
 
180
197
 
181
- def _need_removed_cast_pair(node):
198
+ def _insert_cast_for_cell_container(cell_container, dtype, allowed_list, *, white_list=None, black_list=None):
199
+ """
200
+ Insert cast for cell containers.
201
+ Only one of white_list and black_list can be set.
202
+ """
203
+
204
+ class CastNet(nn.Cell):
205
+ """Cast net"""
206
+ def __init__(self, dtype):
207
+ super().__init__()
208
+ self.cast = _amp_cast_op()
209
+ self.dtype = dtype
210
+
211
+ def construct(self, x):
212
+ return self.cast(x, self.dtype)
213
+
214
+ cast_flag = False
215
+ current_node = None
216
+ stree = cell_container.get_symbol_tree()
217
+ for node in cell_container.get_handler().nodes():
218
+ current_node = ms.rewrite.Node(node)
219
+ if (white_list is not None and current_node.get_instance_type() in white_list) or \
220
+ (black_list is not None and current_node.get_instance_type() not in black_list) and \
221
+ (_allow_mix_precision(current_node, allowed_list, dtype)):
222
+ cast_flag = True
223
+ current_node.get_instance().to_float(dtype)
224
+ elif cast_flag:
225
+ # cast next node back to float32
226
+ current_node.get_instance().to_float(mstype.float32)
227
+ cast_flag = False
228
+ if cast_flag and current_node:
229
+ # if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
230
+ cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
231
+ args=[current_node.get_targets()[0]],
232
+ targets=[current_node.get_targets()[0]],
233
+ name=f"outcast_{cell_container.get_name()}")
234
+ stree.insert(stree.after(current_node), cast_node)
235
+
236
+
237
+ def _need_removed_cast_pair(node, dtype):
182
238
  """check whether the cast pairs should be removed."""
183
- cast_dtypes = ms.rewrite.ScopedValue.create_name_values(["mindspore.float16", "mindspore.float32"])
239
+ dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
240
+ cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "mindspore.float32"])
184
241
  cast_dtype_f16 = cast_dtypes[0]
185
242
  cast_dtype_f32 = cast_dtypes[1]
186
- # current node should be P.Cast()(x, mindspore.float32)
187
- if node.get_instance_type() != P.Cast:
243
+ # current node should be Cast Op to float32
244
+ if node.get_instance_type() != _amp_cast_op:
188
245
  return False
189
246
  node_cast_type = node.get_args()[1]
190
247
  if node_cast_type != cast_dtype_f32:
191
248
  return False
192
- # all user nodes should be P.Cast()(x, mindspore.float16) or Cell with to_float(mindspore.float16)
249
+ # all user nodes should be Cast Op to dtype or Cell with to_float(dtype)
193
250
  if not node.get_users():
194
251
  return False
252
+ all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
195
253
  for user in node.get_users():
254
+ # If ControlFlow node(if statement) exists between current node and user node,
255
+ # cast pair should not be removed.
256
+ middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
257
+ if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
258
+ return False
196
259
  if isinstance(user.get_instance(), nn.Cell):
197
- if not hasattr(user.get_instance(), "to_float_fp16"):
198
- return False
199
- if not user.get_instance().to_float_fp16:
260
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
261
+ if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
200
262
  return False
201
- elif user.get_instance_type() == P.Cast:
263
+ elif user.get_instance_type() == _amp_cast_op:
202
264
  user_cast_type = user.get_args()[1]
203
265
  if user_cast_type != cast_dtype_f16:
204
266
  return False
@@ -207,11 +269,13 @@ def _need_removed_cast_pair(node):
207
269
  return True
208
270
 
209
271
 
210
- def _removed_cast_pair_process(stree, cast_f32_node):
272
+ def _removed_cast_pair_process(cast_f32_node):
211
273
  """remove the duplicated cast operators."""
212
- for user_node in cast_f32_node.get_users():
213
- # remove cast f16 nodes
214
- if user_node.get_instance_type() == P.Cast:
274
+ stree = cast_f32_node.get_symbol_tree()
275
+ cast_f32_users = cast_f32_node.get_users()
276
+ # remove cast f16 nodes
277
+ for user_node in cast_f32_users:
278
+ if user_node.get_instance_type() == _amp_cast_op:
215
279
  cast_f16_node = user_node
216
280
  # modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
217
281
  for cast_f16_user in cast_f16_node.get_users():
@@ -229,34 +293,78 @@ def _removed_cast_pair_process(stree, cast_f32_node):
229
293
  stree.erase(cast_f32_node)
230
294
 
231
295
 
232
- def _remove_duplicated_cast(stree):
296
+ def _remove_duplicated_cast(stree, dtype):
233
297
  """remove the duplicated cast operators."""
234
- for node in stree.nodes():
235
- if node.get_targets() is None:
236
- continue
298
+ node_list = []
299
+ node_list.extend(list(stree.nodes()))
300
+ while node_list:
301
+ node = node_list.pop()
237
302
  if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
238
303
  for n in node.get_handler().node_list:
239
304
  if n.get_node_type() == ms.rewrite.NodeType.Tree:
240
- _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
305
+ _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
241
306
  elif node.get_node_type() == ms.rewrite.NodeType.Tree:
242
307
  substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
243
- _remove_duplicated_cast(substree)
244
- elif _need_removed_cast_pair(node):
245
- _removed_cast_pair_process(stree, node)
308
+ _remove_duplicated_cast(substree, dtype)
309
+ elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
310
+ if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
311
+ nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
312
+ node_list.extend(nodes)
313
+ elif _need_removed_cast_pair(node, dtype):
314
+ _removed_cast_pair_process(node)
246
315
 
247
316
 
248
- def _auto_white_list(network, white_list):
317
+ def _auto_white_list(network, white_list, dtype):
249
318
  """process the white list of network."""
250
- global STREE
251
- STREE = ms.rewrite.SymbolTree.create(network)
252
- _insert_cast_operator_white_list(STREE, white_list)
253
- _remove_duplicated_cast(STREE)
254
- return STREE.get_network()
319
+ stree = ms.rewrite.SymbolTree.create(network)
320
+ _insert_cast_operator_white_list(stree, white_list, dtype)
321
+ _remove_duplicated_cast(stree, dtype)
322
+ return stree.get_network()
255
323
 
256
324
 
257
- def _auto_black_list(network, black_list):
325
+ def _insert_cast_operator_black_list(stree, black_list, dtype):
326
+ """insert cast for operators not in black_list."""
327
+ allowed_list = []
328
+ # Ignore if net called ".to_float(dtype)"
329
+ net = stree.get_handler().get_origin_network()
330
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
331
+ if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
332
+ return
333
+ for node in stree.nodes(all_nodes=True):
334
+ if node.get_targets() is None:
335
+ continue
336
+ if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
337
+ _insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
338
+ elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
339
+ # nodes in CellContainer are processed by _insert_cast_for_cell_container
340
+ continue
341
+ elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
342
+ _insert_cast_operator_process(node, dtype)
343
+
344
+
345
+ def _remove_duplicated_cast_rewrite(stree, dtype):
346
+ """remove the duplicated cast operators."""
347
+ for node in stree.nodes(all_nodes=True):
348
+ if _need_removed_cast_pair(node, dtype):
349
+ user_nodes = node.get_users()
350
+ # remove cast f16 nodes
351
+ for user_node in user_nodes:
352
+ if user_node.get_instance_type() == _amp_cast_op:
353
+ stree.erase(user_node)
354
+ # remove the cast f32 node
355
+ stree.erase(node)
356
+
357
+
358
+ def _auto_black_list_rewrite(network, black_list, dtype):
359
+ stree = ms.rewrite.SymbolTree.create(network)
360
+ _insert_cast_operator_black_list(stree, black_list, dtype)
361
+ _remove_duplicated_cast_rewrite(stree, dtype)
362
+ return stree.get_network()
363
+
364
+
365
+ def _auto_black_list(network, black_list, dtype):
258
366
  """process the black list of network."""
259
- network.to_float(mstype.float16)
367
+ network.to_float(dtype)
260
368
  cells = network.name_cells()
261
369
  change = False
262
370
  for name in cells:
@@ -264,30 +372,27 @@ def _auto_black_list(network, black_list):
264
372
  if subcell == network:
265
373
  continue
266
374
  if isinstance(subcell, tuple(black_list)):
267
- network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
375
+ network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
268
376
  change = True
269
377
  else:
270
- _auto_black_list(subcell, black_list)
378
+ _auto_black_list(subcell, black_list, dtype)
271
379
  if isinstance(network, nn.SequentialCell) and change:
272
380
  network.cell_list = list(network.cells())
381
+ return network
273
382
 
274
383
 
275
- def auto_mixed_precision(network, amp_level="O0"):
384
+ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
276
385
  """
277
386
  Returns a network processed with auto mixed precision.
278
387
 
279
388
  This interface will automatically perform mixed-precision processing on the input network, and the cells
280
- and operators in the processed network will add precision conversion operations to calculate with float16 accuracy.
281
- Inputs and parameters of cells and operators are converted to float16 type, and calculation results are converted
282
- back to float32 type.
389
+ and operators in the processed network will add precision conversion operations to calculate with lower
390
+ precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
391
+ converted to lower precision float, and calculation results are converted back to full precision float,
392
+ i.e. ``mstype.float32`` .
283
393
 
284
394
  The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
285
- operators are specifically converted:
286
-
287
- - When `amp_level="O0"` , no precision conversion is performed.
288
- - When `amp_level="O1"` , only the cells and operators in the whitelist will be converted.
289
- - When `amp_level="O2"` , all cells and operators except those in the blacklist will be converted.
290
- - When `amp_level="O3"` , all cells and operators in the network are converted.
395
+ operators are specifically converted.
291
396
 
292
397
  The current built-in whitelist contents are:
293
398
 
@@ -305,26 +410,38 @@ def auto_mixed_precision(network, amp_level="O0"):
305
410
  :class:`mindspore.nn.LayerNorm`]
306
411
 
307
412
  For details on automatic mixed precision, refer to
308
- `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.1/advanced/mixed_precision.html>`_ .
413
+ `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.2/advanced/mixed_precision.html>`_ .
414
+
415
+ Note:
416
+ - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
417
+ can result in a larger network hierarchy and slower performance.
418
+ - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
419
+ mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
420
+ need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
309
421
 
310
422
  Args:
311
423
  network (Cell): Definition of the network.
312
424
  amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
313
425
 
314
426
  - "O0": Do not change.
315
- - "O1": Convert cells and operators in whitelist to float16 precision operations, and keep float32
427
+ - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
316
428
  precision operations for the rest.
317
- - "O2": Keep float32 precision operations for cells and operators in blacklist, and convert the rest
318
- to float16 precision operations.
319
- - "O3": Cast network to float16.
429
+ - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
430
+ to lower precision operations.
431
+ - "O3": Cast network to lower precision.
432
+
433
+ dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
434
+ default: ``mstype.float16`` .
320
435
 
321
436
  Raises:
322
- ValueError: If amp level is not supported.
437
+ TypeError: If `network` is not a Cell.
438
+ ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
439
+ ValueError: If `amp_level` is not within the supported range.
323
440
 
324
441
  Examples:
325
442
  >>> from mindspore import amp
326
443
  >>> # Define the network structure of LeNet5. Refer to
327
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
444
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
328
445
  >>> network = LeNet5()
329
446
  >>> amp_level = "O1"
330
447
  >>> net = amp.auto_mixed_precision(network, amp_level)
@@ -332,20 +449,37 @@ def auto_mixed_precision(network, amp_level="O0"):
332
449
  if not isinstance(network, nn.Cell):
333
450
  raise TypeError("The network type should be Cell.")
334
451
 
452
+ if dtype not in (mstype.float16, mstype.bfloat16):
453
+ raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
454
+
335
455
  if amp_level == "O0":
336
456
  return network
337
457
 
338
- if amp_level == "O1":
339
- return _auto_white_list(network, AMP_WHITE_LIST)
458
+ # Return network if the same amp level has already been configurated
459
+ if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
460
+ logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
461
+ f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
462
+ f"degradation.")
340
463
 
341
- if amp_level == "O2":
342
- _auto_black_list(network, AMP_BLACK_LIST)
464
+ if amp_level == "O1":
465
+ network = _auto_white_list(network, AMP_WHITE_LIST, dtype)
466
+ elif amp_level == "O2":
467
+ if MS_AMP_BY_REWRITE:
468
+ network = _auto_black_list_rewrite(network, AMP_BLACK_LIST, dtype)
469
+ else:
470
+ network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
471
+ network = _OutputTo32(network)
343
472
  elif amp_level == "O3":
344
- network.to_float(mstype.float16)
473
+ if MS_AMP_BY_REWRITE:
474
+ network = _auto_black_list_rewrite(network, [], dtype)
475
+ else:
476
+ network.to_float(dtype)
477
+ network = _OutputTo32(network)
345
478
  else:
346
479
  raise ValueError("The amp level {} is not supported".format(amp_level))
347
- if amp_level in ("O2", "O3"):
348
- network = _OutputTo32(network)
480
+
481
+ setattr(network, "_amp_level", amp_level)
482
+
349
483
  return network
350
484
 
351
485
 
@@ -436,8 +570,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
436
570
  super(WithLossCell, self).__init__(auto_prefix=False)
437
571
  self._backbone = backbone
438
572
  self._loss_fn = loss_fn
439
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
440
- self._jit_config_dict = backbone.jit_config_dict
573
+ self._get_attr_from_cell(backbone)
441
574
 
442
575
  def construct(self, data, label):
443
576
  out = self._backbone(data)
@@ -452,6 +585,39 @@ def _add_loss_network(network, loss_fn, cast_model_type):
452
585
  return network
453
586
 
454
587
 
588
+ def _is_grad_accumulation(mcell):
589
+ if mcell.cls_name == "GradAccumulationCell":
590
+ return True
591
+ for cell in mcell.cells():
592
+ if _is_grad_accumulation(cell):
593
+ return True
594
+ return False
595
+
596
+
597
+ def _auto_mixed_precision_process(network, config, level):
598
+ """Auto mixed precision process."""
599
+ if MS_AMP_BY_REWRITE:
600
+ if config["cast_model_type"] == mstype.float16 or level == "O2":
601
+ level = "O2" if config["keep_batchnorm_fp32"] else "O3"
602
+ elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
603
+ # cast_model_type set by kwargs
604
+ level = "O0"
605
+ network = auto_mixed_precision(network, level)
606
+ else:
607
+ if config["cast_model_type"] == mstype.float16:
608
+ network.to_float(mstype.float16)
609
+
610
+ if config["keep_batchnorm_fp32"]:
611
+ _do_keep_batchnorm_fp32(network)
612
+ elif not config["keep_batchnorm_fp32"] and level == "O2":
613
+ network.to_float(mstype.float16)
614
+ elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
615
+ pass
616
+ else:
617
+ network = auto_mixed_precision(network, level)
618
+ return network
619
+
620
+
455
621
  def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
456
622
  """
457
623
  Build the mixed precision training cell automatically.
@@ -510,7 +676,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
510
676
  Examples:
511
677
  >>> from mindspore import amp, nn
512
678
  >>> # Define the network structure of LeNet5. Refer to
513
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
679
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
514
680
  >>> network = LeNet5()
515
681
  >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
516
682
  >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
@@ -525,17 +691,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
525
691
  _check_kwargs(kwargs)
526
692
  config = dict(_config_level.get(level), **kwargs)
527
693
 
528
- if config["cast_model_type"] == mstype.float16:
529
- network.to_float(mstype.float16)
530
-
531
- if config["keep_batchnorm_fp32"]:
532
- _do_keep_batchnorm_fp32(network)
533
- elif not config["keep_batchnorm_fp32"] and level == "O2":
534
- network.to_float(mstype.float16)
535
- elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
536
- pass
537
- else:
538
- network = auto_mixed_precision(network, level)
694
+ network = _auto_mixed_precision_process(network, config, level)
539
695
 
540
696
  if loss_fn:
541
697
  network = _add_loss_network(network, loss_fn, config["cast_model_type"])
@@ -551,8 +707,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
551
707
  raise ValueError("Only `loss_scale_manager=None` or "
552
708
  "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
553
709
  "are supported on device `CPU`. ")
554
- if _get_pipeline_stages() > 1:
555
- network = _TrainPipelineWithLossScaleCell(network, optimizer,
710
+ if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
711
+ network = _TrainGradAccuWithLossScaleCell(network, optimizer,
556
712
  scale_sense=update_cell).set_train()
557
713
  elif enable_boost:
558
714
  network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
@@ -561,8 +717,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
561
717
  network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
562
718
  scale_sense=update_cell).set_train()
563
719
  return network
564
- if _get_pipeline_stages() > 1:
565
- network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
720
+ if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
721
+ network = _TrainGradAccuStepCell(network, optimizer).set_train()
566
722
  elif enable_boost:
567
723
  network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
568
724
  else:
@@ -586,6 +742,23 @@ def get_white_list():
586
742
 
587
743
  Returns:
588
744
  list, A copy of internal white list.
745
+
746
+ Examples:
747
+ >>> from mindspore import amp
748
+ >>> white_list = amp.get_white_list()
749
+ >>> print(white_list)
750
+ [<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
751
+ <class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
752
+ <class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
753
+ <class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
754
+ <class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
755
+ <class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
756
+ <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
757
+ <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
758
+ <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
759
+ <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
760
+ <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
761
+ <class 'mindspore.ops.operations.math_ops.Ger'>]
589
762
  """
590
763
  white_list = AMP_WHITE_LIST.copy()
591
764
  return white_list
@@ -602,24 +775,31 @@ def get_black_list():
602
775
 
603
776
  Returns:
604
777
  list, A copy of internal black list.
778
+
779
+ Examples:
780
+ >>> from mindspore import amp
781
+ >>> black_list = amp.get_black_list()
782
+ >>> print(black_list)
783
+ [<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
784
+ <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
605
785
  """
606
786
  black_list = AMP_BLACK_LIST.copy()
607
787
  return black_list
608
788
 
609
789
 
610
- def custom_mixed_precision(network, *, white_list=None, black_list=None):
790
+ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
611
791
  """
612
792
  Custom mixed precision by setting whitelist or blacklist.
613
793
  When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
614
- When the `black_list` is provided, cells that are not in `black_list` will perform the pereision
615
- conversion.
794
+ When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
616
795
  Only one of `white_list` and `black_list` should be provided.
617
796
 
618
797
  Note:
619
- - After using `custom_mixed_precision` for precision conversion, it is not supported to use other interfaces
620
- for precision conversion again. If interfaces like `Model` and `build_train_network` is used to train
621
- the converted network, `amp_level` need to be configured to ``O0`` to avoid the duplicated accuracy
622
- conversion.
798
+ - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
799
+ can result in a larger network hierarchy and slower performance.
800
+ - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
801
+ mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
802
+ need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
623
803
  - Primitives for blacklist is not support yet.
624
804
 
625
805
  Args:
@@ -628,6 +808,8 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
628
808
  white list is not used.
629
809
  black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
630
810
  black list is not used.
811
+ dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
812
+ default: ``mstype.float16`` .
631
813
 
632
814
  Returns:
633
815
  network (Cell), A network supporting mixed precision.
@@ -635,12 +817,13 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
635
817
  Raises:
636
818
  TypeError: The network type is not Cell.
637
819
  ValueError: Neither `white_list` nor `black_list` is provided.
820
+ ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
638
821
  ValueError: Both `white_list` and `black_list` are provided.
639
822
 
640
823
  Examples:
641
824
  >>> from mindspore import amp, nn
642
825
  >>> # Define the network structure of LeNet5. Refer to
643
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
826
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
644
827
  >>> net = LeNet5()
645
828
  >>> custom_white_list = amp.get_white_list()
646
829
  >>> custom_white_list.append(nn.Flatten)
@@ -656,13 +839,19 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
656
839
  raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
657
840
  "at the same time, please provide one or the other.")
658
841
 
842
+ if dtype not in (mstype.float16, mstype.bfloat16):
843
+ raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
844
+
659
845
  if white_list is not None:
660
846
  _list_check(white_list, "white_list")
661
- return _auto_white_list(network, white_list)
662
-
663
- _list_check(black_list, "black_list")
664
- _auto_black_list(network, black_list)
665
- network = _OutputTo32(network)
847
+ network = _auto_white_list(network, white_list, dtype)
848
+ else:
849
+ _list_check(black_list, "black_list")
850
+ if MS_AMP_BY_REWRITE:
851
+ network = _auto_black_list_rewrite(network, black_list, dtype)
852
+ else:
853
+ network = _auto_black_list(network, black_list, dtype)
854
+ network = _OutputTo32(network)
666
855
  return network
667
856
 
668
857
 
@@ -693,3 +882,14 @@ def _list_check(custom_list: list, list_name: str):
693
882
  for elem in AMP_BLACK_LIST:
694
883
  if elem not in custom_list:
695
884
  logger.warning(f"{elem} is removed from internal black list.")
885
+
886
+ def _config_amp(*, enable_rewrite: bool = None, cast_op: type = None): # pylint: disable=unused-variable
887
+ """Configure auto mixed precision."""
888
+ global MS_AMP_BY_REWRITE
889
+ global _amp_cast_op
890
+
891
+ if enable_rewrite is not None:
892
+ MS_AMP_BY_REWRITE = enable_rewrite
893
+
894
+ if cast_op is not None:
895
+ _amp_cast_op = cast_op