mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  import numpy as np
17
17
  from mindspore import _checkparam as validator
18
18
  from mindspore.ops import operations as P
19
+ from mindspore.ops import functional as F
19
20
  from mindspore.common import dtype as mstype
20
21
  import mindspore.nn as nn
21
22
  from .distribution import Distribution
@@ -125,7 +126,6 @@ class TransformedDistribution(Distribution):
125
126
  self.cast_base = P.Cast()
126
127
  self.equal_base = P.Equal()
127
128
  self.select_base = P.Select()
128
- self.fill_base = P.Fill()
129
129
 
130
130
  # broadcast bijector batch_shape and distribution batch_shape
131
131
  self._broadcast_shape = self._broadcast_bijector_dist()
@@ -176,9 +176,9 @@ class TransformedDistribution(Distribution):
176
176
  """
177
177
  if self.batch_shape is None or self.bijector.batch_shape is None:
178
178
  return None
179
- bijector_shape_tensor = self.fill_base(
179
+ bijector_shape_tensor = F.fill(
180
180
  self.dtype, self.bijector.batch_shape, 0.0)
181
- dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
181
+ dist_shape_tensor = F.fill(self.dtype, self.batch_shape, 0.0)
182
182
  return (bijector_shape_tensor + dist_shape_tensor).shape
183
183
 
184
184
  def _cdf(self, value, *args, **kwargs):
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
  """Uniform Distribution"""
16
16
  import numpy as np
17
+ from mindspore.ops import functional as F
17
18
  from mindspore.ops import operations as P
18
19
  from mindspore.ops import composite as C
19
20
  from mindspore import _checkparam as Validator
@@ -170,7 +171,6 @@ class Uniform(Distribution):
170
171
  self.cast = P.Cast()
171
172
  self.const = P.ScalarToTensor()
172
173
  self.dtypeop = P.DType()
173
- self.fill = P.Fill()
174
174
  self.less = P.Less()
175
175
  self.lessequal = P.LessEqual()
176
176
  self.logicaland = P.LogicalAnd()
@@ -287,10 +287,10 @@ class Uniform(Distribution):
287
287
  value = self._check_value(value, 'value')
288
288
  value = self.cast(value, self.dtype)
289
289
  low, high = self._check_param_type(low, high)
290
- neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
290
+ neg_ones = F.fill(self.dtype, self.shape(value), -1.0)
291
291
  prob = self.exp(neg_ones * self.log(high - low))
292
292
  broadcast_shape = self.shape(prob)
293
- zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
293
+ zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
294
294
  comp_lo = self.less(value, low)
295
295
  comp_hi = self.lessequal(value, high)
296
296
  less_than_low = self.select(comp_lo, zeros, prob)
@@ -316,7 +316,7 @@ class Uniform(Distribution):
316
316
  kl = self.log(high_b - low_b) - self.log(high_a - low_a)
317
317
  comp = self.logicaland(self.lessequal(
318
318
  low_b, low_a), self.lessequal(high_a, high_b))
319
- inf = self.fill(self.dtypeop(kl), self.shape(kl), np.inf)
319
+ inf = F.fill(self.dtypeop(kl), self.shape(kl), np.inf)
320
320
  return self.select(comp, kl, inf)
321
321
 
322
322
  def _cdf(self, value, low=None, high=None):
@@ -338,8 +338,8 @@ class Uniform(Distribution):
338
338
  low, high = self._check_param_type(low, high)
339
339
  prob = (value - low) / (high - low)
340
340
  broadcast_shape = self.shape(prob)
341
- zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
342
- ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
341
+ zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
342
+ ones = F.fill(self.dtypeop(prob), broadcast_shape, 1.0)
343
343
  comp_lo = self.less(value, low)
344
344
  comp_hi = self.less(value, high)
345
345
  less_than_low = self.select(comp_lo, zeros, prob)
@@ -99,7 +99,7 @@ class WithLossCell(Cell):
99
99
  >>> from mindspore import Tensor, nn
100
100
  >>> import numpy as np
101
101
  >>> # Define the network structure of LeNet5. Refer to
102
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
102
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
103
103
  >>> net = LeNet5()
104
104
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
105
105
  >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
@@ -115,8 +115,7 @@ class WithLossCell(Cell):
115
115
  super(WithLossCell, self).__init__(auto_prefix=False)
116
116
  self._backbone = backbone
117
117
  self._loss_fn = loss_fn
118
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
119
- self._jit_config_dict = backbone.jit_config_dict
118
+ self._get_attr_from_cell(backbone)
120
119
 
121
120
  def construct(self, data, label):
122
121
  out = self._backbone(data)
@@ -133,7 +132,7 @@ class WithLossCell(Cell):
133
132
  Examples:
134
133
  >>> from mindspore import nn
135
134
  >>> # Define the network structure of LeNet5. Refer to
136
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
135
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
137
136
  >>> net = LeNet5()
138
137
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
139
138
  >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
@@ -176,7 +175,7 @@ class WithGradCell(Cell):
176
175
  >>> import mindspore as ms
177
176
  >>> from mindspore import nn
178
177
  >>> # Defined a network without loss function, taking LeNet5 as an example.
179
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
178
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
180
179
  >>> net = LeNet5()
181
180
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
182
181
  >>> grad_net = nn.WithGradCell(net, loss_fn)
@@ -199,8 +198,7 @@ class WithGradCell(Cell):
199
198
  else:
200
199
  self.network_with_loss = WithLossCell(self.network, self.loss_fn)
201
200
  self.network_with_loss.set_train()
202
- if isinstance(network, Cell) and network.jit_config_dict:
203
- self._jit_config_dict = network.jit_config_dict
201
+ self._get_attr_from_cell(network)
204
202
 
205
203
  def construct(self, *inputs):
206
204
  weights = self.weights
@@ -219,7 +217,7 @@ class ForwardValueAndGrad(Cell):
219
217
  The backward graph will be created in the gradient function to calculating gradient.
220
218
 
221
219
  Args:
222
- network (Cell): The training network.
220
+ network (Union[Cell, Function, MethodType]): The training network.
223
221
  weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
224
222
  Default: ``None`` .
225
223
  get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
@@ -302,8 +300,7 @@ class ForwardValueAndGrad(Cell):
302
300
  self.get_by_list = get_by_list
303
301
  self.sens_param = sens_param
304
302
  self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
305
- if isinstance(network, Cell) and network.jit_config_dict:
306
- self._jit_config_dict = network.jit_config_dict
303
+ self._get_attr_from_cell(network)
307
304
 
308
305
  def construct(self, *inputs):
309
306
  grad_inputs = inputs
@@ -349,7 +346,7 @@ class TrainOneStepCell(Cell):
349
346
  Examples:
350
347
  >>> import mindspore.nn as nn
351
348
  >>> # Define the network structure of LeNet5. Refer to
352
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
349
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
353
350
  >>> net = LeNet5()
354
351
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
355
352
  >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
@@ -414,8 +411,7 @@ class TrainOneStepCell(Cell):
414
411
  create_group(server_group_name, group_list[current_index])
415
412
  group = server_group_name
416
413
  self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
417
- if isinstance(network, Cell) and network.jit_config_dict:
418
- self._jit_config_dict = network.jit_config_dict
414
+ self._get_attr_from_cell(network)
419
415
 
420
416
  def construct(self, *inputs):
421
417
  if not self.sense_flag:
@@ -514,8 +510,7 @@ class _VirtualDatasetCell(Cell):
514
510
  super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
515
511
  self._backbone = backbone
516
512
  self._virtual_dataset = _VirtualDataset()
517
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
518
- self._jit_config_dict = backbone.jit_config_dict
513
+ self._get_attr_from_cell(backbone)
519
514
 
520
515
  def construct(self, *inputs):
521
516
  output = self._virtual_dataset(*inputs)
@@ -524,6 +519,8 @@ class _VirtualDatasetCell(Cell):
524
519
 
525
520
  @_primexpr
526
521
  def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
522
+ if F.isconstant(input_shape[0]) is False:
523
+ return
527
524
  if input_shape[0] % micro_size != 0:
528
525
  raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
529
526
  f"divided by micro size({micro_size})")
@@ -548,8 +545,8 @@ class _MicroBatch(Cell):
548
545
  for each_input in inputs:
549
546
  input_shape = self.shape(each_input)
550
547
  _check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
551
- micro_batch_begin = i * input_shape[0] // self.micro_size
552
- micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
548
+ micro_batch_begin = (input_shape[0] // self.micro_size) * i
549
+ micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1)
553
550
  strided_slice_begin = (micro_batch_begin,)
554
551
  strided_slice_strides = (1,)
555
552
  for _ in range(len(input_shape) - 1):
@@ -589,7 +586,7 @@ class MicroBatchInterleaved(Cell):
589
586
  Examples:
590
587
  >>> import mindspore.nn as nn
591
588
  >>> # Define the network structure of LeNet5. Refer to
592
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
589
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
593
590
  >>> net = LeNet5()
594
591
  >>> net = nn.MicroBatchInterleaved(net, 2)
595
592
  """
@@ -610,8 +607,7 @@ class MicroBatchInterleaved(Cell):
610
607
  interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
611
608
  interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
612
609
  self.interleave_inputs.append(interleave_data)
613
- if isinstance(network, Cell) and network.jit_config_dict:
614
- self._jit_config_dict = network.jit_config_dict
610
+ self._get_attr_from_cell(network)
615
611
 
616
612
  def construct(self, *inputs):
617
613
  output = 0.0
@@ -638,7 +634,7 @@ class PipelineCell(Cell):
638
634
  Examples:
639
635
  >>> import mindspore.nn as nn
640
636
  >>> # Define the network structure of LeNet5. Refer to
641
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
637
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
642
638
  >>> net = LeNet5()
643
639
  >>> net = nn.PipelineCell(net, 4)
644
640
  """
@@ -648,13 +644,70 @@ class PipelineCell(Cell):
648
644
  self.micro_inputs = nn.CellList()
649
645
  self.micro_size = micro_size
650
646
  self.add_list = []
647
+ if not isinstance(network, Cell):
648
+ raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
649
+ "but got the type : {}.".format(type(network)))
650
+ if not isinstance(micro_size, int):
651
+ raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
652
+ "but got the type : {}.".format(type(micro_size)))
653
+ if micro_size <= 0:
654
+ raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
655
+ "but got {}.".format(micro_size))
651
656
  for i in range(micro_size):
652
657
  micro_input = _MicroBatch(micro_size)
653
658
  self.micro_inputs.append(micro_input)
654
659
  self.add = P.Add().add_prim_attr("pipeline_end", i)
655
660
  self.add_list.append(self.add)
656
- if isinstance(network, Cell) and network.jit_config_dict:
657
- self._jit_config_dict = network.jit_config_dict
661
+ self._get_attr_from_cell(network)
662
+
663
+ def construct(self, *inputs):
664
+ ret = None
665
+ for i in range(self.micro_size):
666
+ micro_input = self.micro_inputs[i](i, *inputs)
667
+ output = self.network(*micro_input)
668
+ if ret is not None:
669
+ ret = self.add_list[i](ret, output)
670
+ else:
671
+ ret = output
672
+ return ret
673
+
674
+ class GradAccumulationCell(Cell):
675
+ """
676
+ Wrap the network with Micro Batch.
677
+
678
+ Args:
679
+ network (Cell): The target network to wrap.
680
+ micro_size (int): MicroBatch size.
681
+
682
+ Supported Platforms:
683
+ ``Ascend`` ``GPU``
684
+
685
+ Examples:
686
+ >>> net = Net()
687
+ >>> net = GradAccumulationCell(net, 4)
688
+ """
689
+ def __init__(self, network, micro_size):
690
+ super(GradAccumulationCell, self).__init__(auto_prefix=False)
691
+ self.network = network
692
+ self.micro_inputs = nn.CellList()
693
+ self.micro_size = micro_size
694
+ self.add_list = []
695
+ if not isinstance(network, Cell):
696
+ raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
697
+ "but got the type : {}.".format(type(network)))
698
+ if not isinstance(micro_size, int):
699
+ raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
700
+ "but got the type : {}.".format(type(micro_size)))
701
+ if micro_size <= 0:
702
+ raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, "
703
+ "but got {}.".format(micro_size))
704
+ for i in range(micro_size):
705
+ micro_input = _MicroBatch(micro_size)
706
+ micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size)
707
+ self.micro_inputs.append(micro_input)
708
+ self.add = P.Add().add_prim_attr("forward_end", i)
709
+ self.add_list.append(self.add)
710
+ self._get_attr_from_cell(network)
658
711
 
659
712
  def construct(self, *inputs):
660
713
  ret = None
@@ -674,23 +727,22 @@ def _pipeline_clear_grad(accu_grad, grad):
674
727
  return F.assign(accu_grad, zeros)
675
728
 
676
729
 
677
- class _TrainPipelineAccuStepCell(TrainOneStepCell):
730
+ class _TrainGradAccuStepCell(TrainOneStepCell):
678
731
  """
679
732
  Wraps the network with an optimizer in pipeline mode.
680
733
  """
681
734
  def __init__(self, network, optimizer, sens=None):
682
- super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
735
+ super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens)
683
736
  self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
684
737
  self.hyper_map = ops.HyperMap()
685
738
  self.opt_shard = _get_enable_parallel_optimizer()
686
- if isinstance(network, Cell) and network.jit_config_dict:
687
- self._jit_config_dict = network.jit_config_dict
739
+ self._get_attr_from_cell(network)
688
740
 
689
741
  def construct(self, *inputs):
690
742
  if not self.sense_flag:
691
743
  return self._no_sens_impl(*inputs)
692
744
  loss = self.network(*inputs)
693
- sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
745
+ sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
694
746
  grads = self.grad(self.network, self.weights)(*inputs, sens)
695
747
  accu_grads = ops.depend(self.accu_grads, grads)
696
748
  if self.opt_shard:
@@ -735,7 +787,7 @@ class VirtualDatasetCellTriple(Cell):
735
787
  Examples:
736
788
  >>> import mindspore.nn as nn
737
789
  >>> # Define the network structure of LeNet5. Refer to
738
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
790
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
739
791
  >>> net = LeNet5()
740
792
  >>> net = nn.VirtualDatasetCellTriple(net)
741
793
  """
@@ -744,8 +796,7 @@ class VirtualDatasetCellTriple(Cell):
744
796
  super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
745
797
  logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
746
798
  self._backbone = backbone
747
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
748
- self._jit_config_dict = backbone.jit_config_dict
799
+ self._get_attr_from_cell(backbone)
749
800
 
750
801
  def construct(self, a, b, c):
751
802
  return self._backbone(a, b, c)
@@ -779,7 +830,7 @@ class WithEvalCell(Cell):
779
830
  Examples:
780
831
  >>> import mindspore.nn as nn
781
832
  >>> # Define a forward network without loss function, taking LeNet5 as an example.
782
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
833
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
783
834
  >>> net = LeNet5()
784
835
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
785
836
  >>> eval_net = nn.WithEvalCell(net, loss_fn)
@@ -790,8 +841,7 @@ class WithEvalCell(Cell):
790
841
  self._network = network
791
842
  self._loss_fn = loss_fn
792
843
  self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
793
- if isinstance(network, Cell) and network.jit_config_dict:
794
- self._jit_config_dict = network.jit_config_dict
844
+ self._get_attr_from_cell(network)
795
845
 
796
846
  def construct(self, data, label):
797
847
  outputs = self._network(data)
@@ -314,12 +314,15 @@ class DistributedGradReducer(Cell):
314
314
  Before running the following examples, you need to configure the communication environment variables.
315
315
 
316
316
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
317
- Please see the `Ascend tutorial
318
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/train_ascend.html#preparations>`_
317
+ Please see the `rank table Startup
318
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
319
319
  for more details.
320
320
 
321
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
322
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/train_gpu.html#preparation>`_ .
321
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
322
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
323
+
324
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
325
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
323
326
 
324
327
  This example should be run with multiple devices.
325
328
 
@@ -356,7 +359,7 @@ class DistributedGradReducer(Cell):
356
359
  ... def construct(self, *args):
357
360
  ... weights = self.weights
358
361
  ... loss = self.network(*args)
359
- ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
362
+ ... sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
360
363
  ... grads = self.grad(self.network, weights)(*args, sens)
361
364
  ... if self.reducer_flag:
362
365
  ... # apply grad reducer on grads
@@ -15,6 +15,7 @@
15
15
  """Loss scale cell for loss scale training."""
16
16
  from __future__ import absolute_import
17
17
 
18
+ import os
18
19
  import mindspore.context as context
19
20
  from mindspore.context import ParallelMode
20
21
  from mindspore.parallel._utils import _get_enable_parallel_optimizer
@@ -30,6 +31,7 @@ from mindspore.ops import composite as C
30
31
  from mindspore.ops import operations as P
31
32
  from mindspore.common import dtype as mstype
32
33
  from mindspore.common.api import jit
34
+ from mindspore._c_expression import MSContext
33
35
 
34
36
  _grad_scale = C.MultitypeFuncGraph("grad_scale")
35
37
  reciprocal = P.Reciprocal()
@@ -60,6 +62,28 @@ def _tensor_grad_overflow_row_tensor(grad):
60
62
  return grad_overflow(grad.values)
61
63
 
62
64
 
65
+ _ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow")
66
+ ascend_grad_overflow = P.IsFinite()
67
+
68
+
69
+ @_ascend_grad_overflow.register("Tensor")
70
+ def _tensor_ascend_grad_overflow(grad):
71
+ status = ascend_grad_overflow(grad)
72
+ base = Tensor(1.0, dtype=mstype.float32)
73
+ output = base - status.all()
74
+ output = P.Reshape()(output, ((1,)))
75
+ return output
76
+
77
+
78
+ @_ascend_grad_overflow.register("RowTensor")
79
+ def _tensor_ascend_grad_overflow_row_tensor(grad):
80
+ status = ascend_grad_overflow(grad.values)
81
+ base = Tensor(1.0, dtype=mstype.float32)
82
+ output = base - status.all()
83
+ output = P.Reshape()(output, ((1,)))
84
+ return output
85
+
86
+
63
87
  class DynamicLossScaleUpdateCell(Cell):
64
88
  r"""
65
89
  Dynamic Loss scale update cell.
@@ -296,16 +320,18 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
296
320
  >>> size, in_features, out_features = 16, 16, 10
297
321
  >>> #1) when the type of scale_sense is Cell:
298
322
  >>> net = Net(in_features, out_features)
299
- >>> loss = nn.MSELoss()
323
+ >>> loss_fn = nn.MSELoss()
300
324
  >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
301
- >>> net_with_loss = nn.WithLossCell(net, loss)
302
- >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
303
- >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
325
+ >>> net_with_loss = nn.WithLossCell(net, loss_fn)
304
326
  >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
305
327
  >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
306
- >>> output = train_network(input, labels)
307
- >>> status, scaling_sens = train_network.start_overflow_check(loss, train_network.scaling_sens)
308
- >>> grads = train_network.grad(train_network.network, weights)(*inputs, scaling_sens_filled)
328
+ >>> loss = net_with_loss(input, labels)
329
+ >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
330
+ >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
331
+ >>> status = Tensor([0] * 8, mindspore.int32)
332
+ >>> scaling_sens = train_network.scale_sense
333
+ >>> scaling_sens_filled = ops.ones_like(loss) * ops.cast(scaling_sens, ops.dtype(loss))
334
+ >>> grads = train_network.grad(train_network.network, train_network.weights)(input, labels, scaling_sens_filled)
309
335
  >>> grads = train_network.grad_reducer(grads)
310
336
  >>> cond = train_network.get_overflow_status(status, grads)
311
337
  >>> overflow = train_network.process_loss_scale(cond)
@@ -341,7 +367,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
341
367
  self.allreduce = P.AllReduce()
342
368
  self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
343
369
  self.gpu_target = (context.get_context("device_target") == "GPU")
370
+ self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
371
+ self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910b')
344
372
  self.loss_scaling_manager = None
373
+ self._ascend910b_check_overflow_status_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
374
+
375
+
345
376
  if isinstance(scale_sense, Cell):
346
377
  self.loss_scaling_manager = scale_sense
347
378
  self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
@@ -358,6 +389,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
358
389
  "the 'scale_sense' must be Cell or Tensor, but got 'scale_sense' type: {}."
359
390
  .format(type(scale_sense)))
360
391
  self.enable_tuple_broaden = True
392
+ self._get_attr_from_cell(network)
361
393
 
362
394
  def construct(self, *inputs):
363
395
  weights = self.weights
@@ -418,13 +450,68 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
418
450
  is cleaned up when the function returns.
419
451
  """
420
452
  status = Tensor([0] * 8, mstype.int32)
421
- if not self.gpu_target:
453
+ if self.ascend_910a_target or (self.ascend_910b_target and \
454
+ self._ascend910b_check_overflow_status_mode != "INFNAN_MODE"):
422
455
  status = F.depend(status, pre_cond)
423
456
  # clear overflow buffer
424
457
  clear_status = NPUClearFloatStatusV2()(status)
425
458
  compute_input = F.depend(compute_input, clear_status)
426
459
  return status, compute_input
427
460
 
461
+ def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
462
+ """check overflow status on infnan mode."""
463
+ flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output)
464
+ flag_sum = P.AddN()(flag_sum)
465
+ # convert flag_sum to scalar
466
+ flag_sum = P.Reshape()(flag_sum, (()))
467
+ return flag_sum
468
+
469
+ def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
470
+ """converge the distributed overflow status on infnan mode."""
471
+ flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output)
472
+
473
+ if self.is_distributed:
474
+ # sum overflow flag over devices
475
+ flag_reduce = self.allreduce(flag_sum)
476
+ overflow = self.less_equal(self.base, flag_reduce)
477
+ else:
478
+ overflow = self.less_equal(self.base, flag_sum)
479
+ return overflow
480
+
481
+ def _get_gpu_overflow_status(self, compute_output):
482
+ """get overflow status of gpu."""
483
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
484
+ return overflow
485
+
486
+ def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
487
+ """get overflow status of ascend on infnan mode."""
488
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
489
+ return overflow
490
+
491
+ def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
492
+ """get overflow status of ascend on saturation mode"""
493
+ status = F.depend(status, compute_output)
494
+ get_status = NPUGetFloatStatusV2()(status)
495
+
496
+ if self.is_distributed:
497
+ # sum overflow flag over devices
498
+ flag_reduce = self.allreduce(get_status)
499
+ # get_status not equal to [0]*8 means overflow
500
+ flag = self.equal(self.base0, flag_reduce)
501
+ status = F.depend(status, flag)
502
+ # distributed needs to skip allreduce to avoid its overflow affecting the next step
503
+ clear_status = NPUClearFloatStatusV2()(status)
504
+ flag = F.depend(flag, clear_status)
505
+ overall_finite = self.reduce_all(flag)
506
+ else:
507
+ status = F.depend(status, get_status)
508
+ clear_status = NPUClearFloatStatusV2()(status)
509
+ get_status = F.depend(get_status, clear_status)
510
+ flag = self.equal(self.base0, get_status)
511
+ overall_finite = self.reduce_all(flag)
512
+ overflow = self.logic_not(overall_finite)
513
+ return overflow
514
+
428
515
  @jit
429
516
  def get_overflow_status(self, status, compute_output):
430
517
  """
@@ -442,39 +529,15 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
442
529
  Returns:
443
530
  bool, whether the overflow occurs or not.
444
531
  """
445
- if not self.gpu_target:
446
- status = F.depend(status, compute_output)
447
- get_status = NPUGetFloatStatusV2()(status)
448
-
449
- if self.is_distributed:
450
- # sum overflow flag over devices
451
- flag_reduce = self.allreduce(get_status)
452
- # get_status not equal to [0]*8 means overflow
453
- flag = self.equal(self.base0, flag_reduce)
454
- status = F.depend(status, flag)
455
- # distributed needs to skip allreduce to avoid its overflow affecting the next step
456
- clear_status = NPUClearFloatStatusV2()(status)
457
- flag = F.depend(flag, clear_status)
458
- overall_finite = self.reduce_all(flag)
459
- else:
460
- status = F.depend(status, get_status)
461
- clear_status = NPUClearFloatStatusV2()(status)
462
- get_status = F.depend(get_status, clear_status)
463
- flag = self.equal(self.base0, get_status)
464
- overall_finite = self.reduce_all(flag)
465
- overflow = self.logic_not(overall_finite)
466
- else:
467
- flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
468
- flag_sum = P.AddN()(flag_sum)
469
- # convert flag_sum to scalar
470
- flag_sum = P.Reshape()(flag_sum, (()))
471
-
472
- if self.is_distributed:
473
- # sum overflow flag over devices
474
- flag_reduce = self.allreduce(flag_sum)
475
- overflow = self.less_equal(self.base, flag_reduce)
532
+ if self.gpu_target:
533
+ overflow = self._get_gpu_overflow_status(compute_output)
534
+ elif self.ascend_910b_target:
535
+ if self._ascend910b_check_overflow_status_mode != "INFNAN_MODE":
536
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
476
537
  else:
477
- overflow = self.less_equal(self.base, flag_sum)
538
+ overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output)
539
+ else: # ascend_910a_target
540
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
478
541
  return overflow
479
542
 
480
543
  def process_loss_scale(self, overflow):
@@ -517,7 +580,7 @@ def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
517
580
  return new_grad
518
581
 
519
582
 
520
- class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
583
+ class _TrainGradAccuWithLossScaleCell(TrainOneStepCell):
521
584
  """
522
585
  Append an optimizer to the training network after that the construct
523
586
  function can be called to create the backward graph.
@@ -528,7 +591,7 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
528
591
  scale_sense (Cell): Cell to do the loss scale.
529
592
  """
530
593
  def __init__(self, network, optimizer, scale_sense):
531
- super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None)
594
+ super(_TrainGradAccuWithLossScaleCell, self).__init__(network, optimizer, sens=None)
532
595
  self.network = network
533
596
  self.network.add_flags(defer_inline=True)
534
597
  self.weights = optimizer.parameters
@@ -1304,7 +1304,7 @@ def triu(m, k=0):
1304
1304
  if rank < 1:
1305
1305
  _raise_value_error("input m's rank should be larger than 0")
1306
1306
  elif rank == 1:
1307
- mask = tri(m.shape[0], k=k-1, dtype=mstype.bool_)
1307
+ mask = tri(m.shape[0], k=k - 1, dtype=mstype.bool_)
1308
1308
  return where(mask, zeros(1, m.dtype), m)
1309
1309
  # Only Ascend hardware will reduce accuracy
1310
1310
  if device_target == "Ascend":
@@ -2587,7 +2587,6 @@ def _limit_stat_length(stat_length, shape):
2587
2587
  return tuple((min(stat_pair[0], shape[i]), min(stat_pair[1], shape[i])) for i, stat_pair in enumerate(stat_length))
2588
2588
 
2589
2589
 
2590
- @constexpr
2591
2590
  def _convert_pad_to_nd(pad_values, ndim):
2592
2591
  """broadcasts the pad_values to (ndim * 2)"""
2593
2592
  if not isinstance(pad_values, (int, list, tuple, Tensor)):
@@ -2585,8 +2585,9 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
2585
2585
  """
2586
2586
  def unique_w_ind(arr):
2587
2587
  array, sort_indices = arr.ravel().sort()
2588
- cmp_array1 = F.cat((array, Tensor([0], dtype=array.dtype)))
2589
- cmp_array2 = F.cat((Tensor([0], dtype=array.dtype), array))
2588
+ array_type = array.dtype
2589
+ cmp_array1 = F.cat((array, Tensor([0], dtype=array_type)))
2590
+ cmp_array2 = F.cat((Tensor([0], dtype=array_type), array))
2590
2591
  mask = cmp_array1 != cmp_array2
2591
2592
  mask[0] = True
2592
2593
  array = F.masked_select(array, mask[:-1])