mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (884) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  201. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  202. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  203. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  204. mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
  205. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  206. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  207. mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
  208. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  209. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  210. mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
  211. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  212. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  213. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  214. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  215. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  216. mindspore/log.py +9 -6
  217. mindspore/mindrecord/filereader.py +33 -4
  218. mindspore/mindrecord/filewriter.py +70 -35
  219. mindspore/mindrecord/mindpage.py +40 -34
  220. mindspore/mindrecord/shardreader.py +1 -1
  221. mindspore/mindrecord/shardsegment.py +1 -1
  222. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  223. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  224. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  225. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  226. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  227. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  228. mindspore/nn/cell.py +463 -169
  229. mindspore/nn/dynamic_lr.py +47 -43
  230. mindspore/nn/layer/activation.py +225 -82
  231. mindspore/nn/layer/basic.py +121 -79
  232. mindspore/nn/layer/channel_shuffle.py +21 -21
  233. mindspore/nn/layer/combined.py +33 -26
  234. mindspore/nn/layer/container.py +277 -22
  235. mindspore/nn/layer/conv.py +441 -304
  236. mindspore/nn/layer/dense.py +19 -13
  237. mindspore/nn/layer/embedding.py +62 -49
  238. mindspore/nn/layer/flash_attention.py +264 -0
  239. mindspore/nn/layer/image.py +50 -39
  240. mindspore/nn/layer/math.py +62 -51
  241. mindspore/nn/layer/normalization.py +219 -167
  242. mindspore/nn/layer/padding.py +58 -70
  243. mindspore/nn/layer/pooling.py +334 -287
  244. mindspore/nn/layer/rnn_cells.py +53 -38
  245. mindspore/nn/layer/rnns.py +59 -56
  246. mindspore/nn/layer/thor_layer.py +52 -44
  247. mindspore/nn/layer/timedistributed.py +6 -4
  248. mindspore/nn/layer/transformer.py +284 -164
  249. mindspore/nn/learning_rate_schedule.py +34 -25
  250. mindspore/nn/loss/__init__.py +3 -2
  251. mindspore/nn/loss/loss.py +554 -311
  252. mindspore/nn/optim/ada_grad.py +12 -9
  253. mindspore/nn/optim/adadelta.py +14 -11
  254. mindspore/nn/optim/adafactor.py +19 -16
  255. mindspore/nn/optim/adam.py +62 -47
  256. mindspore/nn/optim/adamax.py +13 -10
  257. mindspore/nn/optim/adasum.py +12 -8
  258. mindspore/nn/optim/asgd.py +10 -9
  259. mindspore/nn/optim/ftrl.py +20 -17
  260. mindspore/nn/optim/lamb.py +16 -12
  261. mindspore/nn/optim/lars.py +8 -6
  262. mindspore/nn/optim/lazyadam.py +25 -20
  263. mindspore/nn/optim/momentum.py +10 -7
  264. mindspore/nn/optim/optimizer.py +61 -9
  265. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  266. mindspore/nn/optim/rmsprop.py +17 -13
  267. mindspore/nn/optim/rprop.py +30 -17
  268. mindspore/nn/optim/sgd.py +40 -23
  269. mindspore/nn/optim/thor.py +24 -26
  270. mindspore/nn/probability/bijector/bijector.py +11 -11
  271. mindspore/nn/probability/bijector/exp.py +1 -1
  272. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  273. mindspore/nn/probability/bijector/invert.py +1 -1
  274. mindspore/nn/probability/bijector/power_transform.py +29 -29
  275. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  276. mindspore/nn/probability/bijector/softplus.py +5 -5
  277. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  278. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  279. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  280. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  281. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  282. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  283. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  284. mindspore/nn/probability/distribution/beta.py +8 -8
  285. mindspore/nn/probability/distribution/categorical.py +23 -15
  286. mindspore/nn/probability/distribution/cauchy.py +5 -6
  287. mindspore/nn/probability/distribution/distribution.py +3 -3
  288. mindspore/nn/probability/distribution/exponential.py +4 -4
  289. mindspore/nn/probability/distribution/gamma.py +10 -10
  290. mindspore/nn/probability/distribution/geometric.py +8 -8
  291. mindspore/nn/probability/distribution/gumbel.py +8 -9
  292. mindspore/nn/probability/distribution/half_normal.py +5 -5
  293. mindspore/nn/probability/distribution/laplace.py +5 -5
  294. mindspore/nn/probability/distribution/log_normal.py +12 -11
  295. mindspore/nn/probability/distribution/logistic.py +8 -8
  296. mindspore/nn/probability/distribution/normal.py +6 -5
  297. mindspore/nn/probability/distribution/poisson.py +10 -11
  298. mindspore/nn/probability/distribution/student_t.py +8 -9
  299. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  300. mindspore/nn/probability/distribution/uniform.py +11 -11
  301. mindspore/nn/reinforcement/tensor_array.py +2 -2
  302. mindspore/nn/sparse/sparse.py +9 -9
  303. mindspore/nn/wrap/cell_wrapper.py +188 -63
  304. mindspore/nn/wrap/grad_reducer.py +21 -12
  305. mindspore/nn/wrap/loss_scale.py +136 -49
  306. mindspore/numpy/__init__.py +4 -4
  307. mindspore/numpy/array_creations.py +55 -56
  308. mindspore/numpy/array_ops.py +134 -35
  309. mindspore/numpy/logic_ops.py +66 -20
  310. mindspore/numpy/math_ops.py +142 -139
  311. mindspore/numpy/utils_const.py +2 -2
  312. mindspore/offline_debug/convert_async.py +2 -2
  313. mindspore/ops/_grad_experimental/__init__.py +7 -5
  314. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  315. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  316. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  317. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  318. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  319. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  320. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  321. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  322. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  323. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  324. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  325. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  326. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  327. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  328. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  329. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  330. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  331. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  332. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  333. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  334. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  335. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  336. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  337. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  338. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  339. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  340. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  341. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  342. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  343. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  344. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  345. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  346. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  347. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  348. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  349. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  350. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  351. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  352. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  353. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  354. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  355. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  356. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  357. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  358. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  359. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  360. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  361. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  362. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  363. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  364. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  365. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  366. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  367. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  368. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  369. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  370. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  371. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  372. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  373. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  374. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  375. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  376. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  377. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  378. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  379. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  380. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  381. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  382. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  383. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  384. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  385. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  386. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  387. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  388. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  389. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  390. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  391. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  392. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  393. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  394. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  395. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  396. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  397. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  398. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  399. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  400. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  401. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  402. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  403. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  404. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  405. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  406. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  407. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  408. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  409. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  410. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  411. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  412. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  413. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  414. mindspore/ops/_primitive_cache.py +1 -1
  415. mindspore/ops/_tracefunc.py +241 -0
  416. mindspore/ops/_utils/utils.py +10 -2
  417. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  418. mindspore/ops/_vmap/vmap_base.py +5 -4
  419. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  420. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  421. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  422. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  423. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  424. mindspore/ops/arg_dtype_cast.py +54 -0
  425. mindspore/ops/composite/__init__.py +7 -5
  426. mindspore/ops/composite/base.py +78 -34
  427. mindspore/ops/composite/math_ops.py +5 -695
  428. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  429. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  430. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  431. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  432. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  433. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  434. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  436. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  437. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  438. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  440. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  441. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  442. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  443. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  444. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  445. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  446. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  447. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  448. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  449. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  450. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  451. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  452. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  453. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  454. mindspore/ops/deprecated.py +304 -0
  455. mindspore/ops/function/__init__.py +41 -4
  456. mindspore/ops/function/array_func.py +1108 -467
  457. mindspore/ops/function/clip_func.py +94 -27
  458. mindspore/ops/function/debug_func.py +3 -1
  459. mindspore/ops/function/grad/grad_func.py +82 -73
  460. mindspore/ops/function/image_func.py +28 -12
  461. mindspore/ops/function/linalg_func.py +135 -39
  462. mindspore/ops/function/math_func.py +3779 -894
  463. mindspore/ops/function/nn_func.py +1584 -657
  464. mindspore/ops/function/parameter_func.py +13 -3
  465. mindspore/ops/function/random_func.py +247 -153
  466. mindspore/ops/function/sparse_func.py +14 -11
  467. mindspore/ops/function/sparse_unary_func.py +173 -47
  468. mindspore/ops/function/spectral_func.py +8 -4
  469. mindspore/ops/function/vmap_func.py +8 -7
  470. mindspore/ops/functional.py +47 -16
  471. mindspore/ops/op_info_register.py +346 -86
  472. mindspore/ops/operations/__init__.py +38 -22
  473. mindspore/ops/operations/_grad_ops.py +145 -149
  474. mindspore/ops/operations/_inner_ops.py +298 -56
  475. mindspore/ops/operations/_ms_kernel.py +3 -3
  476. mindspore/ops/operations/_quant_ops.py +24 -28
  477. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  478. mindspore/ops/operations/_scalar_ops.py +115 -0
  479. mindspore/ops/operations/_sequence_ops.py +148 -10
  480. mindspore/ops/operations/_tensor_array.py +1 -1
  481. mindspore/ops/operations/_thor_ops.py +2 -2
  482. mindspore/ops/operations/array_ops.py +1239 -561
  483. mindspore/ops/operations/comm_ops.py +166 -90
  484. mindspore/ops/operations/control_ops.py +3 -3
  485. mindspore/ops/operations/custom_ops.py +124 -102
  486. mindspore/ops/operations/debug_ops.py +24 -11
  487. mindspore/ops/operations/image_ops.py +86 -71
  488. mindspore/ops/operations/inner_ops.py +18 -13
  489. mindspore/ops/operations/linalg_ops.py +30 -11
  490. mindspore/ops/operations/math_ops.py +1730 -435
  491. mindspore/ops/operations/nn_ops.py +1953 -943
  492. mindspore/ops/operations/other_ops.py +65 -43
  493. mindspore/ops/operations/random_ops.py +258 -98
  494. mindspore/ops/operations/rl_ops.py +4 -36
  495. mindspore/ops/operations/sparse_ops.py +38 -33
  496. mindspore/ops/operations/spectral_ops.py +8 -4
  497. mindspore/ops/primitive.py +66 -44
  498. mindspore/ops/signature.py +5 -5
  499. mindspore/parallel/_auto_parallel_context.py +80 -19
  500. mindspore/parallel/_cost_model_context.py +42 -0
  501. mindspore/parallel/_offload_context.py +162 -72
  502. mindspore/parallel/_parallel_serialization.py +2 -2
  503. mindspore/parallel/_ps_context.py +16 -4
  504. mindspore/parallel/_recovery_context.py +2 -1
  505. mindspore/parallel/_tensor.py +15 -13
  506. mindspore/parallel/_transformer/layers.py +8 -6
  507. mindspore/parallel/_transformer/loss.py +1 -0
  508. mindspore/parallel/_transformer/moe.py +7 -7
  509. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  510. mindspore/parallel/_transformer/transformer.py +34 -14
  511. mindspore/parallel/_utils.py +36 -14
  512. mindspore/parallel/algo_parameter_config.py +114 -20
  513. mindspore/parallel/checkpoint_transform.py +16 -18
  514. mindspore/parallel/shard.py +16 -13
  515. mindspore/profiler/__init__.py +1 -1
  516. mindspore/profiler/common/struct_type.py +3 -3
  517. mindspore/profiler/common/util.py +3 -2
  518. mindspore/profiler/envprofiling.py +11 -4
  519. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  520. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  521. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  522. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  523. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  524. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  525. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  526. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  527. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  528. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  529. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  530. mindspore/profiler/parser/flops_parser.py +15 -11
  531. mindspore/profiler/parser/framework_parser.py +92 -73
  532. mindspore/profiler/parser/hccl_parser.py +16 -12
  533. mindspore/profiler/parser/integrator.py +22 -11
  534. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  535. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  536. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  537. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  538. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  539. mindspore/profiler/parser/optime_parser.py +1 -1
  540. mindspore/profiler/parser/profiler_info.py +4 -5
  541. mindspore/profiler/parser/step_trace_parser.py +11 -14
  542. mindspore/profiler/profiling.py +678 -377
  543. mindspore/rewrite/api/node.py +211 -54
  544. mindspore/rewrite/api/node_type.py +5 -0
  545. mindspore/rewrite/api/pattern_engine.py +22 -23
  546. mindspore/rewrite/api/scoped_value.py +20 -17
  547. mindspore/rewrite/api/symbol_tree.py +252 -106
  548. mindspore/rewrite/api/tree_node_helper.py +3 -0
  549. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  550. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  551. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  552. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  553. mindspore/rewrite/common/rewrite_elog.py +5 -1
  554. mindspore/rewrite/namer.py +51 -51
  555. mindspore/rewrite/namespace.py +14 -5
  556. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  557. mindspore/rewrite/node/call_function.py +79 -0
  558. mindspore/rewrite/node/cell_container.py +135 -0
  559. mindspore/rewrite/node/control_flow.py +88 -0
  560. mindspore/rewrite/{node.py → node/node.py} +313 -247
  561. mindspore/rewrite/node/node_manager.py +254 -0
  562. mindspore/rewrite/node/node_topological_manager.py +243 -0
  563. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  564. mindspore/rewrite/parsers/assign_parser.py +225 -239
  565. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  566. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  567. mindspore/rewrite/parsers/constant_parser.py +9 -6
  568. mindspore/rewrite/parsers/container_parser.py +9 -7
  569. mindspore/rewrite/parsers/for_parser.py +36 -15
  570. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  571. mindspore/rewrite/parsers/if_parser.py +28 -24
  572. mindspore/rewrite/parsers/module_parser.py +202 -25
  573. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  574. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  575. mindspore/rewrite/parsers/return_parser.py +6 -6
  576. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  577. mindspore/rewrite/sparsify/sparsify.py +4 -1
  578. mindspore/rewrite/sparsify/utils.py +11 -5
  579. mindspore/rewrite/symbol_tree.py +577 -732
  580. mindspore/rewrite/symbol_tree_builder.py +9 -175
  581. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  582. mindspore/run_check/_check_version.py +46 -39
  583. mindspore/run_check/run_check.py +3 -2
  584. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  585. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  586. mindspore/scipy/__init__.py +1 -1
  587. mindspore/scipy/linalg.py +67 -61
  588. mindspore/scipy/ops.py +5 -41
  589. mindspore/scipy/ops_grad.py +3 -2
  590. mindspore/scipy/ops_wrapper.py +5 -5
  591. mindspore/scipy/optimize/line_search.py +8 -8
  592. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  593. mindspore/scipy/optimize/minimize.py +16 -12
  594. mindspore/scipy/utils.py +1 -52
  595. mindspore/scipy/utils_const.py +4 -4
  596. mindspore/train/__init__.py +4 -4
  597. mindspore/train/_utils.py +13 -5
  598. mindspore/train/amp.py +410 -148
  599. mindspore/train/anf_ir_pb2.py +16 -4
  600. mindspore/train/callback/_backup_and_restore.py +8 -11
  601. mindspore/train/callback/_callback.py +80 -3
  602. mindspore/train/callback/_checkpoint.py +82 -51
  603. mindspore/train/callback/_early_stop.py +12 -15
  604. mindspore/train/callback/_history.py +1 -1
  605. mindspore/train/callback/_lambda_callback.py +13 -13
  606. mindspore/train/callback/_landscape.py +21 -17
  607. mindspore/train/callback/_loss_monitor.py +9 -10
  608. mindspore/train/callback/_on_request_exit.py +16 -33
  609. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  610. mindspore/train/callback/_summary_collector.py +44 -30
  611. mindspore/train/callback/_time_monitor.py +62 -12
  612. mindspore/train/data_sink.py +10 -16
  613. mindspore/train/dataset_helper.py +154 -86
  614. mindspore/train/loss_scale_manager.py +14 -9
  615. mindspore/train/metrics/__init__.py +10 -2
  616. mindspore/train/metrics/accuracy.py +1 -1
  617. mindspore/train/metrics/auc.py +1 -1
  618. mindspore/train/metrics/bleu_score.py +2 -2
  619. mindspore/train/metrics/confusion_matrix.py +14 -14
  620. mindspore/train/metrics/cosine_similarity.py +3 -3
  621. mindspore/train/metrics/dice.py +1 -1
  622. mindspore/train/metrics/fbeta.py +1 -1
  623. mindspore/train/metrics/hausdorff_distance.py +8 -6
  624. mindspore/train/metrics/mean_surface_distance.py +5 -4
  625. mindspore/train/metrics/metric.py +49 -17
  626. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  627. mindspore/train/metrics/perplexity.py +1 -1
  628. mindspore/train/metrics/precision.py +2 -2
  629. mindspore/train/metrics/recall.py +2 -3
  630. mindspore/train/metrics/roc.py +7 -7
  631. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  632. mindspore/train/metrics/topk.py +7 -4
  633. mindspore/train/mind_ir_pb2.py +193 -48
  634. mindspore/train/model.py +377 -133
  635. mindspore/train/serialization.py +697 -245
  636. mindspore/train/summary/_summary_adapter.py +5 -2
  637. mindspore/train/summary/_writer_pool.py +4 -3
  638. mindspore/train/summary/summary_record.py +25 -23
  639. mindspore/train/train_thor/convert_utils.py +39 -23
  640. mindspore/train/train_thor/dataset_helper.py +4 -3
  641. mindspore/train/train_thor/model_thor.py +8 -8
  642. mindspore/version.py +1 -1
  643. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  644. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
  645. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  646. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  647. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  648. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  649. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  650. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  651. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  652. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  653. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  654. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  655. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  656. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  657. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  658. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  659. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  660. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  661. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  662. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  663. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  664. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  665. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  666. mindspore/_extends/graph_kernel/expander.py +0 -80
  667. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  668. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  669. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  670. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  671. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  672. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  673. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  674. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  675. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  676. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  677. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  678. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  679. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  680. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  681. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  682. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  683. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  684. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  685. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  686. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  687. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  688. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  689. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  690. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  691. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  692. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  693. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  694. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  695. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  696. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  697. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  698. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  699. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  700. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  701. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  702. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  703. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  704. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  705. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  706. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  707. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  708. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  709. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  710. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  711. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  712. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  713. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  714. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  715. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  716. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  717. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  718. mindspore/dataset/engine/graphdata.py +0 -1586
  719. mindspore/include/api/net.h +0 -142
  720. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  721. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  722. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  723. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  724. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  725. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  726. mindspore/ops/_grad/grad_other_ops.py +0 -89
  727. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  728. mindspore/ops/_grad/grad_sparse.py +0 -323
  729. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  730. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  731. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  732. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  733. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  734. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  735. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  736. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  737. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  738. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  739. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  740. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  741. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  742. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  743. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  744. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  745. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  746. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  747. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  748. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  749. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  751. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  752. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  753. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  754. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  755. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  756. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  758. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  759. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  760. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  761. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  762. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  763. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  764. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  765. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  766. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  767. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  768. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  769. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  770. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  771. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  772. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  774. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  775. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  776. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  777. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  778. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  779. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  780. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  784. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  785. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  786. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  787. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  789. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  790. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  791. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  792. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  793. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  794. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  796. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  797. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  798. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  799. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  800. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  801. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  802. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  803. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  804. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  805. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  806. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  807. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  808. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  809. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  810. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  811. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  812. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  814. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  815. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  816. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  817. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  818. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  819. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  820. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  821. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  822. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  823. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  824. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  825. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  826. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  827. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  828. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  829. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  830. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  831. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  832. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  833. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  834. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  835. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  836. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  837. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  839. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  840. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  841. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  842. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  843. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  844. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  845. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  846. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  847. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  848. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  849. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  850. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  851. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  852. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  853. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  854. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  855. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  856. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  857. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  858. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  859. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  860. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  861. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  862. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  863. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  864. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  865. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  866. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  867. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  868. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  869. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  870. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  871. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  872. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  873. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  874. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  875. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  876. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  877. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  878. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  879. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  880. mindspore/rewrite/node_visitor.py +0 -44
  881. mindspore/rewrite/topological_manager.py +0 -203
  882. mindspore/scipy/sparse/linalg.py +0 -192
  883. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  884. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py CHANGED
@@ -19,8 +19,8 @@ import mindspore as ms
19
19
  from mindspore import nn
20
20
  from mindspore import _checkparam as validator
21
21
  from mindspore.common import dtype as mstype
22
- from mindspore.nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
23
- from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
22
+ from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
23
+ from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
24
24
  from mindspore.ops import functional as F
25
25
  from mindspore.parallel._utils import _get_pipeline_stages
26
26
  from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
@@ -30,9 +30,6 @@ from mindspore.ops import Primitive
30
30
  from mindspore import log as logger
31
31
 
32
32
 
33
- STREE = None
34
-
35
-
36
33
  AMP_WHITE_LIST = [
37
34
  nn.Conv1d,
38
35
  nn.Conv2d,
@@ -64,17 +61,19 @@ AMP_BLACK_LIST = [
64
61
  nn.LayerNorm
65
62
  ]
66
63
 
64
+ MS_AMP_BY_REWRITE = False
65
+ _amp_cast_op = P.Cast
67
66
 
68
67
  class _OutputTo16(nn.Cell):
69
68
  """Wrap cell for amp. Cast network output back to float16."""
70
- def __init__(self, backbone):
69
+ def __init__(self, backbone, dtype=mstype.float16):
71
70
  super(_OutputTo16, self).__init__(auto_prefix=False)
72
71
  self._backbone = backbone
73
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
74
- self._jit_config_dict = backbone.jit_config_dict
72
+ self.dtype = dtype
73
+ self._get_attr_from_cell(backbone)
75
74
 
76
- def construct(self, x):
77
- return F.cast(self._backbone(x), mstype.float16)
75
+ def construct(self, *args, **kwargs):
76
+ return F.cast(self._backbone(*args, **kwargs), self.dtype)
78
77
 
79
78
 
80
79
  class _OutputTo32(nn.Cell):
@@ -82,68 +81,78 @@ class _OutputTo32(nn.Cell):
82
81
  def __init__(self, backbone):
83
82
  super(_OutputTo32, self).__init__(auto_prefix=False)
84
83
  self._backbone = backbone
85
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
86
- self._jit_config_dict = backbone.jit_config_dict
84
+ self._get_attr_from_cell(backbone)
87
85
 
88
- def construct(self, *inputs):
89
- out = self._backbone(*inputs)
86
+ def construct(self, *args, **kwargs):
87
+ out = self._backbone(*args, **kwargs)
90
88
  return F.mixed_precision_cast(mstype.float32, out)
91
89
 
92
90
 
93
- def _allow_mix_precision(node, allowed_list) -> bool:
91
+
92
+ def _allow_mix_precision(node, allowed_list, dtype) -> bool:
94
93
  """
95
94
  Check whether current node need do mix precision. Follow conditions need to be satisfied:
96
95
  1) Type of node is one of (Primitive, nn.Cell)
97
- 2) Node is not P.Cast()
96
+ 2) Node is not Cast Op
98
97
  3) to_float(mindspore.float16) is not set in Cell
99
98
  """
100
- if node.get_instance() in allowed_list:
99
+ node_inst = node.get_instance()
100
+ if node_inst in allowed_list:
101
101
  return True
102
+ if node.get_targets() is None:
103
+ return False
102
104
  if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
103
105
  return False
104
- if isinstance(node.get_instance(), P.Cast):
106
+ if isinstance(node_inst, _amp_cast_op):
105
107
  return False
106
108
  if issubclass(node.get_instance_type(), nn.Cell):
107
- # if cell is already in allowed_list, it means to_float(mindspore.float16) is set by amp.
108
- # if cell is not in allowed_list, but has to_float(mindspore.float16),
109
- # it means to_float(mindspore.float16) is set by user.
110
- if node.get_instance().to_float_fp16:
109
+ # if cell is already in allowed_list, it means to_float() is set by amp.
110
+ # if cell is not in allowed_list, but has to_float(),
111
+ # it means to_float() is set by user.
112
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
113
+ if hasattr(node_inst, to_float_flag) and getattr(node_inst, to_float_flag):
111
114
  return False
112
115
  allowed_list.append(node.get_instance())
113
116
  return True
114
117
 
115
118
 
116
- def _insert_cast_operator_process(node, stree):
119
+ def _insert_cast_operator_process(node, dtype):
117
120
  """insert cast for operators in white_list."""
121
+ dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
118
122
  new_cast_node = None
119
- # insert cast float16 before the primitive operators
123
+ stree = node.get_symbol_tree()
124
+ # insert cast fp16/bf16 before the primitive operators
120
125
  if issubclass(node.get_instance_type(), Primitive):
121
- for idx in range(len(node.get_inputs())):
126
+ for idx, arg in enumerate(node.get_args()):
122
127
  position = stree.before(node)
123
- new_node = P.Cast()
124
- arg = ms.rewrite.ScopedValue.create_name_values([node.get_inputs()[idx].get_targets()[0].value,
125
- "mindspore.float16"])
128
+ new_node = _amp_cast_op()
129
+ cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
130
+ arg_provider = node.get_handler().get_arg_providers()[idx]
131
+ if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
132
+ cast_targets = [stree.unique_name(str(arg))]
133
+ else:
134
+ cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
126
135
  new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
127
- targets=['x_cast_{}'.format(node.get_name())],
128
- args=arg,
136
+ targets=cast_targets,
137
+ args=cast_args,
129
138
  name='incast_{}{}'.format(node.get_name(), idx))
130
139
  stree.insert(position, new_cast_node)
131
140
  node.set_arg_by_node(idx, new_cast_node)
132
- # insert cast float16 before the Cell operators
141
+ # insert cast fp16/bf16 before the Cell operators
133
142
  elif issubclass(node.get_instance_type(), nn.Cell):
134
- node.get_instance().to_float(mstype.float16)
143
+ node.get_instance().to_float(dtype)
135
144
  # ignore if subclass is not one of (Primitive, nn.Cell)
136
145
  else:
137
146
  return
138
147
 
139
148
  # insert cast float32 after the operators
140
149
  position = stree.after(node)
141
- new_node = P.Cast()
142
- arg = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
143
- "mindspore.float32"])
150
+ new_node = _amp_cast_op()
151
+ cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
152
+ "mindspore.float32"])
144
153
  new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
145
- targets=['x_cast_{}'.format(node.get_name())],
146
- args=arg,
154
+ targets=[node.get_targets()[0]],
155
+ args=cast_args,
147
156
  name='outcast_{}'.format(node.get_name()))
148
157
  # insert node & unique names
149
158
  stree.insert(position, new_cast_node)
@@ -156,43 +165,102 @@ def _insert_cast_operator_process(node, stree):
156
165
  user.set_arg_by_node(idx, new_cast_node)
157
166
 
158
167
 
159
- def _insert_cast_operator_white_list(stree, white_list):
168
+ def _insert_cast_operator_white_list(stree, white_list, dtype):
160
169
  """insert cast for operators in white_list."""
161
170
  allowed_list = []
162
- for node in stree.nodes():
163
- if node.get_targets() is None:
164
- continue
171
+ # Ignore if net called ".to_float(dtype)"
172
+ net = stree.get_handler().get_origin_network()
173
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
174
+ if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
175
+ return
176
+ node_list = []
177
+ node_list.extend(list(stree.nodes()))
178
+ while node_list:
179
+ node = node_list.pop()
165
180
  if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
181
+ if MS_AMP_BY_REWRITE:
182
+ _insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
166
183
  for n in node.get_handler().node_list:
167
184
  if n.get_node_type() == ms.rewrite.NodeType.Tree:
168
185
  _insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
169
- white_list)
186
+ white_list, dtype)
170
187
  elif node.get_node_type() == ms.rewrite.NodeType.Tree:
171
188
  substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
172
- _insert_cast_operator_white_list(substree, white_list)
173
- elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list):
174
- _insert_cast_operator_process(node, stree)
189
+ _insert_cast_operator_white_list(substree, white_list, dtype)
190
+ elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
191
+ if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
192
+ nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
193
+ node_list.extend(nodes)
194
+ elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
195
+ _insert_cast_operator_process(node, dtype)
175
196
 
176
197
 
177
- def _need_removed_cast_pair(node):
198
+ def _insert_cast_for_cell_container(cell_container, dtype, allowed_list, *, white_list=None, black_list=None):
199
+ """
200
+ Insert cast for cell containers.
201
+ Only one of white_list and black_list can be set.
202
+ """
203
+
204
+ class CastNet(nn.Cell):
205
+ """Cast net"""
206
+ def __init__(self, dtype):
207
+ super().__init__()
208
+ self.cast = _amp_cast_op()
209
+ self.dtype = dtype
210
+
211
+ def construct(self, x):
212
+ return self.cast(x, self.dtype)
213
+
214
+ cast_flag = False
215
+ current_node = None
216
+ stree = cell_container.get_symbol_tree()
217
+ for node in cell_container.get_handler().nodes():
218
+ current_node = ms.rewrite.Node(node)
219
+ if (white_list is not None and current_node.get_instance_type() in white_list) or \
220
+ (black_list is not None and current_node.get_instance_type() not in black_list) and \
221
+ (_allow_mix_precision(current_node, allowed_list, dtype)):
222
+ cast_flag = True
223
+ current_node.get_instance().to_float(dtype)
224
+ elif cast_flag:
225
+ # cast next node back to float32
226
+ current_node.get_instance().to_float(mstype.float32)
227
+ cast_flag = False
228
+ if cast_flag and current_node:
229
+ # if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
230
+ cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
231
+ args=[current_node.get_targets()[0]],
232
+ targets=[current_node.get_targets()[0]],
233
+ name=f"outcast_{cell_container.get_name()}")
234
+ stree.insert(stree.after(current_node), cast_node)
235
+
236
+
237
+ def _need_removed_cast_pair(node, dtype):
178
238
  """check whether the cast pairs should be removed."""
179
- cast_dtypes = ms.rewrite.ScopedValue.create_name_values(["mindspore.float16", "mindspore.float32"])
239
+ dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
240
+ cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "mindspore.float32"])
180
241
  cast_dtype_f16 = cast_dtypes[0]
181
242
  cast_dtype_f32 = cast_dtypes[1]
182
- # current node should be P.Cast()(x, mindspore.float32)
183
- if node.get_instance_type() != P.Cast:
243
+ # current node should be Cast Op to float32
244
+ if node.get_instance_type() != _amp_cast_op:
184
245
  return False
185
246
  node_cast_type = node.get_args()[1]
186
247
  if node_cast_type != cast_dtype_f32:
187
248
  return False
188
- # all user nodes should be P.Cast()(x, mindspore.float16) or Cell with to_float(mindspore.float16)
249
+ # all user nodes should be Cast Op to dtype or Cell with to_float(dtype)
189
250
  if not node.get_users():
190
251
  return False
252
+ all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
191
253
  for user in node.get_users():
254
+ # If ControlFlow node(if statement) exists between current node and user node,
255
+ # cast pair should not be removed.
256
+ middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
257
+ if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
258
+ return False
192
259
  if isinstance(user.get_instance(), nn.Cell):
193
- if not user.get_instance().to_float_fp16:
260
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
261
+ if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
194
262
  return False
195
- elif user.get_instance_type() == P.Cast:
263
+ elif user.get_instance_type() == _amp_cast_op:
196
264
  user_cast_type = user.get_args()[1]
197
265
  if user_cast_type != cast_dtype_f16:
198
266
  return False
@@ -201,18 +269,20 @@ def _need_removed_cast_pair(node):
201
269
  return True
202
270
 
203
271
 
204
- def _removed_cast_pair_process(stree, cast_f32_node):
272
+ def _removed_cast_pair_process(cast_f32_node):
205
273
  """remove the duplicated cast operators."""
206
- for user_node in cast_f32_node.get_users():
207
- # remove cast f16 nodes
208
- if user_node.get_instance_type() == P.Cast:
274
+ stree = cast_f32_node.get_symbol_tree()
275
+ cast_f32_users = cast_f32_node.get_users()
276
+ # remove cast f16 nodes
277
+ for user_node in cast_f32_users:
278
+ if user_node.get_instance_type() == _amp_cast_op:
209
279
  cast_f16_node = user_node
210
280
  # modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
211
281
  for cast_f16_user in cast_f16_node.get_users():
212
282
  for idx, arg in enumerate(cast_f16_user.get_args()):
213
283
  if arg == cast_f16_node.get_targets()[0]:
214
284
  cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
215
- stree.erase_node(cast_f16_node)
285
+ stree.erase(cast_f16_node)
216
286
  # update args of cell f16 nodes
217
287
  elif isinstance(user_node.get_instance(), nn.Cell):
218
288
  cell_f16_node = user_node
@@ -220,37 +290,81 @@ def _removed_cast_pair_process(stree, cast_f32_node):
220
290
  if arg == cast_f32_node.get_targets()[0]:
221
291
  cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
222
292
  # remove the cast f32 node
223
- stree.erase_node(cast_f32_node)
293
+ stree.erase(cast_f32_node)
224
294
 
225
295
 
226
- def _remove_duplicated_cast(stree):
296
+ def _remove_duplicated_cast(stree, dtype):
227
297
  """remove the duplicated cast operators."""
228
- for node in stree.nodes():
229
- if node.get_targets() is None:
230
- continue
298
+ node_list = []
299
+ node_list.extend(list(stree.nodes()))
300
+ while node_list:
301
+ node = node_list.pop()
231
302
  if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
232
303
  for n in node.get_handler().node_list:
233
304
  if n.get_node_type() == ms.rewrite.NodeType.Tree:
234
- _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)))
305
+ _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
235
306
  elif node.get_node_type() == ms.rewrite.NodeType.Tree:
236
307
  substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
237
- _remove_duplicated_cast(substree)
238
- elif _need_removed_cast_pair(node):
239
- _removed_cast_pair_process(stree, node)
308
+ _remove_duplicated_cast(substree, dtype)
309
+ elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
310
+ if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
311
+ nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
312
+ node_list.extend(nodes)
313
+ elif _need_removed_cast_pair(node, dtype):
314
+ _removed_cast_pair_process(node)
240
315
 
241
316
 
242
- def _auto_white_list(network, white_list):
317
+ def _auto_white_list(network, white_list, dtype):
243
318
  """process the white list of network."""
244
- global STREE
245
- STREE = ms.rewrite.SymbolTree.create(network)
246
- _insert_cast_operator_white_list(STREE, white_list)
247
- _remove_duplicated_cast(STREE)
248
- return STREE.get_network()
319
+ stree = ms.rewrite.SymbolTree.create(network)
320
+ _insert_cast_operator_white_list(stree, white_list, dtype)
321
+ _remove_duplicated_cast(stree, dtype)
322
+ return stree.get_network()
249
323
 
250
324
 
251
- def _auto_black_list(network, black_list):
325
+ def _insert_cast_operator_black_list(stree, black_list, dtype):
326
+ """insert cast for operators not in black_list."""
327
+ allowed_list = []
328
+ # Ignore if net called ".to_float(dtype)"
329
+ net = stree.get_handler().get_origin_network()
330
+ to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
331
+ if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
332
+ return
333
+ for node in stree.nodes(all_nodes=True):
334
+ if node.get_targets() is None:
335
+ continue
336
+ if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
337
+ _insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
338
+ elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
339
+ # nodes in CellContainer are processed by _insert_cast_for_cell_container
340
+ continue
341
+ elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
342
+ _insert_cast_operator_process(node, dtype)
343
+
344
+
345
+ def _remove_duplicated_cast_rewrite(stree, dtype):
346
+ """remove the duplicated cast operators."""
347
+ for node in stree.nodes(all_nodes=True):
348
+ if _need_removed_cast_pair(node, dtype):
349
+ user_nodes = node.get_users()
350
+ # remove cast f16 nodes
351
+ for user_node in user_nodes:
352
+ if user_node.get_instance_type() == _amp_cast_op:
353
+ stree.erase(user_node)
354
+ # remove the cast f32 node
355
+ stree.erase(node)
356
+
357
+
358
+ def _auto_black_list_rewrite(network, black_list, dtype):
359
+ stree = ms.rewrite.SymbolTree.create(network)
360
+ _insert_cast_operator_black_list(stree, black_list, dtype)
361
+ _remove_duplicated_cast_rewrite(stree, dtype)
362
+ return stree.get_network()
363
+
364
+
365
+ def _auto_black_list(network, black_list, dtype):
252
366
  """process the black list of network."""
253
- network.to_float(mstype.float16)
367
+ network.to_float(dtype)
254
368
  cells = network.name_cells()
255
369
  change = False
256
370
  for name in cells:
@@ -258,32 +372,76 @@ def _auto_black_list(network, black_list):
258
372
  if subcell == network:
259
373
  continue
260
374
  if isinstance(subcell, tuple(black_list)):
261
- network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
375
+ network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
262
376
  change = True
263
377
  else:
264
- _auto_black_list(subcell, black_list)
378
+ _auto_black_list(subcell, black_list, dtype)
265
379
  if isinstance(network, nn.SequentialCell) and change:
266
380
  network.cell_list = list(network.cells())
381
+ return network
267
382
 
268
383
 
269
- def auto_mixed_precision(network, amp_level="O0"):
384
+ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
270
385
  """
271
- auto mixed precision function.
386
+ Returns a network processed with auto mixed precision.
387
+
388
+ This interface will automatically perform mixed-precision processing on the input network, and the cells
389
+ and operators in the processed network will add precision conversion operations to calculate with lower
390
+ precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
391
+ converted to lower precision float, and calculation results are converted back to full precision float,
392
+ i.e. ``mstype.float32`` .
393
+
394
+ The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
395
+ operators are specifically converted.
396
+
397
+ The current built-in whitelist contents are:
398
+
399
+ [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
400
+ :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
401
+ :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
402
+ :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
403
+ :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
404
+ :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
405
+ :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
406
+
407
+ The current built-in blacklist contents are:
408
+
409
+ [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
410
+ :class:`mindspore.nn.LayerNorm`]
411
+
412
+ For details on automatic mixed precision, refer to
413
+ `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.2/advanced/mixed_precision.html>`_ .
414
+
415
+ Note:
416
+ - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
417
+ can result in a larger network hierarchy and slower performance.
418
+ - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
419
+ mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
420
+ need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
272
421
 
273
422
  Args:
274
423
  network (Cell): Definition of the network.
275
- amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0".
424
+ amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
276
425
 
277
426
  - "O0": Do not change.
278
- - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
279
- - "O2": Cast network to float16, keep operators in black_list run in float32,
280
- - "O3": Cast network to float16.
427
+ - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
428
+ precision operations for the rest.
429
+ - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
430
+ to lower precision operations.
431
+ - "O3": Cast network to lower precision.
432
+
433
+ dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
434
+ default: ``mstype.float16`` .
281
435
 
282
436
  Raises:
283
- ValueError: If amp level is not supported.
437
+ TypeError: If `network` is not a Cell.
438
+ ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
439
+ ValueError: If `amp_level` is not within the supported range.
284
440
 
285
441
  Examples:
286
- >>> from mindspore import amp, nn
442
+ >>> from mindspore import amp
443
+ >>> # Define the network structure of LeNet5. Refer to
444
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
287
445
  >>> network = LeNet5()
288
446
  >>> amp_level = "O1"
289
447
  >>> net = amp.auto_mixed_precision(network, amp_level)
@@ -291,18 +449,37 @@ def auto_mixed_precision(network, amp_level="O0"):
291
449
  if not isinstance(network, nn.Cell):
292
450
  raise TypeError("The network type should be Cell.")
293
451
 
452
+ if dtype not in (mstype.float16, mstype.bfloat16):
453
+ raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
454
+
294
455
  if amp_level == "O0":
295
- pass
296
- elif amp_level == "O1":
297
- return _auto_white_list(network, AMP_WHITE_LIST)
456
+ return network
457
+
458
+ # Return network if the same amp level has already been configurated
459
+ if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
460
+ logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
461
+ f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
462
+ f"degradation.")
463
+
464
+ if amp_level == "O1":
465
+ network = _auto_white_list(network, AMP_WHITE_LIST, dtype)
298
466
  elif amp_level == "O2":
299
- _auto_black_list(network, AMP_BLACK_LIST)
467
+ if MS_AMP_BY_REWRITE:
468
+ network = _auto_black_list_rewrite(network, AMP_BLACK_LIST, dtype)
469
+ else:
470
+ network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
471
+ network = _OutputTo32(network)
300
472
  elif amp_level == "O3":
301
- network.to_float(mstype.float16)
473
+ if MS_AMP_BY_REWRITE:
474
+ network = _auto_black_list_rewrite(network, [], dtype)
475
+ else:
476
+ network.to_float(dtype)
477
+ network = _OutputTo32(network)
302
478
  else:
303
479
  raise ValueError("The amp level {} is not supported".format(amp_level))
304
- if amp_level in ("O2", "O3"):
305
- network = _OutputTo32(network)
480
+
481
+ setattr(network, "_amp_level", amp_level)
482
+
306
483
  return network
307
484
 
308
485
 
@@ -393,8 +570,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
393
570
  super(WithLossCell, self).__init__(auto_prefix=False)
394
571
  self._backbone = backbone
395
572
  self._loss_fn = loss_fn
396
- if isinstance(backbone, nn.Cell) and backbone.jit_config_dict:
397
- self._jit_config_dict = backbone.jit_config_dict
573
+ self._get_attr_from_cell(backbone)
398
574
 
399
575
  def construct(self, data, label):
400
576
  out = self._backbone(data)
@@ -409,42 +585,80 @@ def _add_loss_network(network, loss_fn, cast_model_type):
409
585
  return network
410
586
 
411
587
 
588
+ def _is_grad_accumulation(mcell):
589
+ if mcell.cls_name == "GradAccumulationCell":
590
+ return True
591
+ for cell in mcell.cells():
592
+ if _is_grad_accumulation(cell):
593
+ return True
594
+ return False
595
+
596
+
597
+ def _auto_mixed_precision_process(network, config, level):
598
+ """Auto mixed precision process."""
599
+ if MS_AMP_BY_REWRITE:
600
+ if config["cast_model_type"] == mstype.float16 or level == "O2":
601
+ level = "O2" if config["keep_batchnorm_fp32"] else "O3"
602
+ elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
603
+ # cast_model_type set by kwargs
604
+ level = "O0"
605
+ network = auto_mixed_precision(network, level)
606
+ else:
607
+ if config["cast_model_type"] == mstype.float16:
608
+ network.to_float(mstype.float16)
609
+
610
+ if config["keep_batchnorm_fp32"]:
611
+ _do_keep_batchnorm_fp32(network)
612
+ elif not config["keep_batchnorm_fp32"] and level == "O2":
613
+ network.to_float(mstype.float16)
614
+ elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
615
+ pass
616
+ else:
617
+ network = auto_mixed_precision(network, level)
618
+ return network
619
+
620
+
412
621
  def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
413
622
  """
414
623
  Build the mixed precision training cell automatically.
415
624
 
625
+ Note:
626
+ - After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
627
+ to perform the precision conversion again. If `build_train_network` is used to train a converted network,
628
+ `level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
629
+
416
630
  Args:
417
631
  network (Cell): Definition of the network.
632
+ optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter.
418
633
  loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
419
- Default: None.
420
- optimizer (Optimizer): Define the optimizer to update the Parameter.
421
- level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0".
634
+ Default: ``None`` .
635
+ level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
422
636
 
423
- - "O0": Do not change.
424
- - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
637
+ - 'O0': Do not change.
638
+ - 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32.
425
639
  The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
426
640
  Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
427
- - "O2": Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
641
+ - 'O2': Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
428
642
  using dynamic loss scale.
429
- - "O3": Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
430
- - auto: Set to level to recommended level in different devices. Set level to "O2" on GPU, Set
431
- level to "O3" Ascend. The recommended level is chosen by the export experience, not applicable to all
643
+ - 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
644
+ - 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
645
+ level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
432
646
  scenarios. User should specify the level for special network.
433
647
 
434
- "O2" is recommended on GPU, "O3" is recommended on Ascend. Property of `keep_batchnorm_fp32`,
648
+ 'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
435
649
  `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
436
650
  `kwargs`.
437
651
 
438
652
  boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
439
- training. Supports ["O0", "O1", "O2"]. Default: "O0".
653
+ training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
440
654
 
441
- - "O0": Do not change.
442
- - "O1": Enable the boost mode, the performance is improved by about 20%, and
655
+ - 'O0': Do not change.
656
+ - 'O1': Enable the boost mode, the performance is improved by about 20%, and
443
657
  the accuracy is the same as the original accuracy.
444
- - "O2": Enable the boost mode, the performance is improved by about 30%, and
658
+ - 'O2': Enable the boost mode, the performance is improved by about 30%, and
445
659
  the accuracy is reduced by less than 3%.
446
660
 
447
- If "O1" or "O2" mode is set, the boost related library will take effect automatically.
661
+ If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
448
662
 
449
663
  cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
450
664
  network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
@@ -461,6 +675,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
461
675
 
462
676
  Examples:
463
677
  >>> from mindspore import amp, nn
678
+ >>> # Define the network structure of LeNet5. Refer to
679
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
464
680
  >>> network = LeNet5()
465
681
  >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
466
682
  >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
@@ -475,22 +691,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
475
691
  _check_kwargs(kwargs)
476
692
  config = dict(_config_level.get(level), **kwargs)
477
693
 
478
- if config["cast_model_type"] == mstype.float16:
479
- network.to_float(mstype.float16)
480
-
481
- if config["keep_batchnorm_fp32"]:
482
- _do_keep_batchnorm_fp32(network)
483
- elif not config["keep_batchnorm_fp32"] and level == "O2":
484
- network.to_float(mstype.float16)
485
- elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
486
- pass
487
- else:
488
- network = auto_mixed_precision(network, level)
694
+ network = _auto_mixed_precision_process(network, config, level)
489
695
 
490
696
  if loss_fn:
491
697
  network = _add_loss_network(network, loss_fn, config["cast_model_type"])
492
698
 
493
- loss_scale = 1.0
699
+ loss_scale = None
494
700
  if config["loss_scale_manager"] is not None:
495
701
  loss_scale_manager = config["loss_scale_manager"]
496
702
  loss_scale = loss_scale_manager.get_loss_scale()
@@ -501,8 +707,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
501
707
  raise ValueError("Only `loss_scale_manager=None` or "
502
708
  "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
503
709
  "are supported on device `CPU`. ")
504
- if _get_pipeline_stages() > 1:
505
- network = _TrainPipelineWithLossScaleCell(network, optimizer,
710
+ if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
711
+ network = _TrainGradAccuWithLossScaleCell(network, optimizer,
506
712
  scale_sense=update_cell).set_train()
507
713
  elif enable_boost:
508
714
  network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
@@ -511,8 +717,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
511
717
  network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
512
718
  scale_sense=update_cell).set_train()
513
719
  return network
514
- if _get_pipeline_stages() > 1:
515
- network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
720
+ if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
721
+ network = _TrainGradAccuStepCell(network, optimizer).set_train()
516
722
  elif enable_boost:
517
723
  network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
518
724
  else:
@@ -524,11 +730,35 @@ def get_white_list():
524
730
  """
525
731
  Provide a copy of internal white list used by auto mixed precision.
526
732
 
527
- .. warning::
528
- This is an experimental API that is subject to change or deletion.
733
+ The current built-in whitelist contents are:
734
+
735
+ [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
736
+ :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
737
+ :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
738
+ :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
739
+ :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
740
+ :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
741
+ :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
529
742
 
530
743
  Returns:
531
744
  list, A copy of internal white list.
745
+
746
+ Examples:
747
+ >>> from mindspore import amp
748
+ >>> white_list = amp.get_white_list()
749
+ >>> print(white_list)
750
+ [<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
751
+ <class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
752
+ <class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
753
+ <class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
754
+ <class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
755
+ <class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
756
+ <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
757
+ <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
758
+ <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
759
+ <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
760
+ <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
761
+ <class 'mindspore.ops.operations.math_ops.Ger'>]
532
762
  """
533
763
  white_list = AMP_WHITE_LIST.copy()
534
764
  return white_list
@@ -538,39 +768,48 @@ def get_black_list():
538
768
  """
539
769
  Provide a copy of internal black list used by auto mixed precision.
540
770
 
541
- .. warning::
542
- This is an experimental API that is subject to change or deletion.
771
+ The current built-in blacklist contents are:
772
+
773
+ [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
774
+ :class:`mindspore.nn.LayerNorm`]
543
775
 
544
776
  Returns:
545
777
  list, A copy of internal black list.
778
+
779
+ Examples:
780
+ >>> from mindspore import amp
781
+ >>> black_list = amp.get_black_list()
782
+ >>> print(black_list)
783
+ [<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
784
+ <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
546
785
  """
547
786
  black_list = AMP_BLACK_LIST.copy()
548
787
  return black_list
549
788
 
550
789
 
551
- def custom_mixed_precision(network, *, white_list=None, black_list=None):
790
+ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
552
791
  """
553
792
  Custom mixed precision by setting whitelist or blacklist.
554
793
  When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
555
- When the `black_list` is provided, primitives and cells that are not in `black_list` will perform the pereision
556
- conversion.
794
+ When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
557
795
  Only one of `white_list` and `black_list` should be provided.
558
796
 
559
- .. warning::
560
- This is an experimental API that is subject to change or deletion.
561
-
562
797
  Note:
563
- - `custom_mixed_precision` should not be used at the same time as `auto_mixed_precision` . When both
564
- `build_train_network` and `custom_mixed_precision` are used, `build_train_network` need to be called with
565
- `level='O0'` before call `custom_mixed_precision` .
798
+ - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
799
+ can result in a larger network hierarchy and slower performance.
800
+ - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
801
+ mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
802
+ need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
566
803
  - Primitives for blacklist is not support yet.
567
804
 
568
805
  Args:
569
806
  network (Cell): Definition of the network.
570
- white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: None, means
807
+ white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
571
808
  white list is not used.
572
- black_list (list[Primitive, Cell], optional): Black list of custom mixed precision. Defaults: None, means
809
+ black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
573
810
  black list is not used.
811
+ dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
812
+ default: ``mstype.float16`` .
574
813
 
575
814
  Returns:
576
815
  network (Cell), A network supporting mixed precision.
@@ -578,13 +817,16 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
578
817
  Raises:
579
818
  TypeError: The network type is not Cell.
580
819
  ValueError: Neither `white_list` nor `black_list` is provided.
820
+ ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
581
821
  ValueError: Both `white_list` and `black_list` are provided.
582
822
 
583
823
  Examples:
584
- >>> from mindspore import amp
585
- >>> net = MyNet()
824
+ >>> from mindspore import amp, nn
825
+ >>> # Define the network structure of LeNet5. Refer to
826
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
827
+ >>> net = LeNet5()
586
828
  >>> custom_white_list = amp.get_white_list()
587
- >>> custom_white_list.append(nn.Tanhshrink)
829
+ >>> custom_white_list.append(nn.Flatten)
588
830
  >>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
589
831
  """
590
832
  if not isinstance(network, nn.Cell):
@@ -597,13 +839,19 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None):
597
839
  raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
598
840
  "at the same time, please provide one or the other.")
599
841
 
842
+ if dtype not in (mstype.float16, mstype.bfloat16):
843
+ raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
844
+
600
845
  if white_list is not None:
601
846
  _list_check(white_list, "white_list")
602
- return _auto_white_list(network, white_list)
603
-
604
- _list_check(black_list, "black_list")
605
- _auto_black_list(network, black_list)
606
- network = _OutputTo32(network)
847
+ network = _auto_white_list(network, white_list, dtype)
848
+ else:
849
+ _list_check(black_list, "black_list")
850
+ if MS_AMP_BY_REWRITE:
851
+ network = _auto_black_list_rewrite(network, black_list, dtype)
852
+ else:
853
+ network = _auto_black_list(network, black_list, dtype)
854
+ network = _OutputTo32(network)
607
855
  return network
608
856
 
609
857
 
@@ -623,11 +871,25 @@ def _list_check(custom_list: list, list_name: str):
623
871
  if not isinstance(elem, type):
624
872
  raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
625
873
 
626
- if not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
874
+ if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
627
875
  raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
628
876
  f"but got {elem}")
629
877
 
878
+ if list_name == "black_list" and not issubclass(elem, nn.Cell):
879
+ raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}")
880
+
630
881
  if list_name == 'black_list':
631
882
  for elem in AMP_BLACK_LIST:
632
883
  if elem not in custom_list:
633
884
  logger.warning(f"{elem} is removed from internal black list.")
885
+
886
+ def _config_amp(*, enable_rewrite: bool = None, cast_op: type = None): # pylint: disable=unused-variable
887
+ """Configure auto mixed precision."""
888
+ global MS_AMP_BY_REWRITE
889
+ global _amp_cast_op
890
+
891
+ if enable_rewrite is not None:
892
+ MS_AMP_BY_REWRITE = enable_rewrite
893
+
894
+ if cast_op is not None:
895
+ _amp_cast_op = cast_op