mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1217) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +8 -5
  5. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +124 -25
  9. mindspore/_extends/builtin_operations.py +2 -1
  10. mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
  11. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
  12. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
  13. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
  14. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  15. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
  16. mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
  17. mindspore/_extends/parse/__init__.py +18 -14
  18. mindspore/_extends/parse/compile_config.py +299 -0
  19. mindspore/_extends/parse/namespace.py +2 -2
  20. mindspore/_extends/parse/parser.py +182 -68
  21. mindspore/_extends/parse/resources.py +45 -14
  22. mindspore/_extends/parse/standard_method.py +192 -252
  23. mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
  24. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  25. mindspore/_extends/remote/kernel_build_server.py +2 -0
  26. mindspore/_profiler.py +30 -0
  27. mindspore/amp.py +67 -26
  28. mindspore/atlprov.dll +0 -0
  29. mindspore/avcodec-59.dll +0 -0
  30. mindspore/avdevice-59.dll +0 -0
  31. mindspore/avfilter-8.dll +0 -0
  32. mindspore/avformat-59.dll +0 -0
  33. mindspore/avutil-57.dll +0 -0
  34. mindspore/boost/adasum.py +1 -1
  35. mindspore/boost/base.py +1 -1
  36. mindspore/boost/boost_cell_wrapper.py +2 -2
  37. mindspore/boost/grad_freeze.py +2 -2
  38. mindspore/boost/group_loss_scale_manager.py +1 -1
  39. mindspore/boost/less_batch_normalization.py +9 -6
  40. mindspore/c1.dll +0 -0
  41. mindspore/c1xx.dll +0 -0
  42. mindspore/c2.dll +0 -0
  43. mindspore/common/__init__.py +20 -7
  44. mindspore/common/_jit_fallback_utils.py +2 -3
  45. mindspore/common/_pijit_context.py +190 -0
  46. mindspore/common/_register_for_adapter.py +7 -0
  47. mindspore/common/_register_for_recompute.py +48 -0
  48. mindspore/common/_register_for_tensor.py +10 -10
  49. mindspore/common/_stub_tensor.py +7 -1
  50. mindspore/common/_tensor_overload.py +139 -0
  51. mindspore/common/_utils.py +5 -17
  52. mindspore/common/api.py +449 -129
  53. mindspore/common/auto_dynamic_shape.py +27 -14
  54. mindspore/common/dtype.py +17 -10
  55. mindspore/common/dump.py +8 -11
  56. mindspore/common/file_system.py +48 -0
  57. mindspore/common/generator.py +254 -0
  58. mindspore/common/hook_handle.py +65 -30
  59. mindspore/common/initializer.py +1 -1
  60. mindspore/common/jit_config.py +34 -14
  61. mindspore/common/lazy_inline.py +72 -19
  62. mindspore/common/mindir_util.py +12 -2
  63. mindspore/common/mutable.py +79 -14
  64. mindspore/common/no_inline.py +54 -0
  65. mindspore/common/np_dtype.py +25 -0
  66. mindspore/common/parameter.py +73 -21
  67. mindspore/common/recompute.py +292 -0
  68. mindspore/common/seed.py +9 -9
  69. mindspore/common/sparse_tensor.py +276 -24
  70. mindspore/common/symbol.py +122 -0
  71. mindspore/common/tensor.py +668 -514
  72. mindspore/communication/__init__.py +6 -11
  73. mindspore/communication/_comm_helper.py +43 -3
  74. mindspore/communication/comm_func.py +1395 -0
  75. mindspore/communication/management.py +117 -104
  76. mindspore/config/op_info.config +22 -54
  77. mindspore/context.py +455 -71
  78. mindspore/dataset/__init__.py +5 -5
  79. mindspore/dataset/audio/__init__.py +6 -6
  80. mindspore/dataset/audio/transforms.py +711 -158
  81. mindspore/dataset/callback/ds_callback.py +2 -2
  82. mindspore/dataset/core/config.py +7 -0
  83. mindspore/dataset/core/validator_helpers.py +7 -0
  84. mindspore/dataset/engine/cache_client.py +2 -2
  85. mindspore/dataset/engine/datasets.py +201 -116
  86. mindspore/dataset/engine/datasets_audio.py +14 -14
  87. mindspore/dataset/engine/datasets_standard_format.py +83 -3
  88. mindspore/dataset/engine/datasets_text.py +39 -39
  89. mindspore/dataset/engine/datasets_user_defined.py +230 -141
  90. mindspore/dataset/engine/datasets_vision.py +78 -74
  91. mindspore/dataset/engine/iterators.py +29 -0
  92. mindspore/dataset/engine/obs/util.py +7 -0
  93. mindspore/dataset/engine/offload.py +5 -7
  94. mindspore/dataset/engine/queue.py +138 -66
  95. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  96. mindspore/dataset/engine/validators.py +41 -15
  97. mindspore/dataset/text/__init__.py +2 -5
  98. mindspore/dataset/text/transforms.py +408 -121
  99. mindspore/dataset/text/utils.py +9 -9
  100. mindspore/dataset/transforms/__init__.py +0 -3
  101. mindspore/dataset/transforms/transforms.py +261 -76
  102. mindspore/dataset/utils/browse_dataset.py +9 -9
  103. mindspore/dataset/utils/line_reader.py +2 -0
  104. mindspore/dataset/vision/__init__.py +7 -10
  105. mindspore/dataset/vision/c_transforms.py +10 -10
  106. mindspore/dataset/vision/py_transforms_util.py +1 -1
  107. mindspore/dataset/vision/transforms.py +2844 -549
  108. mindspore/dataset/vision/utils.py +161 -10
  109. mindspore/dataset/vision/validators.py +16 -3
  110. mindspore/dnnl.dll +0 -0
  111. mindspore/dpcmi.dll +0 -0
  112. mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
  113. mindspore/experimental/es/embedding_service.py +883 -0
  114. mindspore/experimental/es/embedding_service_layer.py +581 -0
  115. mindspore/experimental/llm_boost/__init__.py +21 -0
  116. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  117. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  118. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  119. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  120. mindspore/experimental/llm_boost/register.py +129 -0
  121. mindspore/experimental/llm_boost/utils.py +31 -0
  122. mindspore/experimental/optim/__init__.py +12 -2
  123. mindspore/experimental/optim/adadelta.py +161 -0
  124. mindspore/experimental/optim/adagrad.py +168 -0
  125. mindspore/experimental/optim/adam.py +35 -34
  126. mindspore/experimental/optim/adamax.py +170 -0
  127. mindspore/experimental/optim/adamw.py +124 -15
  128. mindspore/experimental/optim/asgd.py +153 -0
  129. mindspore/experimental/optim/lr_scheduler.py +66 -121
  130. mindspore/experimental/optim/nadam.py +157 -0
  131. mindspore/experimental/optim/optimizer.py +18 -8
  132. mindspore/experimental/optim/radam.py +194 -0
  133. mindspore/experimental/optim/rmsprop.py +154 -0
  134. mindspore/experimental/optim/rprop.py +164 -0
  135. mindspore/experimental/optim/sgd.py +28 -19
  136. mindspore/hal/__init__.py +40 -0
  137. mindspore/hal/_ascend.py +57 -0
  138. mindspore/hal/_base.py +57 -0
  139. mindspore/hal/_cpu.py +56 -0
  140. mindspore/hal/_gpu.py +57 -0
  141. mindspore/hal/contiguous_tensors_handle.py +175 -0
  142. mindspore/hal/device.py +356 -0
  143. mindspore/hal/event.py +179 -0
  144. mindspore/hal/memory.py +326 -0
  145. mindspore/hal/stream.py +357 -0
  146. mindspore/include/api/data_type.h +2 -2
  147. mindspore/include/api/dual_abi_helper.h +16 -3
  148. mindspore/include/api/model.h +4 -3
  149. mindspore/include/api/model_group.h +13 -1
  150. mindspore/include/api/status.h +14 -0
  151. mindspore/include/api/types.h +10 -10
  152. mindspore/include/c_api/model_c.h +173 -0
  153. mindspore/include/c_api/types_c.h +19 -0
  154. mindspore/include/dataset/config.h +2 -2
  155. mindspore/include/dataset/constants.h +2 -2
  156. mindspore/include/dataset/execute.h +3 -5
  157. mindspore/include/dataset/vision.h +58 -2
  158. mindspore/jpeg62.dll +0 -0
  159. mindspore/log.py +3 -3
  160. mindspore/mindrecord/__init__.py +5 -1
  161. mindspore/mindrecord/config.py +809 -0
  162. mindspore/mindrecord/filereader.py +25 -0
  163. mindspore/mindrecord/filewriter.py +138 -103
  164. mindspore/mindrecord/mindpage.py +40 -6
  165. mindspore/mindrecord/shardutils.py +3 -2
  166. mindspore/mindrecord/shardwriter.py +7 -0
  167. mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
  168. mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
  169. mindspore/mindrecord/tools/csv_to_mr.py +4 -9
  170. mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
  171. mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
  172. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
  173. mindspore/mindspore_backend.dll +0 -0
  174. mindspore/mindspore_common.dll +0 -0
  175. mindspore/mindspore_core.dll +0 -0
  176. mindspore/mindspore_glog.dll +0 -0
  177. mindspore/mindspore_np_dtype.dll +0 -0
  178. mindspore/mindspore_ops.dll +0 -0
  179. mindspore/mint/__init__.py +1586 -0
  180. mindspore/mint/distributed/__init__.py +31 -0
  181. mindspore/mint/distributed/distributed.py +254 -0
  182. mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
  183. mindspore/mint/nn/__init__.py +757 -0
  184. mindspore/mint/nn/functional.py +679 -0
  185. mindspore/mint/nn/layer/__init__.py +39 -0
  186. mindspore/mint/nn/layer/activation.py +133 -0
  187. mindspore/mint/nn/layer/normalization.py +477 -0
  188. mindspore/mint/nn/layer/pooling.py +110 -0
  189. mindspore/mint/optim/__init__.py +24 -0
  190. mindspore/mint/optim/adamw.py +206 -0
  191. mindspore/mint/special/__init__.py +63 -0
  192. mindspore/msobj140.dll +0 -0
  193. mindspore/mspdb140.dll +0 -0
  194. mindspore/mspdbcore.dll +0 -0
  195. mindspore/mspdbst.dll +0 -0
  196. mindspore/mspft140.dll +0 -0
  197. mindspore/msvcdis140.dll +0 -0
  198. mindspore/msvcp140_1.dll +0 -0
  199. mindspore/msvcp140_2.dll +0 -0
  200. mindspore/msvcp140_atomic_wait.dll +0 -0
  201. mindspore/msvcp140_codecvt_ids.dll +0 -0
  202. mindspore/multiprocessing/__init__.py +73 -0
  203. mindspore/nn/cell.py +461 -323
  204. mindspore/nn/dynamic_lr.py +2 -2
  205. mindspore/nn/layer/activation.py +292 -135
  206. mindspore/nn/layer/basic.py +288 -83
  207. mindspore/nn/layer/channel_shuffle.py +3 -16
  208. mindspore/nn/layer/container.py +3 -3
  209. mindspore/nn/layer/conv.py +75 -66
  210. mindspore/nn/layer/embedding.py +221 -45
  211. mindspore/nn/layer/image.py +4 -7
  212. mindspore/nn/layer/math.py +1 -1
  213. mindspore/nn/layer/normalization.py +150 -68
  214. mindspore/nn/layer/padding.py +64 -87
  215. mindspore/nn/layer/pooling.py +175 -12
  216. mindspore/nn/layer/rnn_cells.py +6 -16
  217. mindspore/nn/layer/rnns.py +6 -5
  218. mindspore/nn/layer/thor_layer.py +1 -2
  219. mindspore/nn/layer/timedistributed.py +1 -1
  220. mindspore/nn/layer/transformer.py +55 -53
  221. mindspore/nn/learning_rate_schedule.py +6 -5
  222. mindspore/nn/loss/__init__.py +2 -2
  223. mindspore/nn/loss/loss.py +145 -88
  224. mindspore/nn/optim/__init__.py +2 -1
  225. mindspore/nn/optim/ada_grad.py +4 -2
  226. mindspore/nn/optim/adadelta.py +4 -2
  227. mindspore/nn/optim/adafactor.py +1 -1
  228. mindspore/nn/optim/adam.py +102 -181
  229. mindspore/nn/optim/adamax.py +4 -2
  230. mindspore/nn/optim/adasum.py +3 -3
  231. mindspore/nn/optim/asgd.py +4 -2
  232. mindspore/nn/optim/ftrl.py +31 -61
  233. mindspore/nn/optim/lamb.py +5 -3
  234. mindspore/nn/optim/lars.py +2 -2
  235. mindspore/nn/optim/lazyadam.py +6 -4
  236. mindspore/nn/optim/momentum.py +13 -25
  237. mindspore/nn/optim/optimizer.py +6 -3
  238. mindspore/nn/optim/proximal_ada_grad.py +4 -2
  239. mindspore/nn/optim/rmsprop.py +9 -3
  240. mindspore/nn/optim/rprop.py +4 -2
  241. mindspore/nn/optim/sgd.py +5 -3
  242. mindspore/nn/optim/tft_wrapper.py +127 -0
  243. mindspore/nn/optim/thor.py +2 -2
  244. mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
  245. mindspore/nn/probability/distribution/beta.py +2 -2
  246. mindspore/nn/probability/distribution/categorical.py +4 -6
  247. mindspore/nn/probability/distribution/cauchy.py +2 -2
  248. mindspore/nn/probability/distribution/exponential.py +2 -2
  249. mindspore/nn/probability/distribution/geometric.py +1 -1
  250. mindspore/nn/probability/distribution/gumbel.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +1 -1
  252. mindspore/nn/probability/distribution/poisson.py +2 -2
  253. mindspore/nn/probability/distribution/uniform.py +2 -2
  254. mindspore/nn/reinforcement/_tensors_queue.py +13 -1
  255. mindspore/nn/wrap/__init__.py +2 -1
  256. mindspore/nn/wrap/cell_wrapper.py +46 -12
  257. mindspore/nn/wrap/grad_reducer.py +148 -8
  258. mindspore/nn/wrap/loss_scale.py +44 -7
  259. mindspore/numpy/__init__.py +2 -0
  260. mindspore/numpy/array_creations.py +67 -68
  261. mindspore/numpy/array_ops.py +70 -66
  262. mindspore/numpy/dtypes.py +3 -3
  263. mindspore/numpy/fft.py +966 -0
  264. mindspore/numpy/logic_ops.py +11 -10
  265. mindspore/numpy/math_ops.py +147 -152
  266. mindspore/numpy/utils.py +3 -0
  267. mindspore/numpy/utils_const.py +4 -4
  268. mindspore/opencv_core452.dll +0 -0
  269. mindspore/opencv_imgcodecs452.dll +0 -0
  270. mindspore/opencv_imgproc452.dll +0 -0
  271. mindspore/ops/__init__.py +9 -6
  272. mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
  273. mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
  274. mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
  275. mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
  276. mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
  277. mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
  278. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  279. mindspore/ops/_op_impl/__init__.py +0 -1
  280. mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
  281. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
  282. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
  283. mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
  284. mindspore/ops/_op_impl/cpu/__init__.py +1 -3
  285. mindspore/ops/_op_impl/cpu/adam.py +2 -2
  286. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
  287. mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
  288. mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
  289. mindspore/ops/_vmap/vmap_array_ops.py +162 -101
  290. mindspore/ops/_vmap/vmap_base.py +8 -1
  291. mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
  292. mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
  293. mindspore/ops/_vmap/vmap_image_ops.py +70 -13
  294. mindspore/ops/_vmap/vmap_math_ops.py +147 -59
  295. mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
  296. mindspore/ops/_vmap/vmap_other_ops.py +1 -1
  297. mindspore/ops/auto_generate/__init__.py +31 -0
  298. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  299. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  300. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  301. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  302. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  303. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  304. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  305. mindspore/ops/composite/__init__.py +5 -2
  306. mindspore/ops/composite/base.py +201 -66
  307. mindspore/ops/composite/math_ops.py +10 -49
  308. mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
  309. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
  310. mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
  311. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
  312. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
  313. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
  314. mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
  315. mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
  316. mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
  317. mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
  319. mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
  320. mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
  321. mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
  323. mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
  324. mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
  325. mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
  326. mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
  327. mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
  328. mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
  330. mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
  331. mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
  333. mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
  334. mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
  335. mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
  336. mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
  337. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
  338. mindspore/ops/deprecated.py +14 -3
  339. mindspore/ops/function/__init__.py +53 -11
  340. mindspore/ops/function/array_func.py +1269 -1821
  341. mindspore/ops/function/clip_func.py +19 -31
  342. mindspore/ops/function/debug_func.py +114 -5
  343. mindspore/ops/function/fft_func.py +44 -0
  344. mindspore/ops/function/grad/grad_func.py +30 -22
  345. mindspore/ops/function/image_func.py +27 -21
  346. mindspore/ops/function/linalg_func.py +35 -68
  347. mindspore/ops/function/math_func.py +1170 -2697
  348. mindspore/ops/function/nn_func.py +2116 -1128
  349. mindspore/ops/function/other_func.py +8 -8
  350. mindspore/ops/function/parameter_func.py +5 -93
  351. mindspore/ops/function/random_func.py +435 -113
  352. mindspore/ops/function/reshard_func.py +104 -0
  353. mindspore/ops/function/sparse_func.py +4 -4
  354. mindspore/ops/function/sparse_unary_func.py +9 -16
  355. mindspore/ops/function/spectral_func.py +1 -1
  356. mindspore/ops/function/vmap_func.py +16 -15
  357. mindspore/ops/functional.py +355 -346
  358. mindspore/ops/op_info_register.py +18 -45
  359. mindspore/ops/operations/__init__.py +38 -24
  360. mindspore/ops/operations/_grad_ops.py +21 -927
  361. mindspore/ops/operations/_infer_ops.py +19 -0
  362. mindspore/ops/operations/_inner_ops.py +173 -607
  363. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  364. mindspore/ops/operations/_scalar_ops.py +5 -480
  365. mindspore/ops/operations/_sequence_ops.py +6 -36
  366. mindspore/ops/operations/_tensor_array.py +8 -8
  367. mindspore/ops/operations/array_ops.py +106 -2837
  368. mindspore/ops/operations/comm_ops.py +799 -127
  369. mindspore/ops/operations/custom_ops.py +124 -119
  370. mindspore/ops/operations/debug_ops.py +142 -41
  371. mindspore/ops/operations/image_ops.py +1 -217
  372. mindspore/ops/operations/inner_ops.py +5 -40
  373. mindspore/ops/operations/linalg_ops.py +1 -49
  374. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  375. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  376. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  377. mindspore/ops/operations/math_ops.py +666 -4972
  378. mindspore/ops/operations/nn_ops.py +205 -2213
  379. mindspore/ops/operations/other_ops.py +60 -49
  380. mindspore/ops/operations/random_ops.py +50 -54
  381. mindspore/ops/operations/reshard_ops.py +53 -0
  382. mindspore/ops/operations/sparse_ops.py +4 -4
  383. mindspore/ops/primitive.py +216 -103
  384. mindspore/ops_generate/__init__.py +27 -0
  385. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  386. mindspore/ops_generate/arg_handler.py +197 -0
  387. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  388. mindspore/ops_generate/gen_constants.py +36 -0
  389. mindspore/ops_generate/gen_ops.py +1099 -0
  390. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  391. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  392. mindspore/ops_generate/gen_utils.py +209 -0
  393. mindspore/ops_generate/op_proto.py +145 -0
  394. mindspore/ops_generate/pyboost_utils.py +367 -0
  395. mindspore/ops_generate/template.py +261 -0
  396. mindspore/parallel/__init__.py +8 -4
  397. mindspore/parallel/_auto_parallel_context.py +100 -10
  398. mindspore/parallel/_cell_wrapper.py +99 -9
  399. mindspore/parallel/_cost_model_context.py +1 -1
  400. mindspore/parallel/_dp_allreduce_fusion.py +159 -159
  401. mindspore/parallel/_parallel_serialization.py +67 -23
  402. mindspore/parallel/_ps_context.py +1 -1
  403. mindspore/parallel/_recovery_context.py +1 -1
  404. mindspore/parallel/_tensor.py +99 -22
  405. mindspore/parallel/_transformer/__init__.py +1 -1
  406. mindspore/parallel/_transformer/layers.py +1 -1
  407. mindspore/parallel/_transformer/loss.py +1 -1
  408. mindspore/parallel/_transformer/moe.py +1 -1
  409. mindspore/parallel/_transformer/op_parallel_config.py +1 -1
  410. mindspore/parallel/_transformer/transformer.py +2 -2
  411. mindspore/parallel/_utils.py +173 -6
  412. mindspore/parallel/algo_parameter_config.py +8 -10
  413. mindspore/parallel/checkpoint_transform.py +204 -38
  414. mindspore/parallel/cluster/__init__.py +15 -0
  415. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  416. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  417. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  418. mindspore/parallel/cluster/run.py +136 -0
  419. mindspore/parallel/mpi/__init__.py +1 -1
  420. mindspore/parallel/mpi/_mpi_config.py +1 -1
  421. mindspore/parallel/parameter_broadcast.py +151 -0
  422. mindspore/parallel/shard.py +279 -37
  423. mindspore/parallel/transform_safetensors.py +993 -0
  424. mindspore/pgodb140.dll +0 -0
  425. mindspore/pgort140.dll +0 -0
  426. mindspore/profiler/__init__.py +4 -2
  427. mindspore/profiler/common/constant.py +29 -0
  428. mindspore/profiler/common/process_pool.py +41 -0
  429. mindspore/profiler/common/registry.py +47 -0
  430. mindspore/profiler/common/singleton.py +28 -0
  431. mindspore/profiler/common/util.py +153 -0
  432. mindspore/profiler/dynamic_profiler.py +694 -0
  433. mindspore/profiler/envprofiling.py +18 -20
  434. mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
  435. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  436. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  437. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  438. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  439. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  440. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  441. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  442. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  443. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  444. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  445. mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
  446. mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
  447. mindspore/profiler/parser/ascend_flops_generator.py +20 -4
  448. mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
  449. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  450. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  451. mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
  452. mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
  453. mindspore/profiler/parser/ascend_op_generator.py +92 -42
  454. mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
  455. mindspore/profiler/parser/base_timeline_generator.py +25 -25
  456. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  457. mindspore/profiler/parser/framework_parser.py +4 -393
  458. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  459. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  460. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  461. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  462. mindspore/profiler/parser/integrator.py +3 -1
  463. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  464. mindspore/profiler/parser/minddata_parser.py +72 -3
  465. mindspore/profiler/parser/profiler_info.py +94 -7
  466. mindspore/profiler/profiler.py +153 -0
  467. mindspore/profiler/profiling.py +631 -508
  468. mindspore/rewrite/__init__.py +2 -14
  469. mindspore/rewrite/api/node.py +122 -36
  470. mindspore/rewrite/api/pattern_engine.py +2 -3
  471. mindspore/rewrite/api/scoped_value.py +16 -15
  472. mindspore/rewrite/api/symbol_tree.py +45 -29
  473. mindspore/rewrite/ast_helpers/__init__.py +3 -6
  474. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  475. mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
  476. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  477. mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
  478. mindspore/rewrite/common/__init__.py +1 -2
  479. mindspore/rewrite/common/config.py +24 -0
  480. mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
  481. mindspore/rewrite/{namer.py → common/namer.py} +63 -18
  482. mindspore/rewrite/common/namespace.py +118 -0
  483. mindspore/rewrite/node/__init__.py +5 -5
  484. mindspore/rewrite/node/call_function.py +23 -7
  485. mindspore/rewrite/node/cell_container.py +7 -3
  486. mindspore/rewrite/node/control_flow.py +53 -28
  487. mindspore/rewrite/node/node.py +212 -196
  488. mindspore/rewrite/node/node_manager.py +51 -22
  489. mindspore/rewrite/node/node_topological_manager.py +3 -23
  490. mindspore/rewrite/parsers/__init__.py +12 -0
  491. mindspore/rewrite/parsers/arguments_parser.py +8 -9
  492. mindspore/rewrite/parsers/assign_parser.py +637 -413
  493. mindspore/rewrite/parsers/attribute_parser.py +3 -4
  494. mindspore/rewrite/parsers/class_def_parser.py +115 -148
  495. mindspore/rewrite/parsers/constant_parser.py +5 -5
  496. mindspore/rewrite/parsers/container_parser.py +4 -6
  497. mindspore/rewrite/parsers/expr_parser.py +55 -0
  498. mindspore/rewrite/parsers/for_parser.py +31 -98
  499. mindspore/rewrite/parsers/function_def_parser.py +13 -5
  500. mindspore/rewrite/parsers/if_parser.py +28 -10
  501. mindspore/rewrite/parsers/module_parser.py +8 -182
  502. mindspore/rewrite/parsers/parser.py +1 -5
  503. mindspore/rewrite/parsers/parser_register.py +1 -1
  504. mindspore/rewrite/parsers/return_parser.py +5 -10
  505. mindspore/rewrite/parsers/while_parser.py +59 -0
  506. mindspore/rewrite/sparsify/utils.py +1 -1
  507. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  508. mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
  509. mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
  510. mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
  511. mindspore/run_check/_check_version.py +40 -115
  512. mindspore/run_check/run_check.py +1 -1
  513. mindspore/safeguard/rewrite_obfuscation.py +597 -263
  514. mindspore/swresample-4.dll +0 -0
  515. mindspore/swscale-6.dll +0 -0
  516. mindspore/tbbmalloc.dll +0 -0
  517. mindspore/tinyxml2.dll +0 -0
  518. mindspore/train/__init__.py +7 -5
  519. mindspore/train/_utils.py +204 -4
  520. mindspore/train/amp.py +335 -295
  521. mindspore/train/anf_ir_pb2.py +14 -2
  522. mindspore/train/callback/__init__.py +5 -2
  523. mindspore/train/callback/_backup_and_restore.py +5 -5
  524. mindspore/train/callback/_callback.py +4 -4
  525. mindspore/train/callback/_checkpoint.py +220 -43
  526. mindspore/train/callback/_cluster_monitor.py +201 -0
  527. mindspore/train/callback/_early_stop.py +2 -2
  528. mindspore/train/callback/_flops_collector.py +239 -0
  529. mindspore/train/callback/_landscape.py +15 -9
  530. mindspore/train/callback/_loss_monitor.py +5 -5
  531. mindspore/train/callback/_on_request_exit.py +136 -33
  532. mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
  533. mindspore/train/callback/_summary_collector.py +12 -12
  534. mindspore/train/callback/_tft_register.py +352 -0
  535. mindspore/train/callback/_time_monitor.py +3 -3
  536. mindspore/train/data_sink.py +6 -5
  537. mindspore/train/dataset_helper.py +66 -23
  538. mindspore/train/loss_scale_manager.py +2 -2
  539. mindspore/train/metrics/accuracy.py +7 -7
  540. mindspore/train/metrics/confusion_matrix.py +8 -6
  541. mindspore/train/metrics/cosine_similarity.py +6 -4
  542. mindspore/train/metrics/error.py +2 -2
  543. mindspore/train/metrics/metric.py +3 -3
  544. mindspore/train/metrics/perplexity.py +2 -1
  545. mindspore/train/metrics/roc.py +4 -4
  546. mindspore/train/metrics/topk.py +2 -2
  547. mindspore/train/mind_ir_pb2.py +116 -37
  548. mindspore/train/model.py +382 -76
  549. mindspore/train/serialization.py +787 -288
  550. mindspore/train/summary/_summary_adapter.py +1 -1
  551. mindspore/train/summary/summary_record.py +51 -28
  552. mindspore/train/train_thor/convert_utils.py +3 -3
  553. mindspore/turbojpeg.dll +0 -0
  554. mindspore/utils/__init__.py +21 -0
  555. mindspore/utils/utils.py +60 -0
  556. mindspore/vcmeta.dll +0 -0
  557. mindspore/vcruntime140.dll +0 -0
  558. mindspore/vcruntime140_1.dll +0 -0
  559. mindspore/version.py +1 -1
  560. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
  561. mindspore-2.4.0.dist-info/RECORD +1406 -0
  562. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
  563. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
  564. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
  565. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
  566. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
  567. mindspore/gen_ops.py +0 -273
  568. mindspore/include/c_api/ms/abstract.h +0 -67
  569. mindspore/include/c_api/ms/attribute.h +0 -197
  570. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  571. mindspore/include/c_api/ms/base/macros.h +0 -32
  572. mindspore/include/c_api/ms/base/status.h +0 -33
  573. mindspore/include/c_api/ms/base/types.h +0 -282
  574. mindspore/include/c_api/ms/context.h +0 -102
  575. mindspore/include/c_api/ms/graph.h +0 -160
  576. mindspore/include/c_api/ms/node.h +0 -606
  577. mindspore/include/c_api/ms/tensor.h +0 -161
  578. mindspore/include/c_api/ms/value.h +0 -84
  579. mindspore/mindspore_shared_lib.dll +0 -0
  580. mindspore/nn/layer/flash_attention.py +0 -189
  581. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  582. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  583. mindspore/ops/_op_impl/cpu/concat.py +0 -39
  584. mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
  585. mindspore/ops/_op_impl/tbe/__init__.py +0 -47
  586. mindspore/ops/_op_impl/tbe/abs.py +0 -38
  587. mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
  588. mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
  589. mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
  590. mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
  591. mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
  592. mindspore/ops/_op_impl/tbe/acos.py +0 -37
  593. mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
  594. mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
  595. mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
  596. mindspore/ops/_op_impl/tbe/acosh.py +0 -37
  597. mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
  598. mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
  599. mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
  600. mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
  601. mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
  602. mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
  603. mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
  604. mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
  605. mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
  606. mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
  607. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
  608. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
  609. mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
  610. mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
  611. mindspore/ops/_op_impl/tbe/add.py +0 -42
  612. mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
  613. mindspore/ops/_op_impl/tbe/add_n.py +0 -39
  614. mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
  615. mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
  616. mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
  617. mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
  618. mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
  619. mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
  620. mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
  621. mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
  622. mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
  623. mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
  624. mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
  625. mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
  626. mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
  627. mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
  628. mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
  629. mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
  630. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
  631. mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
  632. mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
  633. mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
  634. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
  635. mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
  636. mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
  637. mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
  638. mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
  639. mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
  640. mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
  641. mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
  642. mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
  643. mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
  644. mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
  645. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
  646. mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
  647. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
  648. mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
  649. mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
  650. mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
  651. mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
  652. mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
  653. mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
  654. mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
  655. mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
  656. mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
  657. mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
  658. mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
  659. mindspore/ops/_op_impl/tbe/asin.py +0 -37
  660. mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
  661. mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
  662. mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
  663. mindspore/ops/_op_impl/tbe/asinh.py +0 -37
  664. mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
  665. mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
  666. mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
  667. mindspore/ops/_op_impl/tbe/assign.py +0 -79
  668. mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
  669. mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
  670. mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
  671. mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
  672. mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
  673. mindspore/ops/_op_impl/tbe/atan.py +0 -37
  674. mindspore/ops/_op_impl/tbe/atan2.py +0 -38
  675. mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
  676. mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
  677. mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
  678. mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
  679. mindspore/ops/_op_impl/tbe/atanh.py +0 -37
  680. mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
  681. mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
  682. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
  683. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
  684. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
  685. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
  686. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
  687. mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
  688. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
  689. mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
  690. mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
  691. mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
  692. mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
  693. mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
  694. mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
  695. mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
  696. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
  697. mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
  698. mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
  699. mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
  700. mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
  701. mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
  702. mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
  703. mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
  704. mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
  705. mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
  706. mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
  707. mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
  708. mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
  709. mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
  710. mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
  711. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
  712. mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
  713. mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
  714. mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
  715. mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
  716. mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
  717. mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
  718. mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
  719. mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
  720. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
  721. mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
  722. mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
  723. mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
  724. mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
  725. mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
  726. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
  727. mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
  728. mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
  729. mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
  730. mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
  731. mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
  732. mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
  733. mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
  734. mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
  735. mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
  736. mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
  737. mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
  738. mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
  739. mindspore/ops/_op_impl/tbe/cast.py +0 -55
  740. mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
  741. mindspore/ops/_op_impl/tbe/cdist.py +0 -38
  742. mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
  743. mindspore/ops/_op_impl/tbe/ceil.py +0 -37
  744. mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
  745. mindspore/ops/_op_impl/tbe/celu.py +0 -39
  746. mindspore/ops/_op_impl/tbe/centralization.py +0 -39
  747. mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
  748. mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
  749. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
  750. mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
  751. mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
  752. mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
  753. mindspore/ops/_op_impl/tbe/concat.py +0 -40
  754. mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
  755. mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
  756. mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
  757. mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
  758. mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
  759. mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
  760. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
  761. mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
  762. mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
  763. mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
  764. mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
  765. mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
  766. mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
  767. mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
  768. mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
  769. mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
  770. mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
  771. mindspore/ops/_op_impl/tbe/cos.py +0 -37
  772. mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
  773. mindspore/ops/_op_impl/tbe/cosh.py +0 -37
  774. mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
  775. mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
  776. mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
  777. mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
  778. mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
  779. mindspore/ops/_op_impl/tbe/cummin.py +0 -41
  780. mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
  781. mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
  782. mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
  783. mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
  784. mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
  785. mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
  786. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
  787. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
  788. mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
  789. mindspore/ops/_op_impl/tbe/diag.py +0 -38
  790. mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
  791. mindspore/ops/_op_impl/tbe/dilation.py +0 -40
  792. mindspore/ops/_op_impl/tbe/div.py +0 -41
  793. mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
  794. mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
  795. mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
  796. mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
  797. mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
  798. mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
  799. mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
  800. mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
  801. mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
  802. mindspore/ops/_op_impl/tbe/elu.py +0 -38
  803. mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
  804. mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
  805. mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
  806. mindspore/ops/_op_impl/tbe/equal.py +0 -42
  807. mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
  808. mindspore/ops/_op_impl/tbe/erf.py +0 -37
  809. mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
  810. mindspore/ops/_op_impl/tbe/erfc.py +0 -37
  811. mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
  812. mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
  813. mindspore/ops/_op_impl/tbe/exp.py +0 -40
  814. mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
  815. mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
  816. mindspore/ops/_op_impl/tbe/expm1.py +0 -37
  817. mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
  818. mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
  819. mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
  820. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
  821. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
  822. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
  823. mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
  824. mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
  825. mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
  826. mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
  827. mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
  828. mindspore/ops/_op_impl/tbe/fill.py +0 -56
  829. mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
  830. mindspore/ops/_op_impl/tbe/flatten.py +0 -48
  831. mindspore/ops/_op_impl/tbe/floor.py +0 -37
  832. mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
  833. mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
  834. mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
  835. mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
  836. mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
  837. mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
  838. mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
  839. mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
  840. mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
  841. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
  842. mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
  843. mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
  844. mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
  845. mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
  846. mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
  847. mindspore/ops/_op_impl/tbe/gelu.py +0 -37
  848. mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
  849. mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
  850. mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
  851. mindspore/ops/_op_impl/tbe/ger.py +0 -43
  852. mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
  853. mindspore/ops/_op_impl/tbe/greater.py +0 -43
  854. mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
  855. mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
  856. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
  857. mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
  858. mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
  859. mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
  860. mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
  861. mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
  862. mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
  863. mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
  864. mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
  865. mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
  866. mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
  867. mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
  868. mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
  869. mindspore/ops/_op_impl/tbe/im2col.py +0 -42
  870. mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
  871. mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
  872. mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
  873. mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
  874. mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
  875. mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
  876. mindspore/ops/_op_impl/tbe/inv.py +0 -38
  877. mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
  878. mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
  879. mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
  880. mindspore/ops/_op_impl/tbe/invert.py +0 -37
  881. mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
  882. mindspore/ops/_op_impl/tbe/iou.py +0 -38
  883. mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
  884. mindspore/ops/_op_impl/tbe/is_close.py +0 -40
  885. mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
  886. mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
  887. mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
  888. mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
  889. mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
  890. mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
  891. mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
  892. mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
  893. mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
  894. mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
  895. mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
  896. mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
  897. mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
  898. mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
  899. mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
  900. mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
  901. mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
  902. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
  903. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
  904. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
  905. mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
  906. mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
  907. mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
  908. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
  909. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
  910. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
  911. mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
  912. mindspore/ops/_op_impl/tbe/lerp.py +0 -38
  913. mindspore/ops/_op_impl/tbe/less.py +0 -41
  914. mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
  915. mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
  916. mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
  917. mindspore/ops/_op_impl/tbe/log.py +0 -40
  918. mindspore/ops/_op_impl/tbe/log1p.py +0 -37
  919. mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
  920. mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
  921. mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
  922. mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
  923. mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
  924. mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
  925. mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
  926. mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
  927. mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
  928. mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
  929. mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
  930. mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
  931. mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
  932. mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
  933. mindspore/ops/_op_impl/tbe/lrn.py +0 -41
  934. mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
  935. mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
  936. mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
  937. mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
  938. mindspore/ops/_op_impl/tbe/matmul.py +0 -53
  939. mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
  940. mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
  941. mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
  942. mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
  943. mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
  944. mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
  945. mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
  946. mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
  947. mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
  948. mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
  949. mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
  950. mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
  951. mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
  952. mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
  953. mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
  954. mindspore/ops/_op_impl/tbe/maximum.py +0 -39
  955. mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
  956. mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
  957. mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
  958. mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
  959. mindspore/ops/_op_impl/tbe/minimum.py +0 -40
  960. mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
  961. mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
  962. mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
  963. mindspore/ops/_op_impl/tbe/mish.py +0 -37
  964. mindspore/ops/_op_impl/tbe/mod.py +0 -41
  965. mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
  966. mindspore/ops/_op_impl/tbe/mul.py +0 -37
  967. mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
  968. mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
  969. mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
  970. mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
  971. mindspore/ops/_op_impl/tbe/neg.py +0 -39
  972. mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
  973. mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
  974. mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
  975. mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
  976. mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
  977. mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
  978. mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
  979. mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
  980. mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
  981. mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
  982. mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
  983. mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
  984. mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
  985. mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
  986. mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
  987. mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
  988. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
  989. mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
  990. mindspore/ops/_op_impl/tbe/pack.py +0 -58
  991. mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
  992. mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
  993. mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
  994. mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
  995. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
  996. mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
  997. mindspore/ops/_op_impl/tbe/pdist.py +0 -36
  998. mindspore/ops/_op_impl/tbe/pooling.py +0 -46
  999. mindspore/ops/_op_impl/tbe/population_count.py +0 -38
  1000. mindspore/ops/_op_impl/tbe/pow.py +0 -41
  1001. mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
  1002. mindspore/ops/_op_impl/tbe/prelu.py +0 -37
  1003. mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
  1004. mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
  1005. mindspore/ops/_op_impl/tbe/range.py +0 -39
  1006. mindspore/ops/_op_impl/tbe/real_div.py +0 -38
  1007. mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
  1008. mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
  1009. mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
  1010. mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
  1011. mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
  1012. mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
  1013. mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
  1014. mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
  1015. mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
  1016. mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
  1017. mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
  1018. mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
  1019. mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
  1020. mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
  1021. mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
  1022. mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
  1023. mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
  1024. mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
  1025. mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
  1026. mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
  1027. mindspore/ops/_op_impl/tbe/relu.py +0 -39
  1028. mindspore/ops/_op_impl/tbe/relu6.py +0 -38
  1029. mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
  1030. mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
  1031. mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
  1032. mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
  1033. mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
  1034. mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
  1035. mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
  1036. mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
  1037. mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
  1038. mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
  1039. mindspore/ops/_op_impl/tbe/renorm.py +0 -39
  1040. mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
  1041. mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
  1042. mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
  1043. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
  1044. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
  1045. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
  1046. mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
  1047. mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
  1048. mindspore/ops/_op_impl/tbe/rint.py +0 -37
  1049. mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
  1050. mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
  1051. mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
  1052. mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
  1053. mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
  1054. mindspore/ops/_op_impl/tbe/roll.py +0 -42
  1055. mindspore/ops/_op_impl/tbe/round.py +0 -38
  1056. mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
  1057. mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
  1058. mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
  1059. mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
  1060. mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
  1061. mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
  1062. mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
  1063. mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
  1064. mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
  1065. mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
  1066. mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
  1067. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
  1068. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
  1069. mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
  1070. mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
  1071. mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
  1072. mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
  1073. mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
  1074. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
  1075. mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
  1076. mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
  1077. mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
  1078. mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
  1079. mindspore/ops/_op_impl/tbe/select.py +0 -38
  1080. mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
  1081. mindspore/ops/_op_impl/tbe/selu.py +0 -39
  1082. mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
  1083. mindspore/ops/_op_impl/tbe/sgd.py +0 -62
  1084. mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
  1085. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
  1086. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
  1087. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
  1088. mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
  1089. mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
  1090. mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
  1091. mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
  1092. mindspore/ops/_op_impl/tbe/sign.py +0 -38
  1093. mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
  1094. mindspore/ops/_op_impl/tbe/sin.py +0 -37
  1095. mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
  1096. mindspore/ops/_op_impl/tbe/sinh.py +0 -37
  1097. mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
  1098. mindspore/ops/_op_impl/tbe/slice.py +0 -58
  1099. mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
  1100. mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
  1101. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
  1102. mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
  1103. mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
  1104. mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
  1105. mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
  1106. mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
  1107. mindspore/ops/_op_impl/tbe/softmax.py +0 -37
  1108. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
  1109. mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
  1110. mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
  1111. mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
  1112. mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
  1113. mindspore/ops/_op_impl/tbe/softplus.py +0 -37
  1114. mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
  1115. mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
  1116. mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
  1117. mindspore/ops/_op_impl/tbe/softsign.py +0 -37
  1118. mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
  1119. mindspore/ops/_op_impl/tbe/sort.py +0 -38
  1120. mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
  1121. mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
  1122. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
  1123. mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
  1124. mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
  1125. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
  1126. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
  1127. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
  1128. mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
  1129. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
  1130. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
  1131. mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
  1132. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
  1133. mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
  1134. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
  1135. mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
  1136. mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
  1137. mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
  1138. mindspore/ops/_op_impl/tbe/split_d.py +0 -38
  1139. mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
  1140. mindspore/ops/_op_impl/tbe/split_v.py +0 -39
  1141. mindspore/ops/_op_impl/tbe/splitv.py +0 -39
  1142. mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
  1143. mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
  1144. mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
  1145. mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
  1146. mindspore/ops/_op_impl/tbe/square.py +0 -38
  1147. mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
  1148. mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
  1149. mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
  1150. mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
  1151. mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
  1152. mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
  1153. mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
  1154. mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
  1155. mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
  1156. mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
  1157. mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
  1158. mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
  1159. mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
  1160. mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
  1161. mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
  1162. mindspore/ops/_op_impl/tbe/sub.py +0 -39
  1163. mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
  1164. mindspore/ops/_op_impl/tbe/tan.py +0 -38
  1165. mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
  1166. mindspore/ops/_op_impl/tbe/tanh.py +0 -37
  1167. mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
  1168. mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
  1169. mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
  1170. mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
  1171. mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
  1172. mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
  1173. mindspore/ops/_op_impl/tbe/tile.py +0 -37
  1174. mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
  1175. mindspore/ops/_op_impl/tbe/top_k.py +0 -42
  1176. mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
  1177. mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
  1178. mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
  1179. mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
  1180. mindspore/ops/_op_impl/tbe/transpose.py +0 -60
  1181. mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
  1182. mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
  1183. mindspore/ops/_op_impl/tbe/trunc.py +0 -39
  1184. mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
  1185. mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
  1186. mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
  1187. mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
  1188. mindspore/ops/_op_impl/tbe/unpack.py +0 -38
  1189. mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
  1190. mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
  1191. mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
  1192. mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
  1193. mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
  1194. mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
  1195. mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
  1196. mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
  1197. mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
  1198. mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
  1199. mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
  1200. mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
  1201. mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
  1202. mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
  1203. mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
  1204. mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
  1205. mindspore/ops/_tracefunc.py +0 -241
  1206. mindspore/ops/arg_dtype_cast.py +0 -54
  1207. mindspore/ops/silent_check.py +0 -162
  1208. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  1209. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  1210. mindspore/rewrite/api/tree_node_helper.py +0 -60
  1211. mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
  1212. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
  1213. mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
  1214. mindspore/rewrite/namespace.py +0 -53
  1215. mindspore-2.2.14.dist-info/RECORD +0 -1924
  1216. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  1217. {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/train/amp.py CHANGED
@@ -14,6 +14,11 @@
14
14
  # ============================================================================
15
15
  """Auto mixed precision."""
16
16
  from __future__ import absolute_import
17
+ import inspect
18
+ import types
19
+ from typing import Any
20
+ import functools
21
+ import collections
17
22
 
18
23
  import mindspore as ms
19
24
  from mindspore import nn
@@ -27,8 +32,9 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScal
27
32
  from mindspore import boost, context
28
33
  from mindspore.ops import operations as P
29
34
  from mindspore.ops import Primitive
35
+ from mindspore.ops import auto_generate as gen
30
36
  from mindspore import log as logger
31
-
37
+ from mindspore._c_expression.amp import pop_amp_strategy, push_amp_strategy, create_amp_strategy, AmpLevel
32
38
 
33
39
  AMP_WHITE_LIST = [
34
40
  nn.Conv1d,
@@ -50,19 +56,81 @@ AMP_WHITE_LIST = [
50
56
  P.BatchMatMul,
51
57
  P.PReLU,
52
58
  P.ReLU,
53
- P.Ger
59
+ P.Ger,
54
60
  ]
55
61
 
56
-
57
62
  AMP_BLACK_LIST = [
58
63
  nn.BatchNorm1d,
59
64
  nn.BatchNorm2d,
60
65
  nn.BatchNorm3d,
61
- nn.LayerNorm
66
+ nn.LayerNorm,
67
+ ]
68
+
69
+ AMP_AUTO_WHITE_LIST = [
70
+ P.Conv2D,
71
+ P.Conv3D,
72
+ P.Conv2DTranspose,
73
+ P.Conv3DTranspose,
74
+ gen.Convolution,
75
+ P.MatMul,
76
+ gen.MatMulExt,
77
+ P.BatchMatMul,
78
+ gen.BatchMatMulExt,
79
+ gen.PReLU,
80
+ P.Einsum,
81
+ gen.Dense,
82
+ gen.Addmm,
62
83
  ]
63
84
 
85
+ AMP_AUTO_BLACK_LIST = [
86
+ gen.Pow,
87
+ gen.ACos,
88
+ gen.Asin,
89
+ gen.Cosh,
90
+ P.Erfinv,
91
+ P.Exp,
92
+ P.Expm1,
93
+ P.Log,
94
+ P.Log1p,
95
+ P.Reciprocal,
96
+ P.Rsqrt,
97
+ P.Sinh,
98
+ P.Tan,
99
+ P.Softplus,
100
+ gen.SoftplusExt,
101
+ P.LayerNorm,
102
+ gen.LayerNormExt,
103
+ P.BatchNorm,
104
+ gen.GroupNorm,
105
+ P.KLDivLoss,
106
+ P.SmoothL1Loss,
107
+ P.MultilabelMarginLoss,
108
+ P.SoftMarginLoss,
109
+ P.TripletMarginLoss,
110
+ P.MultiMarginLoss,
111
+ P.BCEWithLogitsLoss,
112
+ P.Pdist,
113
+ P.Cdist,
114
+ P.Renorm,
115
+ ]
116
+
117
+ # Indicates which inputs of primitives need to be converted
118
+ AMP_PRIM_ARG_TABLE = collections.defaultdict(list, {})
119
+
120
+ # Primitives in inner amp black list will not be converted in O2/O3
121
+ _INNER_AMP_BLACK_LIST = []
122
+
64
123
  MS_AMP_BY_REWRITE = False
65
- _amp_cast_op = P.Cast
124
+
125
+
126
+ def amp_cast(value, dtype):
127
+ """This function is used to insert cast operators for tensors during auto mixed precision."""
128
+ if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
129
+ return P.Cast()(value, dtype)
130
+ return value
131
+
132
+ _amp_cast_op = amp_cast
133
+
66
134
 
67
135
  class _OutputTo16(nn.Cell):
68
136
  """Wrap cell for amp. Cast network output back to float16."""
@@ -88,278 +156,185 @@ class _OutputTo32(nn.Cell):
88
156
  return F.mixed_precision_cast(mstype.float32, out)
89
157
 
90
158
 
91
-
92
- def _allow_mix_precision(node, allowed_list, dtype) -> bool:
159
+ def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
93
160
  """
94
- Check whether current node need do mix precision. Follow conditions need to be satisfied:
95
- 1) Type of node is one of (Primitive, nn.Cell)
96
- 2) Node is not Cast Op
97
- 3) to_float(mindspore.float16) is not set in Cell
161
+ Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
162
+ 1) Type of node is CallPrimitive and type of instance is Primitive
163
+ 2) Type of instance is not P.Cast
164
+ 3) force_cast is True, which means one of upper layer cells is under casting
165
+ 4) white_list exist and type of node is in white_list
166
+ 5) black_list exist and type of node is in not black_list
98
167
  """
99
- node_inst = node.get_instance()
100
- if node_inst in allowed_list:
101
- return True
102
- if node.get_targets() is None:
168
+ if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
103
169
  return False
104
- if not issubclass(node.get_instance_type(), (Primitive, nn.Cell)):
170
+ if not inspect.isclass(node.get_instance_type()):
105
171
  return False
106
- if isinstance(node_inst, _amp_cast_op):
172
+ if not issubclass(node.get_instance_type(), Primitive):
107
173
  return False
108
- if issubclass(node.get_instance_type(), nn.Cell):
109
- # if cell is already in allowed_list, it means to_float() is set by amp.
110
- # if cell is not in allowed_list, but has to_float(),
111
- # it means to_float() is set by user.
112
- to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
113
- if hasattr(node_inst, to_float_flag) and getattr(node_inst, to_float_flag):
114
- return False
115
- allowed_list.append(node.get_instance())
116
- return True
174
+ if issubclass(node.get_instance_type(), P.Cast):
175
+ return False
176
+ if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
177
+ return False
178
+ if force_cast:
179
+ return True
180
+ if white_list is not None and node.get_instance_type() in white_list:
181
+ return True
182
+ if black_list is not None and node.get_instance_type() not in black_list:
183
+ return True
184
+ return False
117
185
 
118
186
 
119
- def _insert_cast_operator_process(node, dtype):
120
- """insert cast for operators in white_list."""
121
- dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
122
- new_cast_node = None
123
- stree = node.get_symbol_tree()
124
- # insert cast fp16/bf16 before the primitive operators
125
- if issubclass(node.get_instance_type(), Primitive):
126
- for idx, arg in enumerate(node.get_args()):
127
- position = stree.before(node)
128
- new_node = _amp_cast_op()
129
- cast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, ""])
130
- arg_provider = node.get_handler().get_arg_providers()[idx]
131
- if arg_provider and len(arg_provider[0].get_target_users(arg_provider[1])) > 1:
132
- cast_targets = [stree.unique_name(str(arg))]
133
- else:
134
- cast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
135
- new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
136
- targets=cast_targets,
137
- args=cast_args,
138
- name='incast_{}{}'.format(node.get_name(), idx))
139
- stree.insert(position, new_cast_node)
140
- node.set_arg_by_node(idx, new_cast_node)
141
- # insert cast fp16/bf16 before the Cell operators
142
- elif issubclass(node.get_instance_type(), nn.Cell):
143
- node.get_instance().to_float(dtype)
144
- # ignore if subclass is not one of (Primitive, nn.Cell)
145
- else:
146
- return
147
-
148
- # insert cast float32 after the operators
149
- position = stree.after(node)
150
- new_node = _amp_cast_op()
151
- cast_args = ms.rewrite.ScopedValue.create_name_values([node.get_targets()[0].value,
152
- "mindspore.float32"])
153
- new_cast_node = ms.rewrite.Node.create_call_cell(new_node,
154
- targets=[node.get_targets()[0]],
155
- args=cast_args,
156
- name='outcast_{}'.format(node.get_name()))
157
- # insert node & unique names
158
- stree.insert(position, new_cast_node)
159
- # update argument names
160
- for user in node.get_users():
161
- if user.get_name() == new_cast_node.get_name():
162
- continue
163
- for idx, arg in enumerate(user.get_args()):
164
- if arg == node.get_targets()[0]:
165
- user.set_arg_by_node(idx, new_cast_node)
166
-
167
-
168
- def _insert_cast_operator_white_list(stree, white_list, dtype):
169
- """insert cast for operators in white_list."""
170
- allowed_list = []
171
- # Ignore if net called ".to_float(dtype)"
172
- net = stree.get_handler().get_origin_network()
173
- to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
174
- if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
175
- return
176
- node_list = []
177
- node_list.extend(list(stree.nodes()))
178
- while node_list:
179
- node = node_list.pop()
180
- if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
181
- if MS_AMP_BY_REWRITE:
182
- _insert_cast_for_cell_container(node, dtype, allowed_list, white_list=white_list)
183
- for n in node.get_handler().node_list:
184
- if n.get_node_type() == ms.rewrite.NodeType.Tree:
185
- _insert_cast_operator_white_list(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)),
186
- white_list, dtype)
187
- elif node.get_node_type() == ms.rewrite.NodeType.Tree:
188
- substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
189
- _insert_cast_operator_white_list(substree, white_list, dtype)
190
- elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
191
- if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
192
- nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
193
- node_list.extend(nodes)
194
- elif node.get_instance_type() in white_list and _allow_mix_precision(node, allowed_list, dtype):
195
- _insert_cast_operator_process(node, dtype)
187
+ def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
188
+ """Check whether cell precision is set by user."""
189
+ for flag in ["fp32", "fp16", "bf16"]:
190
+ if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
191
+ return True
192
+ return False
196
193
 
197
194
 
198
- def _insert_cast_for_cell_container(cell_container, dtype, allowed_list, *, white_list=None, black_list=None):
195
+ def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
199
196
  """
200
- Insert cast for cell containers.
201
- Only one of white_list and black_list can be set.
197
+ Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
198
+ be satisfied:
199
+ 1) Type of node is Tree and type of instance is Cell
200
+ 2) Cell.to_float(xxx) is not set by user
201
+ 3) force_cast is True, which means one of upper layer networks is under casting
202
+ 4) white_list exist and type of node is in white_list
203
+ 5) black_list exist and type of node is in not black_list
202
204
  """
205
+ if node.get_node_type() != ms.rewrite.NodeType.Tree:
206
+ return False
207
+ if not inspect.isclass(node.get_instance_type()):
208
+ return False
209
+ if not issubclass(node.get_instance_type(), nn.Cell):
210
+ return False
211
+ if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
212
+ return False
213
+ if _precision_set_by_user(node.get_instance()):
214
+ return False
215
+ if force_cast:
216
+ return True
217
+ if white_list is not None and node.get_instance_type() in white_list:
218
+ return True
219
+ if black_list is not None and node.get_instance_type() not in black_list:
220
+ return True
221
+ return False
222
+
223
+
224
+ def _insert_cast_for_operator(node, dtype):
225
+ """insert cast pair for node."""
226
+ dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
227
+ stree = node.get_symbol_tree()
228
+ # insert cast fp16/bf16 for inputs of node
229
+ for idx, arg in enumerate(node.get_args()):
230
+ if arg.type != ms.rewrite.ValueType.NamingValue:
231
+ continue
232
+ incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
233
+ arg_providers = node.get_arg_providers()
234
+ if not arg_providers or idx not in arg_providers or \
235
+ len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
236
+ # create new target names when argument is used by other node
237
+ incast_targets = [stree.unique_name(f"{arg.value}_var")]
238
+ else:
239
+ incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
240
+ incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
241
+ stree.insert(stree.before(node), incast_node)
242
+ node.set_arg_by_node(idx, incast_node)
243
+ # insert cast fp32 for outputs of node
244
+ for _, target in enumerate(node.get_targets()):
245
+ if target.type != ms.rewrite.ValueType.NamingValue:
246
+ continue
247
+ outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
248
+ [target.scope, "mindspore"])
249
+ outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
250
+ outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
251
+ stree.insert(stree.after(node), outcast_node)
203
252
 
204
- class CastNet(nn.Cell):
205
- """Cast net"""
206
- def __init__(self, dtype):
207
- super().__init__()
208
- self.cast = _amp_cast_op()
209
- self.dtype = dtype
210
-
211
- def construct(self, x):
212
- return self.cast(x, self.dtype)
213
-
214
- cast_flag = False
215
- current_node = None
216
- stree = cell_container.get_symbol_tree()
217
- for node in cell_container.get_handler().nodes():
218
- current_node = ms.rewrite.Node(node)
219
- if (white_list is not None and current_node.get_instance_type() in white_list) or \
220
- (black_list is not None and current_node.get_instance_type() not in black_list) and \
221
- (_allow_mix_precision(current_node, allowed_list, dtype)):
222
- cast_flag = True
223
- current_node.get_instance().to_float(dtype)
224
- elif cast_flag:
225
- # cast next node back to float32
226
- current_node.get_instance().to_float(mstype.float32)
227
- cast_flag = False
228
- if cast_flag and current_node:
229
- # if last node in cell_container is casted to fp16/bf16, insert a cast node to cast value back to fp32
230
- cast_node = ms.rewrite.Node.create_call_cell(cell=CastNet(mstype.float32),
231
- args=[current_node.get_targets()[0]],
232
- targets=[current_node.get_targets()[0]],
233
- name=f"outcast_{cell_container.get_name()}")
234
- stree.insert(stree.after(current_node), cast_node)
253
+
254
+ def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
255
+ """insert cast for operators not in black_list."""
256
+ # get all nodes of stree exclude nodes in subtree.
257
+ all_nodes = stree.all_nodes(False)
258
+ for node in all_nodes:
259
+ if not node.get_targets():
260
+ continue
261
+ if _operator_need_cast(node, force_cast, white_list, black_list):
262
+ _insert_cast_for_operator(node, dtype)
263
+ elif node.get_node_type() == ms.rewrite.NodeType.Tree:
264
+ force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
265
+ if not _precision_set_by_user(node.get_instance()):
266
+ subtree = node.get_sub_tree()
267
+ _insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)
235
268
 
236
269
 
237
270
  def _need_removed_cast_pair(node, dtype):
238
271
  """check whether the cast pairs should be removed."""
239
- dtype_str = "mindspore.bfloat16" if dtype == mstype.bfloat16 else "mindspore.float16"
240
- cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "mindspore.float32"])
272
+ dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
273
+ cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
241
274
  cast_dtype_f16 = cast_dtypes[0]
242
275
  cast_dtype_f32 = cast_dtypes[1]
243
- # current node should be Cast Op to float32
276
+ # current node should be cast fp32
244
277
  if node.get_instance_type() != _amp_cast_op:
245
278
  return False
246
279
  node_cast_type = node.get_args()[1]
247
280
  if node_cast_type != cast_dtype_f32:
248
281
  return False
249
- # all user nodes should be Cast Op to dtype or Cell with to_float(dtype)
282
+ # all user nodes should be cast fp16/bf16
250
283
  if not node.get_users():
251
284
  return False
252
285
  all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
253
286
  for user in node.get_users():
254
- # If ControlFlow node(if statement) exists between current node and user node,
287
+ # If ControlFlow node(e.g. if, for, while) exists between current node and user node,
255
288
  # cast pair should not be removed.
256
289
  middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
257
290
  if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
258
291
  return False
259
- if isinstance(user.get_instance(), nn.Cell):
260
- to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
261
- if not (hasattr(user.get_instance(), to_float_flag) and getattr(user.get_instance(), to_float_flag)):
262
- return False
263
- elif user.get_instance_type() == _amp_cast_op:
264
- user_cast_type = user.get_args()[1]
265
- if user_cast_type != cast_dtype_f16:
266
- return False
267
- else:
292
+ if user.get_instance_type() != _amp_cast_op:
268
293
  return False
294
+ user_cast_type = user.get_args()[1]
295
+ if user_cast_type != cast_dtype_f16:
296
+ return False
297
+ # cast pair detected, check next user
298
+ continue
269
299
  return True
270
300
 
271
301
 
272
- def _removed_cast_pair_process(cast_f32_node):
273
- """remove the duplicated cast operators."""
274
- stree = cast_f32_node.get_symbol_tree()
275
- cast_f32_users = cast_f32_node.get_users()
276
- # remove cast f16 nodes
277
- for user_node in cast_f32_users:
278
- if user_node.get_instance_type() == _amp_cast_op:
279
- cast_f16_node = user_node
280
- # modify arguments using cast_f16's target[0] to cast_f32's args[0], which is f16 type
281
- for cast_f16_user in cast_f16_node.get_users():
282
- for idx, arg in enumerate(cast_f16_user.get_args()):
283
- if arg == cast_f16_node.get_targets()[0]:
284
- cast_f16_user.set_arg(idx, cast_f32_node.get_args()[0])
285
- stree.erase(cast_f16_node)
286
- # update args of cell f16 nodes
287
- elif isinstance(user_node.get_instance(), nn.Cell):
288
- cell_f16_node = user_node
289
- for idx, arg in enumerate(cell_f16_node.get_args()):
290
- if arg == cast_f32_node.get_targets()[0]:
291
- cell_f16_node.set_arg(idx, cast_f32_node.get_args()[0])
292
- # remove the cast f32 node
293
- stree.erase(cast_f32_node)
294
-
295
-
296
302
  def _remove_duplicated_cast(stree, dtype):
297
303
  """remove the duplicated cast operators."""
298
- node_list = []
299
- node_list.extend(list(stree.nodes()))
300
- while node_list:
301
- node = node_list.pop()
302
- if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
303
- for n in node.get_handler().node_list:
304
- if n.get_node_type() == ms.rewrite.NodeType.Tree:
305
- _remove_duplicated_cast(ms.rewrite.TreeNodeHelper.get_sub_tree(ms.rewrite.Node(n)), dtype)
306
- elif node.get_node_type() == ms.rewrite.NodeType.Tree:
307
- substree = ms.rewrite.TreeNodeHelper.get_sub_tree(node)
308
- _remove_duplicated_cast(substree, dtype)
309
- elif node.get_node_type() in [ms.rewrite.NodeType.CallFunction, ms.rewrite.NodeType.ControlFlow]:
310
- if isinstance(node.get_handler(), ms.rewrite.node.NodeManager):
311
- nodes = [ms.rewrite.Node(n) for n in node.get_handler().nodes()]
312
- node_list.extend(nodes)
313
- elif _need_removed_cast_pair(node, dtype):
314
- _removed_cast_pair_process(node)
315
-
316
-
317
- def _auto_white_list(network, white_list, dtype):
318
- """process the white list of network."""
319
- stree = ms.rewrite.SymbolTree.create(network)
320
- _insert_cast_operator_white_list(stree, white_list, dtype)
321
- _remove_duplicated_cast(stree, dtype)
322
- return stree.get_network()
323
-
324
-
325
- def _insert_cast_operator_black_list(stree, black_list, dtype):
326
- """insert cast for operators not in black_list."""
327
- allowed_list = []
328
- # Ignore if net called ".to_float(dtype)"
329
- net = stree.get_handler().get_origin_network()
330
- to_float_flag = "bf16" if dtype == mstype.bfloat16 else "fp16"
331
- if isinstance(net, nn.Cell) and hasattr(net, to_float_flag) and getattr(net, to_float_flag):
332
- return
333
- for node in stree.nodes(all_nodes=True):
334
- if node.get_targets() is None:
335
- continue
336
- if node.get_node_type() == ms.rewrite.NodeType.CellContainer:
337
- _insert_cast_for_cell_container(node, dtype, allowed_list, black_list=black_list)
338
- elif isinstance(node.get_handler().get_node_manager(), ms.rewrite.node.CellContainer):
339
- # nodes in CellContainer are processed by _insert_cast_for_cell_container
340
- continue
341
- elif node.get_instance_type() not in black_list and _allow_mix_precision(node, allowed_list, dtype):
342
- _insert_cast_operator_process(node, dtype)
343
-
344
-
345
- def _remove_duplicated_cast_rewrite(stree, dtype):
346
- """remove the duplicated cast operators."""
347
- for node in stree.nodes(all_nodes=True):
304
+ all_nodes = list(stree.nodes(all_nodes=True))
305
+ for node in all_nodes:
348
306
  if _need_removed_cast_pair(node, dtype):
349
- user_nodes = node.get_users()
350
- # remove cast f16 nodes
351
- for user_node in user_nodes:
352
- if user_node.get_instance_type() == _amp_cast_op:
353
- stree.erase(user_node)
354
- # remove the cast f32 node
307
+ incast_nodes = node.get_users()
308
+ # remove cast fp16/bf16 nodes
309
+ for incast_node in incast_nodes:
310
+ # get_target_users() return {target0: [(user0, arg_idx), ...], ...}
311
+ target_users = list(incast_node.get_target_users().values())
312
+ if not target_users or not target_users[0]:
313
+ continue
314
+ for user_node, arg_idx in target_users[0]:
315
+ user_node.set_arg(arg_idx, incast_node.get_args()[0])
316
+ stree.erase(incast_node)
317
+ # remove the cast fp32 node
355
318
  stree.erase(node)
356
319
 
357
320
 
358
- def _auto_black_list_rewrite(network, black_list, dtype):
321
+ def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
322
+ """Implement auto mixed precision by rewrite"""
323
+ if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
324
+ raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
325
+ # enable rewrite configs for amp
326
+ ms.rewrite.common.namespace._ms_cells_to_subtree = True
327
+ ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
328
+ # insert casts by rewrite
359
329
  stree = ms.rewrite.SymbolTree.create(network)
360
- _insert_cast_operator_black_list(stree, black_list, dtype)
361
- _remove_duplicated_cast_rewrite(stree, dtype)
362
- return stree.get_network()
330
+ _insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
331
+ _remove_duplicated_cast(stree, dtype)
332
+ new_net = stree.get_network()
333
+ # disable rewrite configs
334
+ ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
335
+ ms.rewrite.common.namespace._ms_cells_to_subtree = False
336
+ ms.rewrite.common.config.clear_caches()
337
+ return new_net
363
338
 
364
339
 
365
340
  def _auto_black_list(network, black_list, dtype):
@@ -381,6 +356,42 @@ def _auto_black_list(network, black_list, dtype):
381
356
  return network
382
357
 
383
358
 
359
+ class amp_decorator:
360
+ """
361
+ Auto mixed precision decorator.
362
+ Type of lists: List[Tuple[str, List[int]]]
363
+ """
364
+ def __init__(self, amp_level, amp_dtype, white_list, black_list):
365
+ self.amp_level = amp_level
366
+ self.amp_dtype = amp_dtype
367
+ self.white_list = white_list
368
+ self.black_list = black_list
369
+
370
+ def __enter__(self):
371
+ push_amp_strategy(self.amp_level, self.amp_dtype, self.white_list, self.black_list)
372
+
373
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
374
+ pop_amp_strategy()
375
+
376
+
377
+ def _set_amp_decorator(obj, amp_level, amp_dtype, white_list, black_list):
378
+ """
379
+ Set auto mixed precision context decorator for object.
380
+ Type of lists: List[Tuple[str, List[int]]]
381
+ """
382
+ if inspect.isfunction(obj) or inspect.ismethod(obj):
383
+ @functools.wraps(obj)
384
+ def wrapper(*args, **kwargs):
385
+ with amp_decorator(amp_level, amp_dtype, white_list, black_list):
386
+ return obj(*args, **kwargs)
387
+ return wrapper
388
+ if isinstance(obj, nn.Cell):
389
+ obj.construct = types.MethodType(
390
+ _set_amp_decorator(obj.construct.__func__, amp_level, amp_dtype, white_list, black_list), obj)
391
+ return obj
392
+ raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell or function, bot got {type(obj)}.")
393
+
394
+
384
395
  def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
385
396
  """
386
397
  Returns a network processed with auto mixed precision.
@@ -391,26 +402,44 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
391
402
  converted to lower precision float, and calculation results are converted back to full precision float,
392
403
  i.e. ``mstype.float32`` .
393
404
 
394
- The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
395
- operators are specifically converted.
405
+ The `amp_level` and its corresponding lists determine which cells and operators are converted.
396
406
 
397
- The current built-in whitelist contents are:
407
+ When `amp_level` is set to ``O0``, no cells and operators are converted.
398
408
 
399
- [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
400
- :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
401
- :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
402
- :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
403
- :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
404
- :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
405
- :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
409
+ When `amp_level` is set to ``O1``, cells and operators in whitelist will be converted to lower precision
410
+ operations. For details on whitelist, refer to :func:`mindspore.amp.get_white_list`.
406
411
 
407
- The current built-in blacklist contents are:
412
+ When `amp_level` is set to ``O2``, cells in blacklist will maintain full precision, and cells outside the
413
+ list will be converted to low precision. For details on blacklist, refer to :func:`mindspore.amp.get_black_list`.
408
414
 
409
- [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
410
- :class:`mindspore.nn.LayerNorm`]
415
+ When `amp_level` is set to ``O3``, all cells will be converted to low precision.
416
+
417
+ When `amp_level` is set to ``auto``, operators in `auto_whitelist` will be converted to lower precision
418
+ operations, operators in `auto_blacklist` will be converted to full precision operations, operators in
419
+ `promote_list` will be converted to the higher accuracy float type of the operator inputs, and operators
420
+ not listed will run in the type defined by their inputs.
421
+
422
+ Operators in `auto_whitelist` are:
423
+
424
+ ``Conv2D``, ``Conv3D``, ``Conv2DTranspose``, ``Conv3DTranspose``, ``Convolution``, ``MatMul``, ``MatMulExt``,
425
+ ``BatchMatMul``, ``BatchMatMulExt``, ``PReLU``, ``Einsum``, ``Dense``, ``Addmm``
426
+
427
+ Operators in `auto_blacklist` are:
428
+
429
+ ``Pow``, ``ACos``, ``Asin``, ``Cosh``, ``Erfinv``, ``Exp``, ``Expm1``, ``Log``, ``Log1p``, ``Reciprocal``,
430
+ ``Rsqrt``, ``Sinh``, ``Tan``, ``Softplus``, ``SoftplusExt``, ``LayerNorm``, ``LayerNormExt``, ``BatchNorm``,
431
+ ``GroupNorm``, ``KLDivLoss``, ``SmoothL1Loss``, ``MultilabelMarginLoss``, ``SoftMarginLoss``,
432
+ ``TripletMarginLoss``, ``MultiMarginLoss``, ``BCEWithLogitsLoss``, ``Pdist``, ``Cdist``, ``Renorm``,
433
+ ``ReduceProd``, ``Softmax``, ``LogSoftmax``, ``CumProd``, ``CumSum``, ``CumsumExt``, ``ProdExt``, ``SumExt``,
434
+ ``Norm``
435
+
436
+ Operators in `promote_list` are:
437
+
438
+ ``Addcdiv``, ``Addcmul``, ``Cross``, ``_PyboostCrossPrim``, ``Dot``, ``GridSampler2D``, ``GridSampler3D``,
439
+ ``BiasAdd``
411
440
 
412
441
  For details on automatic mixed precision, refer to
413
- `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/r2.2/advanced/mixed_precision.html>`_ .
442
+ `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ .
414
443
 
415
444
  Note:
416
445
  - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
@@ -418,10 +447,18 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
418
447
  - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
419
448
  mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
420
449
  need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
450
+ - When `amp_level` is set to ``auto``, the output of the network may be lower precision. In this case, you
451
+ may need to manually convert the type to avoid type inconsistency errors of the loss function.
452
+ - When `amp_level` is set to ``auto``, and cells in the network are configured with `to_float`, the accuracy
453
+ specified by `to_float` takes effect first.
454
+
455
+ .. warning::
456
+ ``auto`` level of `amp_level` is an experimental API that is subject to change or deletion.
421
457
 
422
458
  Args:
423
- network (Cell): Definition of the network.
424
- amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
459
+ network (Union[Cell, function]): Definition of the network. Function type is supported only when `amp_level`
460
+ is set to ``auto`` .
461
+ amp_level (str): Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
425
462
 
426
463
  - "O0": Do not change.
427
464
  - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
@@ -429,25 +466,34 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
429
466
  - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
430
467
  to lower precision operations.
431
468
  - "O3": Cast network to lower precision.
469
+ - "auto": Operators in `auto_whitelist` will be converted to lower precision operations, operators in
470
+ `auto_blacklist` will be converted to full precision, operators in `promote_list` will be converted
471
+ to the higher accuracy float type of the operator inputs, and operators not listed will run in the
472
+ type defined by their inputs.
432
473
 
433
474
  dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
434
475
  default: ``mstype.float16`` .
435
476
 
436
477
  Raises:
437
- TypeError: If `network` is not a Cell.
478
+ TypeError: If `network` is not a Cell or a function.
438
479
  ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
439
480
  ValueError: If `amp_level` is not within the supported range.
440
481
 
441
482
  Examples:
442
483
  >>> from mindspore import amp
443
484
  >>> # Define the network structure of LeNet5. Refer to
444
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
485
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
445
486
  >>> network = LeNet5()
446
487
  >>> amp_level = "O1"
447
488
  >>> net = amp.auto_mixed_precision(network, amp_level)
448
489
  """
449
490
  if not isinstance(network, nn.Cell):
450
- raise TypeError("The network type should be Cell.")
491
+ if amp_level == "auto":
492
+ if not inspect.isfunction(network) and not inspect.ismethod(network):
493
+ raise TypeError("For amp_level 'auto', the network type should be Cell or function.")
494
+ # function is supported for amp_level 'auto'
495
+ else:
496
+ raise TypeError(f"For amp_level '{amp_level}', the network type should be Cell.")
451
497
 
452
498
  if dtype not in (mstype.float16, mstype.bfloat16):
453
499
  raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
@@ -456,27 +502,35 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
456
502
  return network
457
503
 
458
504
  # Return network if the same amp level has already been configurated
459
- if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
505
+ if hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O1", "O2", "O3", "auto"):
460
506
  logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
461
507
  f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
462
508
  f"degradation.")
463
509
 
464
510
  if amp_level == "O1":
465
- network = _auto_white_list(network, AMP_WHITE_LIST, dtype)
511
+ network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST)
466
512
  elif amp_level == "O2":
467
513
  if MS_AMP_BY_REWRITE:
468
- network = _auto_black_list_rewrite(network, AMP_BLACK_LIST, dtype)
514
+ network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST)
469
515
  else:
470
516
  network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
471
517
  network = _OutputTo32(network)
472
518
  elif amp_level == "O3":
473
519
  if MS_AMP_BY_REWRITE:
474
- network = _auto_black_list_rewrite(network, [], dtype)
520
+ network = _auto_mixed_precision_rewrite(network, dtype, black_list=[])
475
521
  else:
476
522
  network.to_float(dtype)
477
523
  network = _OutputTo32(network)
524
+ elif amp_level == "auto":
525
+ white_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_WHITE_LIST]
526
+ black_list = [(prim.__name__, AMP_PRIM_ARG_TABLE[prim]) for prim in AMP_AUTO_BLACK_LIST]
527
+ # set amp_strategy attribute for the object
528
+ amp_strategy = create_amp_strategy(AmpLevel.AmpAuto, dtype, white_list, black_list)
529
+ setattr(network, "amp_strategy", amp_strategy)
530
+ # set amp_strategy context decorator for the object
531
+ network = _set_amp_decorator(network, AmpLevel.AmpAuto, dtype, white_list, black_list)
478
532
  else:
479
- raise ValueError("The amp level {} is not supported".format(amp_level))
533
+ raise ValueError(f"The amp level {amp_level} is not supported")
480
534
 
481
535
  setattr(network, "_amp_level", amp_level)
482
536
 
@@ -516,6 +570,10 @@ _config_level = {
516
570
  "O3": {
517
571
  "keep_batchnorm_fp32": False,
518
572
  "cast_model_type": mstype.float16,
573
+ "loss_scale_manager": None},
574
+ "auto": {
575
+ "keep_batchnorm_fp32": False,
576
+ "cast_model_type": mstype.float32,
519
577
  "loss_scale_manager": None}}
520
578
 
521
579
 
@@ -540,20 +598,11 @@ def _check_kwargs(key_words):
540
598
  def _check_level(level, boost_level):
541
599
  """Check level."""
542
600
  if not isinstance(level, str):
543
- raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
544
- but got type {}.".format(type(level)))
601
+ raise TypeError(f"The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'],"
602
+ f"but got type {type(level)}.")
545
603
  validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
546
604
  validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
547
605
 
548
- if level == "auto":
549
- device_target = context.get_context('device_target')
550
- if device_target == "GPU":
551
- level = "O2"
552
- elif device_target == "Ascend":
553
- level = "O3"
554
- else:
555
- raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
556
-
557
606
  enable_boost = False
558
607
  if boost_level in ["O1", "O2"]:
559
608
  enable_boost = True
@@ -578,7 +627,8 @@ def _add_loss_network(network, loss_fn, cast_model_type):
578
627
  return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
579
628
 
580
629
  validator.check_value_type('loss_fn', loss_fn, nn.Cell)
581
- if cast_model_type == mstype.float16:
630
+ if cast_model_type in (mstype.float16, mstype.bfloat16) or \
631
+ (hasattr(network, "_amp_level") and getattr(network, "_amp_level") in ("O2", "O3", "auto")):
582
632
  network = WithLossCell(network, loss_fn)
583
633
  else:
584
634
  network = nn.WithLossCell(network, loss_fn)
@@ -634,20 +684,10 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
634
684
  Default: ``None`` .
635
685
  level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
636
686
 
637
- - 'O0': Do not change.
638
- - 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32.
639
- The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
640
- Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
641
- - 'O2': Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
642
- using dynamic loss scale.
643
- - 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
644
- - 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
645
- level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
646
- scenarios. User should specify the level for special network.
647
-
648
- 'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
649
- `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
650
- `kwargs`.
687
+ For details on amp level, refer to :func:`mindspore.amp.auto_mixed_precision`.
688
+
689
+ Property of `keep_batchnorm_fp32`, `cast_model_type` and `loss_scale_manager` determined by `level`
690
+ setting may be overwritten by settings in `kwargs`.
651
691
 
652
692
  boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
653
693
  training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
@@ -670,13 +710,13 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
670
710
  take no effect on this property.
671
711
 
672
712
  Raises:
673
- ValueError: If device is CPU, property `loss_scale_manager` is not `None` or `FixedLossScaleManager`
674
- (with property `drop_overflow_update=False` ).
713
+ ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
714
+ :class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ).
675
715
 
676
716
  Examples:
677
717
  >>> from mindspore import amp, nn
678
718
  >>> # Define the network structure of LeNet5. Refer to
679
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
719
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
680
720
  >>> network = LeNet5()
681
721
  >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
682
722
  >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
@@ -728,7 +768,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
728
768
 
729
769
  def get_white_list():
730
770
  """
731
- Provide a copy of internal white list used by auto mixed precision.
771
+ Provide a copy of internal white list used by auto mixed precision with `amp_level` set to ``O1``.
732
772
 
733
773
  The current built-in whitelist contents are:
734
774
 
@@ -766,7 +806,7 @@ def get_white_list():
766
806
 
767
807
  def get_black_list():
768
808
  """
769
- Provide a copy of internal black list used by auto mixed precision.
809
+ Provide a copy of internal black list used by auto mixed precision with `amp_level` set to ``O2``.
770
810
 
771
811
  The current built-in blacklist contents are:
772
812
 
@@ -789,7 +829,6 @@ def get_black_list():
789
829
 
790
830
  def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
791
831
  """
792
- Custom mixed precision by setting whitelist or blacklist.
793
832
  When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
794
833
  When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
795
834
  Only one of `white_list` and `black_list` should be provided.
@@ -823,7 +862,7 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
823
862
  Examples:
824
863
  >>> from mindspore import amp, nn
825
864
  >>> # Define the network structure of LeNet5. Refer to
826
- >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
865
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
827
866
  >>> net = LeNet5()
828
867
  >>> custom_white_list = amp.get_white_list()
829
868
  >>> custom_white_list.append(nn.Flatten)
@@ -844,11 +883,11 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
844
883
 
845
884
  if white_list is not None:
846
885
  _list_check(white_list, "white_list")
847
- network = _auto_white_list(network, white_list, dtype)
886
+ network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list)
848
887
  else:
849
888
  _list_check(black_list, "black_list")
850
889
  if MS_AMP_BY_REWRITE:
851
- network = _auto_black_list_rewrite(network, black_list, dtype)
890
+ network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list)
852
891
  else:
853
892
  network = _auto_black_list(network, black_list, dtype)
854
893
  network = _OutputTo32(network)
@@ -883,7 +922,8 @@ def _list_check(custom_list: list, list_name: str):
883
922
  if elem not in custom_list:
884
923
  logger.warning(f"{elem} is removed from internal black list.")
885
924
 
886
- def _config_amp(*, enable_rewrite: bool = None, cast_op: type = None): # pylint: disable=unused-variable
925
+
926
+ def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable
887
927
  """Configure auto mixed precision."""
888
928
  global MS_AMP_BY_REWRITE
889
929
  global _amp_cast_op