mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  import numpy as np
17
17
  from mindspore import _checkparam as validator
18
18
  from mindspore.ops import operations as P
19
+ from mindspore.ops import functional as F
19
20
  from mindspore.common import dtype as mstype
20
21
  import mindspore.nn as nn
21
22
  from .distribution import Distribution
@@ -125,7 +126,6 @@ class TransformedDistribution(Distribution):
125
126
  self.cast_base = P.Cast()
126
127
  self.equal_base = P.Equal()
127
128
  self.select_base = P.Select()
128
- self.fill_base = P.Fill()
129
129
 
130
130
  # broadcast bijector batch_shape and distribution batch_shape
131
131
  self._broadcast_shape = self._broadcast_bijector_dist()
@@ -176,9 +176,9 @@ class TransformedDistribution(Distribution):
176
176
  """
177
177
  if self.batch_shape is None or self.bijector.batch_shape is None:
178
178
  return None
179
- bijector_shape_tensor = self.fill_base(
179
+ bijector_shape_tensor = F.fill(
180
180
  self.dtype, self.bijector.batch_shape, 0.0)
181
- dist_shape_tensor = self.fill_base(self.dtype, self.batch_shape, 0.0)
181
+ dist_shape_tensor = F.fill(self.dtype, self.batch_shape, 0.0)
182
182
  return (bijector_shape_tensor + dist_shape_tensor).shape
183
183
 
184
184
  def _cdf(self, value, *args, **kwargs):
@@ -14,6 +14,7 @@
14
14
  # ============================================================================
15
15
  """Uniform Distribution"""
16
16
  import numpy as np
17
+ from mindspore.ops import functional as F
17
18
  from mindspore.ops import operations as P
18
19
  from mindspore.ops import composite as C
19
20
  from mindspore import _checkparam as Validator
@@ -170,7 +171,6 @@ class Uniform(Distribution):
170
171
  self.cast = P.Cast()
171
172
  self.const = P.ScalarToTensor()
172
173
  self.dtypeop = P.DType()
173
- self.fill = P.Fill()
174
174
  self.less = P.Less()
175
175
  self.lessequal = P.LessEqual()
176
176
  self.logicaland = P.LogicalAnd()
@@ -287,10 +287,10 @@ class Uniform(Distribution):
287
287
  value = self._check_value(value, 'value')
288
288
  value = self.cast(value, self.dtype)
289
289
  low, high = self._check_param_type(low, high)
290
- neg_ones = self.fill(self.dtype, self.shape(value), -1.0)
290
+ neg_ones = F.fill(self.dtype, self.shape(value), -1.0)
291
291
  prob = self.exp(neg_ones * self.log(high - low))
292
292
  broadcast_shape = self.shape(prob)
293
- zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
293
+ zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
294
294
  comp_lo = self.less(value, low)
295
295
  comp_hi = self.lessequal(value, high)
296
296
  less_than_low = self.select(comp_lo, zeros, prob)
@@ -316,7 +316,7 @@ class Uniform(Distribution):
316
316
  kl = self.log(high_b - low_b) - self.log(high_a - low_a)
317
317
  comp = self.logicaland(self.lessequal(
318
318
  low_b, low_a), self.lessequal(high_a, high_b))
319
- inf = self.fill(self.dtypeop(kl), self.shape(kl), np.inf)
319
+ inf = F.fill(self.dtypeop(kl), self.shape(kl), np.inf)
320
320
  return self.select(comp, kl, inf)
321
321
 
322
322
  def _cdf(self, value, low=None, high=None):
@@ -338,8 +338,8 @@ class Uniform(Distribution):
338
338
  low, high = self._check_param_type(low, high)
339
339
  prob = (value - low) / (high - low)
340
340
  broadcast_shape = self.shape(prob)
341
- zeros = self.fill(self.dtypeop(prob), broadcast_shape, 0.0)
342
- ones = self.fill(self.dtypeop(prob), broadcast_shape, 1.0)
341
+ zeros = F.fill(self.dtypeop(prob), broadcast_shape, 0.0)
342
+ ones = F.fill(self.dtypeop(prob), broadcast_shape, 1.0)
343
343
  comp_lo = self.less(value, low)
344
344
  comp_hi = self.less(value, high)
345
345
  less_than_low = self.select(comp_lo, zeros, prob)
@@ -99,7 +99,7 @@ class WithLossCell(Cell):
99
99
  >>> from mindspore import Tensor, nn
100
100
  >>> import numpy as np
101
101
  >>> # Define the network structure of LeNet5. Refer to
102
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
102
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
103
103
  >>> net = LeNet5()
104
104
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
105
105
  >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
@@ -115,8 +115,7 @@ class WithLossCell(Cell):
115
115
  super(WithLossCell, self).__init__(auto_prefix=False)
116
116
  self._backbone = backbone
117
117
  self._loss_fn = loss_fn
118
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
119
- self._jit_config_dict = backbone.jit_config_dict
118
+ self._get_attr_from_cell(backbone)
120
119
 
121
120
  def construct(self, data, label):
122
121
  out = self._backbone(data)
@@ -133,7 +132,7 @@ class WithLossCell(Cell):
133
132
  Examples:
134
133
  >>> from mindspore import nn
135
134
  >>> # Define the network structure of LeNet5. Refer to
136
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
135
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
137
136
  >>> net = LeNet5()
138
137
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
139
138
  >>> net_with_criterion = nn.WithLossCell(net, loss_fn)
@@ -176,7 +175,7 @@ class WithGradCell(Cell):
176
175
  >>> import mindspore as ms
177
176
  >>> from mindspore import nn
178
177
  >>> # Defined a network without loss function, taking LeNet5 as an example.
179
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
178
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
180
179
  >>> net = LeNet5()
181
180
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
182
181
  >>> grad_net = nn.WithGradCell(net, loss_fn)
@@ -199,8 +198,7 @@ class WithGradCell(Cell):
199
198
  else:
200
199
  self.network_with_loss = WithLossCell(self.network, self.loss_fn)
201
200
  self.network_with_loss.set_train()
202
- if isinstance(network, Cell) and network.jit_config_dict:
203
- self._jit_config_dict = network.jit_config_dict
201
+ self._get_attr_from_cell(network)
204
202
 
205
203
  def construct(self, *inputs):
206
204
  weights = self.weights
@@ -219,7 +217,7 @@ class ForwardValueAndGrad(Cell):
219
217
  The backward graph will be created in the gradient function to calculating gradient.
220
218
 
221
219
  Args:
222
- network (Cell): The training network.
220
+ network (Union[Cell, Function, MethodType]): The training network.
223
221
  weights (ParameterTuple): The parameters of the training network that need to calculate the gradient.
224
222
  Default: ``None`` .
225
223
  get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
@@ -302,8 +300,7 @@ class ForwardValueAndGrad(Cell):
302
300
  self.get_by_list = get_by_list
303
301
  self.sens_param = sens_param
304
302
  self.grad = C.GradOperation(get_all=self.get_all, get_by_list=self.get_by_list, sens_param=self.sens_param)
305
- if isinstance(network, Cell) and network.jit_config_dict:
306
- self._jit_config_dict = network.jit_config_dict
303
+ self._get_attr_from_cell(network)
307
304
 
308
305
  def construct(self, *inputs):
309
306
  grad_inputs = inputs
@@ -349,7 +346,7 @@ class TrainOneStepCell(Cell):
349
346
  Examples:
350
347
  >>> import mindspore.nn as nn
351
348
  >>> # Define the network structure of LeNet5. Refer to
352
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
349
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
353
350
  >>> net = LeNet5()
354
351
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
355
352
  >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
@@ -414,8 +411,7 @@ class TrainOneStepCell(Cell):
414
411
  create_group(server_group_name, group_list[current_index])
415
412
  group = server_group_name
416
413
  self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
417
- if isinstance(network, Cell) and network.jit_config_dict:
418
- self._jit_config_dict = network.jit_config_dict
414
+ self._get_attr_from_cell(network)
419
415
 
420
416
  def construct(self, *inputs):
421
417
  if not self.sense_flag:
@@ -514,8 +510,7 @@ class _VirtualDatasetCell(Cell):
514
510
  super(_VirtualDatasetCell, self).__init__(auto_prefix=False)
515
511
  self._backbone = backbone
516
512
  self._virtual_dataset = _VirtualDataset()
517
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
518
- self._jit_config_dict = backbone.jit_config_dict
513
+ self._get_attr_from_cell(backbone)
519
514
 
520
515
  def construct(self, *inputs):
521
516
  output = self._virtual_dataset(*inputs)
@@ -524,6 +519,8 @@ class _VirtualDatasetCell(Cell):
524
519
 
525
520
  @_primexpr
526
521
  def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
522
+ if F.isconstant(input_shape[0]) is False:
523
+ return
527
524
  if input_shape[0] % micro_size != 0:
528
525
  raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
529
526
  f"divided by micro size({micro_size})")
@@ -548,8 +545,8 @@ class _MicroBatch(Cell):
548
545
  for each_input in inputs:
549
546
  input_shape = self.shape(each_input)
550
547
  _check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
551
- micro_batch_begin = i * input_shape[0] // self.micro_size
552
- micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
548
+ micro_batch_begin = (input_shape[0] // self.micro_size) * i
549
+ micro_batch_end = (input_shape[0] // self.micro_size) * (i + 1)
553
550
  strided_slice_begin = (micro_batch_begin,)
554
551
  strided_slice_strides = (1,)
555
552
  for _ in range(len(input_shape) - 1):
@@ -589,7 +586,7 @@ class MicroBatchInterleaved(Cell):
589
586
  Examples:
590
587
  >>> import mindspore.nn as nn
591
588
  >>> # Define the network structure of LeNet5. Refer to
592
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
589
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
593
590
  >>> net = LeNet5()
594
591
  >>> net = nn.MicroBatchInterleaved(net, 2)
595
592
  """
@@ -610,8 +607,7 @@ class MicroBatchInterleaved(Cell):
610
607
  interleave_data.strided_slice.add_prim_attr("strided_slice_flag", True)
611
608
  interleave_data.strided_slice.add_prim_attr("interleave_num", interleave_num)
612
609
  self.interleave_inputs.append(interleave_data)
613
- if isinstance(network, Cell) and network.jit_config_dict:
614
- self._jit_config_dict = network.jit_config_dict
610
+ self._get_attr_from_cell(network)
615
611
 
616
612
  def construct(self, *inputs):
617
613
  output = 0.0
@@ -638,7 +634,7 @@ class PipelineCell(Cell):
638
634
  Examples:
639
635
  >>> import mindspore.nn as nn
640
636
  >>> # Define the network structure of LeNet5. Refer to
641
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
637
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
642
638
  >>> net = LeNet5()
643
639
  >>> net = nn.PipelineCell(net, 4)
644
640
  """
@@ -648,13 +644,64 @@ class PipelineCell(Cell):
648
644
  self.micro_inputs = nn.CellList()
649
645
  self.micro_size = micro_size
650
646
  self.add_list = []
647
+ if not isinstance(micro_size, int):
648
+ raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
649
+ "but got the type : {}.".format(type(micro_size)))
650
+ if micro_size <= 0:
651
+ raise ValueError("For 'PipelineCell', the argument 'micro_size' must be large than 0, "
652
+ "but got {}.".format(micro_size))
651
653
  for i in range(micro_size):
652
654
  micro_input = _MicroBatch(micro_size)
653
655
  self.micro_inputs.append(micro_input)
654
656
  self.add = P.Add().add_prim_attr("pipeline_end", i)
655
657
  self.add_list.append(self.add)
656
- if isinstance(network, Cell) and network.jit_config_dict:
657
- self._jit_config_dict = network.jit_config_dict
658
+ self._get_attr_from_cell(network)
659
+
660
+ def construct(self, *inputs):
661
+ ret = None
662
+ for i in range(self.micro_size):
663
+ micro_input = self.micro_inputs[i](i, *inputs)
664
+ output = self.network(*micro_input)
665
+ if ret is not None:
666
+ ret = self.add_list[i](ret, output)
667
+ else:
668
+ ret = output
669
+ return ret
670
+
671
+ class GradAccumulationCell(Cell):
672
+ """
673
+ Wrap the network with Micro Batch.
674
+
675
+ Args:
676
+ network (Cell): The target network to wrap.
677
+ micro_size (int): MicroBatch size.
678
+
679
+ Supported Platforms:
680
+ ``Ascend`` ``GPU``
681
+
682
+ Examples:
683
+ >>> net = Net()
684
+ >>> net = GradAccumulationCell(net, 4)
685
+ """
686
+ def __init__(self, network, micro_size):
687
+ super(GradAccumulationCell, self).__init__(auto_prefix=False)
688
+ self.network = network
689
+ self.micro_inputs = nn.CellList()
690
+ self.micro_size = micro_size
691
+ self.add_list = []
692
+ if not isinstance(micro_size, int):
693
+ raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
694
+ "but got the type : {}.".format(type(micro_size)))
695
+ if micro_size <= 0:
696
+ raise ValueError("For 'GradAccumulationCell', the argument 'micro_size' must be large than 0, "
697
+ "but got {}.".format(micro_size))
698
+ for i in range(micro_size):
699
+ micro_input = _MicroBatch(micro_size)
700
+ micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size)
701
+ self.micro_inputs.append(micro_input)
702
+ self.add = P.Add().add_prim_attr("forward_end", i)
703
+ self.add_list.append(self.add)
704
+ self._get_attr_from_cell(network)
658
705
 
659
706
  def construct(self, *inputs):
660
707
  ret = None
@@ -674,23 +721,22 @@ def _pipeline_clear_grad(accu_grad, grad):
674
721
  return F.assign(accu_grad, zeros)
675
722
 
676
723
 
677
- class _TrainPipelineAccuStepCell(TrainOneStepCell):
724
+ class _TrainGradAccuStepCell(TrainOneStepCell):
678
725
  """
679
726
  Wraps the network with an optimizer in pipeline mode.
680
727
  """
681
728
  def __init__(self, network, optimizer, sens=None):
682
- super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens)
729
+ super(_TrainGradAccuStepCell, self).__init__(network, optimizer, sens)
683
730
  self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
684
731
  self.hyper_map = ops.HyperMap()
685
732
  self.opt_shard = _get_enable_parallel_optimizer()
686
- if isinstance(network, Cell) and network.jit_config_dict:
687
- self._jit_config_dict = network.jit_config_dict
733
+ self._get_attr_from_cell(network)
688
734
 
689
735
  def construct(self, *inputs):
690
736
  if not self.sense_flag:
691
737
  return self._no_sens_impl(*inputs)
692
738
  loss = self.network(*inputs)
693
- sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
739
+ sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
694
740
  grads = self.grad(self.network, self.weights)(*inputs, sens)
695
741
  accu_grads = ops.depend(self.accu_grads, grads)
696
742
  if self.opt_shard:
@@ -735,7 +781,7 @@ class VirtualDatasetCellTriple(Cell):
735
781
  Examples:
736
782
  >>> import mindspore.nn as nn
737
783
  >>> # Define the network structure of LeNet5. Refer to
738
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
784
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
739
785
  >>> net = LeNet5()
740
786
  >>> net = nn.VirtualDatasetCellTriple(net)
741
787
  """
@@ -744,8 +790,7 @@ class VirtualDatasetCellTriple(Cell):
744
790
  super(VirtualDatasetCellTriple, self).__init__(auto_prefix=False)
745
791
  logger.warning("WARN_DEPRECATED: The usage of VirtualDatasetCellTriple is deprecated.")
746
792
  self._backbone = backbone
747
- if isinstance(backbone, Cell) and backbone.jit_config_dict:
748
- self._jit_config_dict = backbone.jit_config_dict
793
+ self._get_attr_from_cell(backbone)
749
794
 
750
795
  def construct(self, a, b, c):
751
796
  return self._backbone(a, b, c)
@@ -779,7 +824,7 @@ class WithEvalCell(Cell):
779
824
  Examples:
780
825
  >>> import mindspore.nn as nn
781
826
  >>> # Define a forward network without loss function, taking LeNet5 as an example.
782
- >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
827
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
783
828
  >>> net = LeNet5()
784
829
  >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
785
830
  >>> eval_net = nn.WithEvalCell(net, loss_fn)
@@ -790,8 +835,7 @@ class WithEvalCell(Cell):
790
835
  self._network = network
791
836
  self._loss_fn = loss_fn
792
837
  self.add_cast_fp32 = validator.check_value_type("add_cast_fp32", add_cast_fp32, [bool], self.cls_name)
793
- if isinstance(network, Cell) and network.jit_config_dict:
794
- self._jit_config_dict = network.jit_config_dict
838
+ self._get_attr_from_cell(network)
795
839
 
796
840
  def construct(self, data, label):
797
841
  outputs = self._network(data)
@@ -314,12 +314,15 @@ class DistributedGradReducer(Cell):
314
314
  Before running the following examples, you need to configure the communication environment variables.
315
315
 
316
316
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
317
- Please see the `Ascend tutorial
318
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/train_ascend.html#preparations>`_
317
+ Please see the `rank table Startup
318
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
319
319
  for more details.
320
320
 
321
- For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
322
- <https://www.mindspore.cn/tutorials/experts/en/r2.1/parallel/train_gpu.html#preparation>`_ .
321
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
322
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
323
+
324
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
325
+ Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
323
326
 
324
327
  This example should be run with multiple devices.
325
328
 
@@ -356,7 +359,7 @@ class DistributedGradReducer(Cell):
356
359
  ... def construct(self, *args):
357
360
  ... weights = self.weights
358
361
  ... loss = self.network(*args)
359
- ... sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)
362
+ ... sens = F.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
360
363
  ... grads = self.grad(self.network, weights)(*args, sens)
361
364
  ... if self.reducer_flag:
362
365
  ... # apply grad reducer on grads
@@ -15,6 +15,7 @@
15
15
  """Loss scale cell for loss scale training."""
16
16
  from __future__ import absolute_import
17
17
 
18
+ import os
18
19
  import mindspore.context as context
19
20
  from mindspore.context import ParallelMode
20
21
  from mindspore.parallel._utils import _get_enable_parallel_optimizer
@@ -30,6 +31,7 @@ from mindspore.ops import composite as C
30
31
  from mindspore.ops import operations as P
31
32
  from mindspore.common import dtype as mstype
32
33
  from mindspore.common.api import jit
34
+ from mindspore._c_expression import MSContext
33
35
 
34
36
  _grad_scale = C.MultitypeFuncGraph("grad_scale")
35
37
  reciprocal = P.Reciprocal()
@@ -60,6 +62,28 @@ def _tensor_grad_overflow_row_tensor(grad):
60
62
  return grad_overflow(grad.values)
61
63
 
62
64
 
65
+ _ascend_grad_overflow = C.MultitypeFuncGraph("_ascend_grad_overflow")
66
+ ascend_grad_overflow = P.IsFinite()
67
+
68
+
69
+ @_ascend_grad_overflow.register("Tensor")
70
+ def _tensor_ascend_grad_overflow(grad):
71
+ status = ascend_grad_overflow(grad)
72
+ base = Tensor(1.0, dtype=mstype.float32)
73
+ output = base - status.all()
74
+ output = P.Reshape()(output, ((1,)))
75
+ return output
76
+
77
+
78
+ @_ascend_grad_overflow.register("RowTensor")
79
+ def _tensor_ascend_grad_overflow_row_tensor(grad):
80
+ status = ascend_grad_overflow(grad.values)
81
+ base = Tensor(1.0, dtype=mstype.float32)
82
+ output = base - status.all()
83
+ output = P.Reshape()(output, ((1,)))
84
+ return output
85
+
86
+
63
87
  class DynamicLossScaleUpdateCell(Cell):
64
88
  r"""
65
89
  Dynamic Loss scale update cell.
@@ -296,16 +320,18 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
296
320
  >>> size, in_features, out_features = 16, 16, 10
297
321
  >>> #1) when the type of scale_sense is Cell:
298
322
  >>> net = Net(in_features, out_features)
299
- >>> loss = nn.MSELoss()
323
+ >>> loss_fn = nn.MSELoss()
300
324
  >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
301
- >>> net_with_loss = nn.WithLossCell(net, loss)
302
- >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
303
- >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
325
+ >>> net_with_loss = nn.WithLossCell(net, loss_fn)
304
326
  >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
305
327
  >>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
306
- >>> output = train_network(input, labels)
307
- >>> status, scaling_sens = train_network.start_overflow_check(loss, train_network.scaling_sens)
308
- >>> grads = train_network.grad(train_network.network, weights)(*inputs, scaling_sens_filled)
328
+ >>> loss = net_with_loss(input, labels)
329
+ >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
330
+ >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
331
+ >>> status = Tensor([0] * 8, mindspore.int32)
332
+ >>> scaling_sens = train_network.scale_sense
333
+ >>> scaling_sens_filled = ops.ones_like(loss) * ops.cast(scaling_sens, ops.dtype(loss))
334
+ >>> grads = train_network.grad(train_network.network, train_network.weights)(input, labels, scaling_sens_filled)
309
335
  >>> grads = train_network.grad_reducer(grads)
310
336
  >>> cond = train_network.get_overflow_status(status, grads)
311
337
  >>> overflow = train_network.process_loss_scale(cond)
@@ -341,7 +367,12 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
341
367
  self.allreduce = P.AllReduce()
342
368
  self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
343
369
  self.gpu_target = (context.get_context("device_target") == "GPU")
370
+ self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
371
+ self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910b')
344
372
  self.loss_scaling_manager = None
373
+ self._ascend910b_check_overflow_status_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
374
+
375
+
345
376
  if isinstance(scale_sense, Cell):
346
377
  self.loss_scaling_manager = scale_sense
347
378
  self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
@@ -358,6 +389,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
358
389
  "the 'scale_sense' must be Cell or Tensor, but got 'scale_sense' type: {}."
359
390
  .format(type(scale_sense)))
360
391
  self.enable_tuple_broaden = True
392
+ self._get_attr_from_cell(network)
361
393
 
362
394
  def construct(self, *inputs):
363
395
  weights = self.weights
@@ -418,13 +450,68 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
418
450
  is cleaned up when the function returns.
419
451
  """
420
452
  status = Tensor([0] * 8, mstype.int32)
421
- if not self.gpu_target:
453
+ if self.ascend_910a_target or (self.ascend_910b_target and \
454
+ self._ascend910b_check_overflow_status_mode != "INFNAN_MODE"):
422
455
  status = F.depend(status, pre_cond)
423
456
  # clear overflow buffer
424
457
  clear_status = NPUClearFloatStatusV2()(status)
425
458
  compute_input = F.depend(compute_input, clear_status)
426
459
  return status, compute_input
427
460
 
461
+ def _check_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
462
+ """check overflow status on infnan mode."""
463
+ flag_sum = self.hyper_map(F.partial(grad_overflow_check_func), compute_output)
464
+ flag_sum = P.AddN()(flag_sum)
465
+ # convert flag_sum to scalar
466
+ flag_sum = P.Reshape()(flag_sum, (()))
467
+ return flag_sum
468
+
469
+ def _get_distributed_overflow_status_on_infnan_mode(self, grad_overflow_check_func, compute_output):
470
+ """converge the distributed overflow status on infnan mode."""
471
+ flag_sum = self._check_overflow_status_on_infnan_mode(grad_overflow_check_func, compute_output)
472
+
473
+ if self.is_distributed:
474
+ # sum overflow flag over devices
475
+ flag_reduce = self.allreduce(flag_sum)
476
+ overflow = self.less_equal(self.base, flag_reduce)
477
+ else:
478
+ overflow = self.less_equal(self.base, flag_sum)
479
+ return overflow
480
+
481
+ def _get_gpu_overflow_status(self, compute_output):
482
+ """get overflow status of gpu."""
483
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
484
+ return overflow
485
+
486
+ def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
487
+ """get overflow status of ascend on infnan mode."""
488
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
489
+ return overflow
490
+
491
+ def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
492
+ """get overflow status of ascend on saturation mode"""
493
+ status = F.depend(status, compute_output)
494
+ get_status = NPUGetFloatStatusV2()(status)
495
+
496
+ if self.is_distributed:
497
+ # sum overflow flag over devices
498
+ flag_reduce = self.allreduce(get_status)
499
+ # get_status not equal to [0]*8 means overflow
500
+ flag = self.equal(self.base0, flag_reduce)
501
+ status = F.depend(status, flag)
502
+ # distributed needs to skip allreduce to avoid its overflow affecting the next step
503
+ clear_status = NPUClearFloatStatusV2()(status)
504
+ flag = F.depend(flag, clear_status)
505
+ overall_finite = self.reduce_all(flag)
506
+ else:
507
+ status = F.depend(status, get_status)
508
+ clear_status = NPUClearFloatStatusV2()(status)
509
+ get_status = F.depend(get_status, clear_status)
510
+ flag = self.equal(self.base0, get_status)
511
+ overall_finite = self.reduce_all(flag)
512
+ overflow = self.logic_not(overall_finite)
513
+ return overflow
514
+
428
515
  @jit
429
516
  def get_overflow_status(self, status, compute_output):
430
517
  """
@@ -442,39 +529,15 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
442
529
  Returns:
443
530
  bool, whether the overflow occurs or not.
444
531
  """
445
- if not self.gpu_target:
446
- status = F.depend(status, compute_output)
447
- get_status = NPUGetFloatStatusV2()(status)
448
-
449
- if self.is_distributed:
450
- # sum overflow flag over devices
451
- flag_reduce = self.allreduce(get_status)
452
- # get_status not equal to [0]*8 means overflow
453
- flag = self.equal(self.base0, flag_reduce)
454
- status = F.depend(status, flag)
455
- # distributed needs to skip allreduce to avoid its overflow affecting the next step
456
- clear_status = NPUClearFloatStatusV2()(status)
457
- flag = F.depend(flag, clear_status)
458
- overall_finite = self.reduce_all(flag)
459
- else:
460
- status = F.depend(status, get_status)
461
- clear_status = NPUClearFloatStatusV2()(status)
462
- get_status = F.depend(get_status, clear_status)
463
- flag = self.equal(self.base0, get_status)
464
- overall_finite = self.reduce_all(flag)
465
- overflow = self.logic_not(overall_finite)
466
- else:
467
- flag_sum = self.hyper_map(F.partial(_grad_overflow), compute_output)
468
- flag_sum = P.AddN()(flag_sum)
469
- # convert flag_sum to scalar
470
- flag_sum = P.Reshape()(flag_sum, (()))
471
-
472
- if self.is_distributed:
473
- # sum overflow flag over devices
474
- flag_reduce = self.allreduce(flag_sum)
475
- overflow = self.less_equal(self.base, flag_reduce)
532
+ if self.gpu_target:
533
+ overflow = self._get_gpu_overflow_status(compute_output)
534
+ elif self.ascend_910b_target:
535
+ if self._ascend910b_check_overflow_status_mode != "INFNAN_MODE":
536
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
476
537
  else:
477
- overflow = self.less_equal(self.base, flag_sum)
538
+ overflow = self._get_ascend_overflow_status_on_infnan_mode(compute_output)
539
+ else: # ascend_910a_target
540
+ overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
478
541
  return overflow
479
542
 
480
543
  def process_loss_scale(self, overflow):
@@ -517,7 +580,7 @@ def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
517
580
  return new_grad
518
581
 
519
582
 
520
- class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
583
+ class _TrainGradAccuWithLossScaleCell(TrainOneStepCell):
521
584
  """
522
585
  Append an optimizer to the training network after that the construct
523
586
  function can be called to create the backward graph.
@@ -528,7 +591,7 @@ class _TrainPipelineWithLossScaleCell(TrainOneStepCell):
528
591
  scale_sense (Cell): Cell to do the loss scale.
529
592
  """
530
593
  def __init__(self, network, optimizer, scale_sense):
531
- super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None)
594
+ super(_TrainGradAccuWithLossScaleCell, self).__init__(network, optimizer, sens=None)
532
595
  self.network = network
533
596
  self.network.add_flags(defer_inline=True)
534
597
  self.weights = optimizer.parameters
@@ -1304,7 +1304,7 @@ def triu(m, k=0):
1304
1304
  if rank < 1:
1305
1305
  _raise_value_error("input m's rank should be larger than 0")
1306
1306
  elif rank == 1:
1307
- mask = tri(m.shape[0], k=k-1, dtype=mstype.bool_)
1307
+ mask = tri(m.shape[0], k=k - 1, dtype=mstype.bool_)
1308
1308
  return where(mask, zeros(1, m.dtype), m)
1309
1309
  # Only Ascend hardware will reduce accuracy
1310
1310
  if device_target == "Ascend":
@@ -2587,7 +2587,6 @@ def _limit_stat_length(stat_length, shape):
2587
2587
  return tuple((min(stat_pair[0], shape[i]), min(stat_pair[1], shape[i])) for i, stat_pair in enumerate(stat_length))
2588
2588
 
2589
2589
 
2590
- @constexpr
2591
2590
  def _convert_pad_to_nd(pad_values, ndim):
2592
2591
  """broadcasts the pad_values to (ndim * 2)"""
2593
2592
  if not isinstance(pad_values, (int, list, tuple, Tensor)):
@@ -2585,8 +2585,9 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
2585
2585
  """
2586
2586
  def unique_w_ind(arr):
2587
2587
  array, sort_indices = arr.ravel().sort()
2588
- cmp_array1 = F.cat((array, Tensor([0], dtype=array.dtype)))
2589
- cmp_array2 = F.cat((Tensor([0], dtype=array.dtype), array))
2588
+ array_type = array.dtype
2589
+ cmp_array1 = F.cat((array, Tensor([0], dtype=array_type)))
2590
+ cmp_array2 = F.cat((Tensor([0], dtype=array_type), array))
2590
2591
  mask = cmp_array1 != cmp_array2
2591
2592
  mask[0] = True
2592
2593
  array = F.masked_select(array, mask[:-1])
@@ -77,11 +77,11 @@ class ConvertToolLoader:
77
77
  'dump_data_parser').DumpDataParser
78
78
  self.format_conversion = import_module(
79
79
  'shape_conversion').FormatConversionMain
80
- except ModuleNotFoundError:
80
+ except ModuleNotFoundError as err:
81
81
  self.reset_system_path()
82
82
  raise ModuleNotFoundError(
83
83
  "Failed to load CANN conversion tools under {}. Please make sure Ascend " \
84
- "toolkit has been installed properly.".format(self.toolkit_path))
84
+ "toolkit has been installed properly.".format(self.toolkit_path)) from err
85
85
 
86
86
  try:
87
87
  self.progress = import_module("progress").Progress