mindspore 2.3.0__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (423) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +0 -1512
  3. mindspore/__init__.py +1 -2
  4. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  5. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  6. mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
  7. mindspore/_checkparam.py +25 -5
  8. mindspore/_extends/graph_kernel/model/graph_parallel.py +1 -1
  9. mindspore/_extends/parse/__init__.py +2 -2
  10. mindspore/_extends/parse/compile_config.py +0 -29
  11. mindspore/_extends/parse/namespace.py +2 -2
  12. mindspore/_extends/parse/parser.py +5 -21
  13. mindspore/_extends/parse/resources.py +7 -5
  14. mindspore/_extends/parse/standard_method.py +59 -40
  15. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  16. mindspore/amp.py +5 -26
  17. mindspore/bin/cache_admin +0 -0
  18. mindspore/bin/cache_server +0 -0
  19. mindspore/boost/adasum.py +1 -1
  20. mindspore/boost/base.py +1 -1
  21. mindspore/boost/boost_cell_wrapper.py +1 -1
  22. mindspore/boost/grad_freeze.py +2 -2
  23. mindspore/boost/less_batch_normalization.py +6 -9
  24. mindspore/common/__init__.py +1 -8
  25. mindspore/common/_register_for_tensor.py +9 -8
  26. mindspore/common/api.py +65 -275
  27. mindspore/common/dtype.py +4 -8
  28. mindspore/common/dump.py +5 -2
  29. mindspore/common/jit_config.py +1 -1
  30. mindspore/common/lazy_inline.py +2 -14
  31. mindspore/common/parameter.py +15 -14
  32. mindspore/common/recompute.py +5 -20
  33. mindspore/common/sparse_tensor.py +6 -21
  34. mindspore/common/tensor.py +52 -100
  35. mindspore/communication/__init__.py +11 -6
  36. mindspore/communication/management.py +94 -92
  37. mindspore/context.py +18 -180
  38. mindspore/dataset/engine/datasets.py +46 -69
  39. mindspore/dataset/engine/datasets_user_defined.py +53 -72
  40. mindspore/dataset/engine/datasets_vision.py +2 -2
  41. mindspore/dataset/engine/queue.py +38 -56
  42. mindspore/dataset/engine/validators.py +5 -11
  43. mindspore/dataset/vision/__init__.py +5 -5
  44. mindspore/dataset/vision/c_transforms.py +5 -5
  45. mindspore/dataset/vision/py_transforms_util.py +1 -1
  46. mindspore/dataset/vision/transforms.py +46 -591
  47. mindspore/dataset/vision/utils.py +1 -121
  48. mindspore/dataset/vision/validators.py +3 -9
  49. mindspore/hal/__init__.py +1 -7
  50. mindspore/hal/device.py +1 -1
  51. mindspore/include/api/model.h +0 -3
  52. mindspore/include/dataset/vision.h +2 -54
  53. mindspore/include/mindapi/base/types.h +0 -1
  54. mindspore/lib/libdnnl.so.2 +0 -0
  55. mindspore/lib/libmindspore.so +0 -0
  56. mindspore/lib/libmindspore_backend.so +0 -0
  57. mindspore/lib/libmindspore_common.so +0 -0
  58. mindspore/lib/libmindspore_core.so +0 -0
  59. mindspore/lib/libmindspore_glog.so.0 +0 -0
  60. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  61. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  62. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  63. mindspore/lib/libmindspore_shared_lib.so +0 -0
  64. mindspore/lib/libmpi_adapter.so +0 -0
  65. mindspore/lib/libmpi_collective.so +0 -0
  66. mindspore/lib/libnnacl.so +0 -0
  67. mindspore/lib/libopencv_core.so.4.5 +0 -0
  68. mindspore/lib/libps_cache.so +0 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +0 -35
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  71. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +0 -72
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/{aclnn_all_finite.h → aclnn_add_custom.h} +11 -9
  76. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +1 -1
  77. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +1 -1
  78. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
  79. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +12 -184
  80. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +15 -7
  81. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +15 -7
  82. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
  83. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
  84. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +31 -77
  85. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +31 -77
  86. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64/libcust_opmaster_rt2.0.so +0 -0
  87. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
  88. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +5 -4
  89. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so +0 -0
  90. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  91. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  92. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  93. mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
  94. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  95. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
  96. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
  97. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
  98. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +286 -275
  99. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
  100. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
  101. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
  102. mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
  103. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/add_impl.h +0 -1
  104. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +0 -1
  105. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -3
  106. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/backend_param.h +0 -5
  107. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/cast/cast_tiling.h +45 -1
  108. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/compare/compare_impl.h +0 -1
  109. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_impl.h +4 -8
  110. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_tiling.h +4 -11
  111. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/kernel/flash_attention_score_mix_hwsync.h +0 -18
  112. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_kernel.h +0 -6
  113. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_rtbackend.h +75 -1
  114. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/kernel/matmul.h +5 -5
  115. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +3 -18
  116. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_common_tiling.h +5 -5
  117. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_info.h +2 -2
  118. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_data.h +3 -36
  119. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/kernel/matmul_stridedslice_fusion.h +2 -2
  120. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +4 -22
  121. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +2 -16
  122. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +3 -1
  123. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_impl.h +4 -5
  124. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +4 -9
  125. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/attention_param.h +2 -5
  126. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +0 -1
  127. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_qkv_param.h +4 -10
  128. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +12 -0
  129. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +0 -1
  130. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +0 -1
  131. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +1 -1
  132. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/backend.h +2 -10
  133. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/elewise_utils.h +1 -5
  134. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log.h +0 -1
  135. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +0 -17
  136. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/math.h +7 -2
  137. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
  138. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
  139. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
  140. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
  141. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
  142. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
  143. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
  144. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
  145. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
  146. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
  147. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
  148. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
  149. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
  150. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
  151. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
  152. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
  153. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
  154. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
  155. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
  156. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
  157. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
  158. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
  159. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
  160. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
  161. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
  162. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
  163. mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
  164. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  165. mindspore/mindrecord/filewriter.py +2 -2
  166. mindspore/mint/__init__.py +40 -720
  167. mindspore/mint/nn/__init__.py +7 -89
  168. mindspore/mint/nn/functional.py +16 -165
  169. mindspore/mint/optim/adamw.py +16 -15
  170. mindspore/nn/__init__.py +2 -0
  171. mindspore/nn/cell.py +98 -97
  172. mindspore/nn/extend/basic.py +2 -2
  173. mindspore/nn/extend/embedding.py +1 -1
  174. mindspore/nn/extend/layer/normalization.py +5 -7
  175. mindspore/nn/generator.py +297 -0
  176. mindspore/nn/layer/activation.py +3 -4
  177. mindspore/nn/layer/basic.py +16 -79
  178. mindspore/nn/layer/conv.py +8 -17
  179. mindspore/nn/layer/embedding.py +4 -1
  180. mindspore/nn/layer/math.py +1 -1
  181. mindspore/nn/layer/normalization.py +1 -1
  182. mindspore/nn/layer/pooling.py +0 -5
  183. mindspore/nn/layer/rnn_cells.py +2 -2
  184. mindspore/nn/loss/loss.py +19 -19
  185. mindspore/nn/optim/adasum.py +1 -1
  186. mindspore/nn/optim/sgd.py +2 -3
  187. mindspore/nn/probability/distribution/exponential.py +1 -1
  188. mindspore/nn/probability/distribution/geometric.py +1 -1
  189. mindspore/nn/probability/distribution/logistic.py +1 -1
  190. mindspore/nn/wrap/cell_wrapper.py +1 -25
  191. mindspore/nn/wrap/loss_scale.py +1 -24
  192. mindspore/numpy/array_ops.py +1 -5
  193. mindspore/numpy/dtypes.py +3 -3
  194. mindspore/numpy/math_ops.py +8 -8
  195. mindspore/ops/__init__.py +1 -1
  196. mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -75
  197. mindspore/ops/_vmap/vmap_array_ops.py +0 -27
  198. mindspore/ops/_vmap/vmap_math_ops.py +1 -29
  199. mindspore/ops/_vmap/vmap_nn_ops.py +18 -19
  200. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +8 -34
  201. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +9 -2
  202. mindspore/ops/auto_generate/gen_arg_handler.py +0 -26
  203. mindspore/ops/auto_generate/gen_extend_func.py +27 -603
  204. mindspore/ops/auto_generate/gen_ops_def.py +203 -993
  205. mindspore/ops/auto_generate/gen_ops_prim.py +402 -1946
  206. mindspore/ops/auto_generate/pyboost_inner_prim.py +20 -90
  207. mindspore/ops/composite/base.py +6 -3
  208. mindspore/ops/composite/math_ops.py +1 -1
  209. mindspore/ops/composite/multitype_ops/_compile_utils.py +17 -24
  210. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  211. mindspore/ops/extend/__init__.py +3 -2
  212. mindspore/ops/extend/array_func.py +51 -10
  213. mindspore/ops/extend/nn_func.py +78 -2
  214. mindspore/ops/function/__init__.py +13 -8
  215. mindspore/ops/function/array_func.py +179 -455
  216. mindspore/ops/function/clip_func.py +1 -1
  217. mindspore/ops/function/grad/grad_func.py +3 -3
  218. mindspore/ops/function/math_func.py +103 -117
  219. mindspore/ops/function/nn_func.py +163 -275
  220. mindspore/ops/function/other_func.py +2 -2
  221. mindspore/ops/function/random_func.py +69 -202
  222. mindspore/ops/function/sparse_func.py +4 -4
  223. mindspore/ops/functional.py +327 -332
  224. mindspore/ops/operations/__init__.py +3 -13
  225. mindspore/ops/operations/_grad_ops.py +27 -3
  226. mindspore/ops/operations/_inner_ops.py +356 -53
  227. mindspore/ops/operations/_rl_inner_ops.py +2 -2
  228. mindspore/ops/operations/_tensor_array.py +8 -8
  229. mindspore/ops/operations/array_ops.py +65 -82
  230. mindspore/ops/operations/comm_ops.py +93 -784
  231. mindspore/ops/operations/custom_ops.py +28 -51
  232. mindspore/ops/operations/debug_ops.py +4 -4
  233. mindspore/ops/operations/inner_ops.py +2 -2
  234. mindspore/ops/operations/manually_defined/ops_def.py +4 -304
  235. mindspore/ops/operations/math_ops.py +50 -3
  236. mindspore/ops/operations/nn_ops.py +247 -14
  237. mindspore/ops/operations/other_ops.py +3 -3
  238. mindspore/ops/operations/random_ops.py +1 -1
  239. mindspore/ops/operations/sparse_ops.py +1 -1
  240. mindspore/ops/primitive.py +8 -9
  241. mindspore/ops/silent_check.py +5 -5
  242. mindspore/ops_generate/arg_dtype_cast.py +9 -2
  243. mindspore/ops_generate/arg_handler.py +0 -26
  244. mindspore/ops_generate/gen_aclnn_implement.py +4 -1
  245. mindspore/ops_generate/gen_ops.py +4 -26
  246. mindspore/ops_generate/gen_pyboost_func.py +12 -41
  247. mindspore/ops_generate/gen_utils.py +0 -21
  248. mindspore/ops_generate/pyboost_utils.py +2 -7
  249. mindspore/ops_generate/template.py +0 -1
  250. mindspore/parallel/_auto_parallel_context.py +1 -21
  251. mindspore/parallel/_tensor.py +5 -0
  252. mindspore/parallel/_transformer/transformer.py +1 -1
  253. mindspore/parallel/_utils.py +1 -15
  254. mindspore/parallel/algo_parameter_config.py +3 -1
  255. mindspore/parallel/checkpoint_transform.py +9 -12
  256. mindspore/parallel/cluster/process_entity/_api.py +29 -28
  257. mindspore/parallel/cluster/process_entity/_utils.py +3 -13
  258. mindspore/parallel/cluster/run.py +16 -13
  259. mindspore/parallel/parameter_broadcast.py +2 -2
  260. mindspore/parallel/shard.py +17 -31
  261. mindspore/profiler/__init__.py +2 -3
  262. mindspore/profiler/common/util.py +2 -107
  263. mindspore/profiler/envprofiling.py +1 -1
  264. mindspore/profiler/parser/ascend_analysis/constant.py +21 -8
  265. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -82
  266. mindspore/profiler/parser/ascend_analysis/function_event.py +28 -43
  267. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +27 -49
  268. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +10 -15
  269. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +20 -25
  270. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +5 -5
  271. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +1 -10
  272. mindspore/profiler/parser/ascend_hccl_generator.py +1 -4
  273. mindspore/profiler/parser/ascend_msprof_exporter.py +22 -43
  274. mindspore/profiler/parser/ascend_timeline_generator.py +5 -7
  275. mindspore/profiler/parser/minddata_parser.py +3 -72
  276. mindspore/profiler/profiling.py +59 -176
  277. mindspore/rewrite/api/node.py +1 -1
  278. mindspore/rewrite/common/namespace.py +5 -5
  279. mindspore/rewrite/parsers/assign_parser.py +0 -2
  280. mindspore/rewrite/parsers/class_def_parser.py +4 -8
  281. mindspore/run_check/_check_version.py +1 -1
  282. mindspore/scipy/fft.py +3 -1
  283. mindspore/scipy/linalg.py +3 -2
  284. mindspore/scipy/ops.py +3 -5
  285. mindspore/scipy/optimize/__init__.py +2 -2
  286. mindspore/train/__init__.py +4 -4
  287. mindspore/train/anf_ir_pb2.py +2 -8
  288. mindspore/train/callback/__init__.py +2 -5
  289. mindspore/train/callback/_backup_and_restore.py +2 -2
  290. mindspore/train/callback/_checkpoint.py +16 -104
  291. mindspore/train/callback/_landscape.py +1 -1
  292. mindspore/train/callback/_time_monitor.py +1 -1
  293. mindspore/train/data_sink.py +4 -5
  294. mindspore/train/dataset_helper.py +20 -45
  295. mindspore/train/model.py +38 -266
  296. mindspore/train/serialization.py +105 -256
  297. mindspore/train/summary/_summary_adapter.py +1 -1
  298. mindspore/version.py +1 -1
  299. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
  300. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +303 -420
  301. mindspore/_extends/pijit/__init__.py +0 -23
  302. mindspore/_extends/pijit/pijit_func_white_list.py +0 -343
  303. mindspore/common/file_system.py +0 -48
  304. mindspore/common/generator.py +0 -260
  305. mindspore/common/no_inline.py +0 -54
  306. mindspore/common/np_dtype.py +0 -25
  307. mindspore/communication/comm_func.py +0 -1140
  308. mindspore/hal/memory.py +0 -326
  309. mindspore/lib/libavcodec.so.59 +0 -0
  310. mindspore/lib/libavdevice.so.59 +0 -0
  311. mindspore/lib/libavfilter.so.8 +0 -0
  312. mindspore/lib/libavformat.so.59 +0 -0
  313. mindspore/lib/libavutil.so.57 +0 -0
  314. mindspore/lib/libmindspore_np_dtype.so +0 -0
  315. mindspore/lib/libswresample.so.4 +0 -0
  316. mindspore/lib/libswscale.so.6 +0 -0
  317. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.cpp +0 -326
  318. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.py +0 -180
  319. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.json +0 -58
  320. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.o +0 -0
  321. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.json +0 -58
  322. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.o +0 -0
  323. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.json +0 -58
  324. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.o +0 -0
  325. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +0 -109
  326. mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +0 -38
  327. mindspore/lib/plugin/ascend/custom_compiler/OWNERS +0 -12
  328. mindspore/lib/plugin/ascend/custom_compiler/setup.py +0 -255
  329. mindspore/lib/plugin/ascend/custom_compiler/start.sh +0 -26
  330. mindspore/lib/plugin/ascend/custom_compiler/template.json +0 -40
  331. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme.h +0 -24
  332. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +0 -69
  333. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/base_type.h +0 -133
  334. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_creator.h +0 -32
  335. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_param.h +0 -35
  336. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/tiling_info.h +0 -60
  337. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/kernel_register.h +0 -37
  338. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/platform_configs.h +0 -89
  339. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/rt_funcs.h +0 -135
  340. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/add_op.h +0 -34
  341. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_backoff_base.h +0 -62
  342. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_elewise_op.h +0 -33
  343. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_ops.h +0 -88
  344. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_pa_op.h +0 -45
  345. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/cast_op.h +0 -52
  346. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/matmul_op.h +0 -95
  347. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/asd_utils.h +0 -84
  348. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/comm_utils.h +0 -61
  349. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp32.h +0 -224
  350. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/and_impl.h +0 -29
  351. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/div_impl.h +0 -29
  352. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_impl.h +0 -48
  353. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_tiling.h +0 -25
  354. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/and_kernel.h +0 -46
  355. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/div_kernel.h +0 -46
  356. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_base.h +0 -260
  357. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_kernel.h +0 -35
  358. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/max_kernel.h +0 -66
  359. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/min_kernel.h +0 -66
  360. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/mul_kernel.h +0 -66
  361. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/or_kernel.h +0 -46
  362. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/max_impl.h +0 -29
  363. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/min_impl.h +0 -29
  364. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/mul_impl.h +0 -29
  365. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/or_impl.h +0 -29
  366. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/abs_impl.h +0 -29
  367. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_impl.h +0 -47
  368. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_tiling.h +0 -24
  369. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/exp_impl.h +0 -29
  370. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/abs_kernel.h +0 -45
  371. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_base.h +0 -148
  372. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_kernel.h +0 -31
  373. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/exp_kernel.h +0 -45
  374. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/ln_kernel.h +0 -45
  375. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/not_kernel.h +0 -45
  376. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/reciprocal_kernel.h +0 -45
  377. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/relu_kernel.h +0 -55
  378. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/rsqrt_kernel.h +0 -45
  379. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/sqrt_kernel.h +0 -45
  380. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/ln_impl.h +0 -29
  381. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/not_impl.h +0 -29
  382. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/reciprocal_impl.h +0 -29
  383. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/relu_impl.h +0 -29
  384. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/rsqrt_impl.h +0 -29
  385. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/sqrt_impl.h +0 -29
  386. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_impl.h +0 -45
  387. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_tiling.h +0 -187
  388. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul.h +0 -245
  389. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_interface.h +0 -24
  390. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_utils.h +0 -111
  391. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/tiling_data.h +0 -54
  392. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/compare_param.h +0 -31
  393. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/elewise_param.h +0 -41
  394. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/grouped_matmul_param.h +0 -40
  395. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/profiling_util.h +0 -364
  396. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_utils.h +0 -69
  397. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_creator.h +0 -39
  398. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_registry.h +0 -114
  399. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/utils.h +0 -98
  400. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.json +0 -19
  401. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.o +0 -0
  402. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aic_0.o +0 -0
  403. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aiv_0.o +0 -0
  404. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.json +0 -19
  405. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.o +0 -0
  406. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aic_0.o +0 -0
  407. mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aiv_0.o +0 -0
  408. mindspore/mint/linalg/__init__.py +0 -22
  409. mindspore/nn/layer/embedding_service.py +0 -531
  410. mindspore/nn/layer/embedding_service_layer.py +0 -393
  411. mindspore/ops/function/reshard_func.py +0 -102
  412. mindspore/ops/operations/_infer_ops.py +0 -19
  413. mindspore/ops/operations/reshard_ops.py +0 -53
  414. mindspore/profiler/common/process_pool.py +0 -41
  415. mindspore/profiler/common/singleton.py +0 -28
  416. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  417. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  418. mindspore/train/callback/_cluster_monitor.py +0 -201
  419. mindspore/train/callback/_flops_collector.py +0 -238
  420. mindspore/train/callback/_mindio_ttp.py +0 -443
  421. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
  422. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
  423. {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/nn/loss/loss.py CHANGED
@@ -1820,10 +1820,10 @@ class MultilabelMarginLoss(LossBase):
1820
1820
 
1821
1821
  class BCEWithLogitsLoss(LossBase):
1822
1822
  r"""
1823
- Adds sigmoid activation function to input `input` as logits, and uses the given logits to compute binary cross
1824
- entropy between the `input` and the `target`.
1823
+ Adds sigmoid activation function to input logits, and uses the given logits to compute binary cross entropy
1824
+ between the logits and the labels.
1825
1825
 
1826
- Sets input `input` as :math:`X`, input `target` as :math:`Y`, output as :math:`L`. Then,
1826
+ Sets input `logits` as :math:`X`, input `labels` as :math:`Y`, output as :math:`L`. Then,
1827
1827
 
1828
1828
  .. math::
1829
1829
  p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
@@ -1849,29 +1849,29 @@ class BCEWithLogitsLoss(LossBase):
1849
1849
  - ``'sum'``: the output elements will be summed.
1850
1850
 
1851
1851
  weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
1852
- If not None, it can be broadcast to a tensor with shape of `input`,
1852
+ If not None, it can be broadcast to a tensor with shape of `logits`,
1853
1853
  data type must be float16 or float32. Default: ``None`` .
1854
1854
  pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
1855
- number of classes. If not None, it must be broadcast to a tensor with shape of `input`, data type
1855
+ number of classes. If not None, it must be broadcast to a tensor with shape of `logits`, data type
1856
1856
  must be float16 or float32. Default: ``None`` .
1857
1857
 
1858
1858
  Inputs:
1859
- - **input** (Tensor) - Input `input` with shape :math:`(N, *)` where :math:`*` means, any number
1859
+ - **logits** (Tensor) - Input logits with shape :math:`(N, *)` where :math:`*` means, any number
1860
1860
  of additional dimensions. The data type must be float16 or float32.
1861
- - **target** (Tensor) - Ground truth label with shape :math:`(N, *)` where :math:`*` means, any number
1862
- of additional dimensions. The same shape and data type as `input`.
1861
+ - **labels** (Tensor) - Ground truth label with shape :math:`(N, *)` where :math:`*` means, any number
1862
+ of additional dimensions. The same shape and data type as `logits`.
1863
1863
 
1864
1864
  Outputs:
1865
- Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `input`.
1865
+ Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `logits`.
1866
1866
  Otherwise, a scalar value will be returned.
1867
1867
 
1868
1868
  Raises:
1869
- TypeError: If input `input` or `target` is not Tensor.
1870
- TypeError: If data type of `input` or `target` is neither float16 nor float32.
1869
+ TypeError: If input `logits` or `labels` is not Tensor.
1870
+ TypeError: If data type of `logits` or `labels` is neither float16 nor float32.
1871
1871
  TypeError: If `weight` or `pos_weight` is a parameter.
1872
1872
  TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
1873
1873
  TypeError: If data type of `reduction` is not string.
1874
- ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `input`.
1874
+ ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `logits`.
1875
1875
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
1876
1876
 
1877
1877
  Supported Platforms:
@@ -1881,10 +1881,10 @@ class BCEWithLogitsLoss(LossBase):
1881
1881
  >>> import mindspore as ms
1882
1882
  >>> import mindspore.nn as nn
1883
1883
  >>> import numpy as np
1884
- >>> input = ms.Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
1885
- >>> target = ms.Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
1884
+ >>> logits = ms.Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
1885
+ >>> labels = ms.Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
1886
1886
  >>> loss = nn.BCEWithLogitsLoss()
1887
- >>> output = loss(input, target)
1887
+ >>> output = loss(logits, labels)
1888
1888
  >>> print(output)
1889
1889
  0.3463612
1890
1890
  """
@@ -1900,10 +1900,10 @@ class BCEWithLogitsLoss(LossBase):
1900
1900
  self.weight = weight
1901
1901
  self.pos_weight = pos_weight
1902
1902
 
1903
- def construct(self, input, target):
1904
- _check_is_tensor('input', input, self.cls_name)
1905
- _check_is_tensor('target', target, self.cls_name)
1906
- loss = ops.binary_cross_entropy_with_logits(input, target, self.weight, self.pos_weight, self.reduction)
1903
+ def construct(self, logits, labels):
1904
+ _check_is_tensor('logits', logits, self.cls_name)
1905
+ _check_is_tensor('labels', labels, self.cls_name)
1906
+ loss = ops.binary_cross_entropy_with_logits(logits, labels, self.weight, self.pos_weight, self.reduction)
1907
1907
  return loss
1908
1908
 
1909
1909
 
@@ -29,7 +29,7 @@ from mindspore.parallel._utils import _get_global_rank, _get_stage_device_num
29
29
  from mindspore.ops import composite as C
30
30
  from mindspore.ops import functional as F
31
31
  from mindspore.ops import operations as P
32
- from mindspore.ops import Send, Receive
32
+ from mindspore.ops.operations._inner_ops import Send, Receive
33
33
  from mindspore.common.tensor import Tensor
34
34
  from mindspore.common import dtype as mstype
35
35
  from mindspore.communication.management import create_group
mindspore/nn/optim/sgd.py CHANGED
@@ -195,9 +195,9 @@ class SGD(Optimizer):
195
195
  "or 'weight_decay' set in grouped 'params' must be float or int type.")
196
196
 
197
197
  if hasattr(self, "group_weight_decay") and self.group_weight_decay:
198
- self.opt = tuple(P.SGD(dampening, 0.0, nesterov) for wd in self.group_weight_decay)
198
+ self.opt = tuple(P.SGD(dampening, wd, nesterov) for wd in self.group_weight_decay)
199
199
  else:
200
- self.opt = tuple([P.SGD(dampening, 0.0, nesterov)] * len(self._parameters))
200
+ self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
201
201
 
202
202
  self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
203
203
 
@@ -222,7 +222,6 @@ class SGD(Optimizer):
222
222
  params = self._parameters
223
223
  accum = self.accum
224
224
  stat = self.stat
225
- gradients = self.decay_weight(gradients)
226
225
  gradients = self.flatten_gradients(gradients)
227
226
  gradients = self.gradients_centralization(gradients)
228
227
  gradients = self.scale_grad(gradients)
@@ -152,7 +152,7 @@ class Exponential(Distribution):
152
152
  if self.rate is not None:
153
153
  check_greater_zero(self.rate, 'rate')
154
154
 
155
- self.minval = np.finfo(np.float_).tiny
155
+ self.minval = np.finfo(np.float).tiny
156
156
 
157
157
  # ops needed for the class
158
158
  self.exp = exp_generic
@@ -152,7 +152,7 @@ class Geometric(Distribution):
152
152
  if self._probs is not None:
153
153
  check_prob(self.probs)
154
154
 
155
- self.minval = np.finfo(np.float_).tiny
155
+ self.minval = np.finfo(np.float).tiny
156
156
 
157
157
  # ops needed for the class
158
158
  self.exp = exp_generic
@@ -170,7 +170,7 @@ class Logistic(Distribution):
170
170
  self.neg = P.Neg()
171
171
 
172
172
  self.threshold = np.log(np.finfo(np.float32).eps) + 1.
173
- self.tiny = np.finfo(np.float_).tiny
173
+ self.tiny = np.finfo(np.float).tiny
174
174
  self.sd_const = np.pi / np.sqrt(3)
175
175
 
176
176
  def _softplus(self, x):
@@ -17,18 +17,16 @@
17
17
  from __future__ import absolute_import
18
18
  from __future__ import division
19
19
 
20
- import os
21
20
  from types import FunctionType, MethodType
22
21
 
23
22
  from mindspore import log as logger
24
23
  from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
25
24
  _get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
26
- from mindspore.context import ParallelMode, GRAPH_MODE, get_context
25
+ from mindspore.context import ParallelMode
27
26
  from mindspore import _checkparam as validator
28
27
  from mindspore import ops, nn
29
28
  from mindspore.common import dtype as mstype
30
29
  from mindspore.common.parameter import Parameter, ParameterTuple
31
- from mindspore.common.tensor import Tensor
32
30
  from mindspore.ops.primitive import _primexpr
33
31
  from mindspore.ops import composite as C
34
32
  from mindspore.ops import functional as F
@@ -742,18 +740,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
742
740
  self.hyper_map = ops.HyperMap()
743
741
  self.opt_shard = _get_enable_parallel_optimizer()
744
742
  self._get_attr_from_cell(network)
745
- self.enable_mindio = False
746
- mode = get_context("mode")
747
- device_type = get_context("device_target")
748
- if device_type != "Ascend" or mode != GRAPH_MODE:
749
- return
750
- graceful_exit = os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT")
751
- ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
752
- ttp_path_check = ttp_lib_path is not None and os.path.isfile(ttp_lib_path)
753
- if graceful_exit == "true" and ttp_path_check:
754
- self.g_one = Tensor([0.1])
755
- self.allreduce_sum = ops.AllReduce()
756
- self.enable_mindio = True
757
743
 
758
744
  def construct(self, *inputs):
759
745
  if not self.sense_flag:
@@ -762,11 +748,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
762
748
  sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
763
749
  grads = self.grad(self.network, self.weights)(*inputs, sens)
764
750
  accu_grads = ops.depend(self.accu_grads, grads)
765
- if self.enable_mindio:
766
- g_one = ops.depend(self.g_one, accu_grads)
767
- g_one_res = self.allreduce_sum(g_one)
768
- accu_grads = ops.depend(accu_grads, g_one_res)
769
- grads = ops.depend(grads, g_one_res)
770
751
  if self.opt_shard:
771
752
  succ = self.optimizer(grads)
772
753
  else:
@@ -781,11 +762,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
781
762
  loss = self.network(*inputs)
782
763
  grads = self.grad_no_sens(self.network, self.weights)(*inputs)
783
764
  accu_grads = ops.depend(self.accu_grads, grads)
784
- if self.enable_mindio:
785
- g_one = ops.depend(self.g_one, accu_grads)
786
- g_one_res = self.allreduce_sum(g_one)
787
- accu_grads = ops.depend(accu_grads, g_one_res)
788
- grads = ops.depend(grads, g_one_res)
789
765
  if self.opt_shard:
790
766
  succ = self.optimizer(grads)
791
767
  else:
@@ -29,7 +29,6 @@ from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloat
29
29
  from mindspore.ops import functional as F
30
30
  from mindspore.ops import composite as C
31
31
  from mindspore.ops import operations as P
32
- from mindspore.ops.operations.nn_ops import AllFinite
33
32
  from mindspore.common import dtype as mstype
34
33
  from mindspore.common.api import jit
35
34
  from mindspore._c_expression import MSContext
@@ -373,15 +372,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
373
372
  self.loss_scaling_manager = None
374
373
  self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
375
374
 
376
- self.enable_allfinite = False
377
- runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
378
- global_jit_config = context.get_jit_config()
379
- if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
380
- self.enable_allfinite = True
381
- elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
382
- self.enable_allfinite = False
383
- elif global_jit_config:
384
- self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
385
375
 
386
376
  if isinstance(scale_sense, Cell):
387
377
  self.loss_scaling_manager = scale_sense
@@ -488,15 +478,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
488
478
  overflow = self.less_equal(self.base, flag_sum)
489
479
  return overflow
490
480
 
491
- def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output):
492
- """check overflow status on infnan kernel mode."""
493
- overflow = AllFinite()(compute_output)
494
-
495
- if self.is_distributed:
496
- overflow = P.Cast()(overflow, mstype.int8)
497
- overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
498
- return overflow
499
-
500
481
  def _get_gpu_overflow_status(self, compute_output):
501
482
  """get overflow status of gpu."""
502
483
  overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
@@ -504,11 +485,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
504
485
 
505
486
  def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
506
487
  """get overflow status of ascend on infnan mode."""
507
- overflow = False
508
- if self.enable_allfinite:
509
- overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output)
510
- else:
511
- overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
488
+ overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
512
489
  return overflow
513
490
 
514
491
  def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
@@ -2606,11 +2606,7 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
2606
2606
  array1 = ar1.ravel()
2607
2607
  array2 = ar2.ravel()
2608
2608
  concat_array = concatenate((array1, array2))
2609
- if return_indices:
2610
- concat_sort_indices = F.argsort(concat_array)
2611
- concat_array = concat_array[concat_sort_indices]
2612
- else:
2613
- concat_array, concat_sort_indices = concat_array.sort()
2609
+ concat_array, concat_sort_indices = concat_array.sort()
2614
2610
 
2615
2611
  mask_res = concat_array[1:] == concat_array[:-1]
2616
2612
  res = F.masked_select(concat_array[1:], mask_res)
mindspore/numpy/dtypes.py CHANGED
@@ -86,7 +86,7 @@ dtype_map = {
86
86
  }
87
87
 
88
88
  all_types = [
89
- 'np.int_',
89
+ 'np.int',
90
90
  'np.int8',
91
91
  'np.int16',
92
92
  'np.int32',
@@ -96,11 +96,11 @@ all_types = [
96
96
  'np.uint16',
97
97
  'np.uint32',
98
98
  'np.uint64',
99
- 'np.float_',
99
+ 'np.float',
100
100
  'np.float16',
101
101
  'np.float32',
102
102
  'np.float64',
103
- 'np.bool_']
103
+ 'np.bool']
104
104
 
105
105
  promotion_rule = {
106
106
  (uint8, uint16): uint16,
@@ -4166,18 +4166,18 @@ def multi_dot(arrays):
4166
4166
  Examples:
4167
4167
  >>> import mindspore.numpy as np
4168
4168
  >>> A = np.ones((10000, 100))
4169
- >>> B = np.ones((100, 100))
4170
- >>> C = np.ones((100, 5))
4169
+ >>> B = np.ones((100, 1000))
4170
+ >>> C = np.ones((1000, 5))
4171
4171
  >>> D = np.ones((5, 333))
4172
4172
  >>> output = np.multi_dot([A, B, C, D])
4173
4173
  >>> print(output)
4174
- [[50000. 50000. 50000. ... 50000. 50000. 50000.]
4175
- [50000. 50000. 50000. ... 50000. 50000. 50000.]
4176
- [50000. 50000. 50000. ... 50000. 50000. 50000.]
4174
+ [[500000. 500000. 500000. ... 500000. 500000. 500000.]
4175
+ [500000. 500000. 500000. ... 500000. 500000. 500000.]
4176
+ [500000. 500000. 500000. ... 500000. 500000. 500000.]
4177
4177
  ...
4178
- [50000. 50000. 50000. ... 50000. 50000. 50000.]
4179
- [50000. 50000. 50000. ... 50000. 50000. 50000.]
4180
- [50000. 50000. 50000. ... 50000. 50000. 50000.]]
4178
+ [500000. 500000. 500000. ... 500000. 500000. 500000.]
4179
+ [500000. 500000. 500000. ... 500000. 500000. 500000.]
4180
+ [500000. 500000. 500000. ... 500000. 500000. 500000.]]
4181
4181
  """
4182
4182
  if len(arrays) < 2:
4183
4183
  _raise_value_error('Expecting at least 2 arrays')
mindspore/ops/__init__.py CHANGED
@@ -44,7 +44,7 @@ __primitive__ = [
44
44
  __all__ = ["get_vm_impl_fn", "vm_impl_registry",
45
45
  "op_info_register", "custom_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp",
46
46
  "CpuRegOp", "CustomRegOp", "DataType",
47
- "constexpr", "reshard"]
47
+ "constexpr"]
48
48
  __all__.extend(__primitive__)
49
49
  __all__.extend(composite.__all__)
50
50
  __all__.extend(operations.__all__)
@@ -22,8 +22,7 @@ from mindspore.ops import functional as F
22
22
  from mindspore.communication import get_rank, get_group_size
23
23
  from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
24
24
  from mindspore.ops import operations as P
25
- from mindspore.ops import Send, Receive
26
- from mindspore.ops.operations._inner_ops import issubclass_
25
+ from mindspore.ops.operations._inner_ops import Send, Receive, issubclass_
27
26
  from mindspore.common.sparse_tensor import RowTensorInner
28
27
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
29
28
  from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce,
@@ -31,7 +30,7 @@ from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _H
31
30
  _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
32
31
  ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
33
32
  _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator,
34
- _MicroStepAllGather, Reduce, CollectiveGather, CollectiveScatter)
33
+ _MicroStepAllGather)
35
34
  from mindspore.ops._grad_experimental.grad_base import bprop_getters
36
35
  from mindspore.ops.operations import _grad_ops as G
37
36
 
@@ -211,17 +210,21 @@ def get_bprop_mirror_micro_step_operator(self):
211
210
  def bprop(x, z, out, dout):
212
211
  real_grad = z
213
212
  assign_out = dout
214
- if issubclass_(F.typeof(dout), mstype.tensor_type):
215
- z = F.depend(z, dout)
216
- if dev_num > 1:
213
+ if mean_flag:
214
+ if issubclass_(F.typeof(dout), mstype.tensor_type):
215
+ z = F.depend(z, dout)
217
216
  real_grad = all_reduce(z)
218
- if mean_flag:
219
- real_grad = F.tensor_mul(real_grad, scale)
220
- else:
221
- real_grad = z
222
- if opt_shard:
223
- return (real_grad, cast(out_tensor, dtype(z)))
224
- return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
217
+ real_grad = F.tensor_mul(real_grad, scale)
218
+ if opt_shard:
219
+ return (real_grad, cast(out_tensor, dtype(z)))
220
+ return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
221
+ else:
222
+ if issubclass_(F.typeof(dout), mstype.tensor_type):
223
+ z = F.depend(z, dout)
224
+ real_grad = all_reduce(z)
225
+ if opt_shard:
226
+ return (real_grad, cast(out_tensor, dtype(z)))
227
+ return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
225
228
  return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
226
229
 
227
230
  return bprop
@@ -241,13 +244,11 @@ def get_bprop_broad_cast(self):
241
244
  def get_bprop_all_gather(self):
242
245
  """Generate bprop for AllGather"""
243
246
  fusion = self.get_attr_dict()["fusion"]
244
- self.group = self.get_attr_dict()["group"]
245
247
  reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
246
248
  if hasattr(self, "instance_name") and self.instance_name:
247
249
  instance_name = "grad_" + self.instance_name
248
250
  reduce_scatter.set_prim_instance_name(instance_name)
249
251
  mean_flag = self.get_attr_dict()["mean_flag"]
250
- self.rank_size = self.get_attr_dict()["rank_size"]
251
252
  if self.rank_size == 0:
252
253
  raise ValueError(f"The 'rank_size' can not be zero, but got {self.rank_size}.")
253
254
  scale = 1.0 / self.rank_size
@@ -377,66 +378,6 @@ def get_bprop_reduce_scatter(self):
377
378
  return bprop
378
379
 
379
380
 
380
- @bprop_getters.register(Reduce)
381
- def get_bprop_reduce(self):
382
- """Generate bprop for Reduce"""
383
- dest_rank = self.get_attr_dict()["dest_rank"]
384
- group = self.get_attr_dict()["group"]
385
- reduce_grad = Broadcast(dest_rank, group)
386
- if hasattr(self, "instance_name") and self.instance_name:
387
- instance_name = "grad" + self.instance_name
388
- reduce_grad.set_prim_instance_name(instance_name)
389
-
390
- def bprop(x, out, dout):
391
- dx = reduce_grad((dout,))
392
- return (dx[0],)
393
-
394
- return bprop
395
-
396
-
397
- @bprop_getters.register(CollectiveGather)
398
- def get_bprop_collective_gather(self):
399
- """Generate bprop for CollectiveGather"""
400
- group = self.get_attr_dict()["group"]
401
- dest_rank = self.get_attr_dict()["dest_rank"]
402
- collective_gather_grad = Broadcast(dest_rank, group)
403
- rank = get_rank(group)
404
- dev_num = self.rank_size
405
- split = P.Split(output_num=dev_num)
406
- if hasattr(self, "instance_name") and self.instance_name:
407
- instance_name = "grad" + self.instance_name
408
- collective_gather_grad.set_prim_instance_name(instance_name)
409
-
410
- def bprop(x, out, dout):
411
- grad = collective_gather_grad((dout,))
412
- dx = split(grad[0])[rank]
413
- return (dx,)
414
-
415
- return bprop
416
-
417
-
418
- @bprop_getters.register(CollectiveScatter)
419
- def get_bprop_collective_scatter(self):
420
- """Generate bprop for CollectiveScatter"""
421
- group = self.get_attr_dict()["group"]
422
- dest_rank = self.get_attr_dict()["src_rank"]
423
- rank = get_rank(group)
424
- collective_scatter_grad = CollectiveGather(dest_rank, group)
425
- if hasattr(self, "instance_name") and self.instance_name:
426
- instance_name = "grad" + self.instance_name
427
- collective_scatter_grad.set_prim_instance_name(instance_name)
428
-
429
- def bprop(x, out, dout):
430
- dx_out = collective_scatter_grad(dout)
431
- if rank == dest_rank:
432
- dx = dx_out
433
- else:
434
- dx = F.depend(F.zeros_like(x), dx_out)
435
- return (dx,)
436
-
437
- return bprop
438
-
439
-
440
381
  @bprop_getters.register(_AllSwap)
441
382
  def get_bprop_allswap(self):
442
383
  """Generate bprop for _AllSwap."""
@@ -2113,33 +2113,6 @@ def get_split_vmap_rule(prim, axis_size):
2113
2113
 
2114
2114
  return vmap_rule
2115
2115
 
2116
- @vmap_rules_getters.register(P.SearchSorted)
2117
- def get_searchsorted_vmap_rule(prim, axis_size):
2118
- """VmapRule for `SearchSorted`."""
2119
- def vmap_rule(sequence_bdim, values_bdim, sorter_bdim, dtype_bdim, right_bdim):
2120
- is_all_none, result = vmap_general_preprocess(prim, sequence_bdim, values_bdim,
2121
- sorter_bdim, dtype_bdim, right_bdim)
2122
- if is_all_none:
2123
- return result
2124
-
2125
- sequence, sequence_dim = sequence_bdim
2126
- values, values_dim = values_bdim
2127
- sorter, sorter_dim = sorter_bdim
2128
-
2129
- sequence = _bdim_at_front(sequence, sequence_dim, axis_size)
2130
- values = _bdim_at_front(values, values_dim, axis_size)
2131
- if sorter is not None and sorter_dim is not None:
2132
- sorter = _bdim_at_front(sorter, sorter_dim, axis_size)
2133
-
2134
- dtype, _ = dtype_bdim
2135
- right, _ = right_bdim
2136
-
2137
- outputs = prim(sequence, values, sorter, dtype, right)
2138
-
2139
- return outputs, 0
2140
-
2141
- return vmap_rule
2142
-
2143
2116
 
2144
2117
  get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
2145
2118
  get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
@@ -63,6 +63,7 @@ def _broadcast_shape(nd, x_ndim, x_shape):
63
63
  @vmap_rules_getters.register(P.BitwiseAnd)
64
64
  @vmap_rules_getters.register(P.BitwiseOr)
65
65
  @vmap_rules_getters.register(P.BitwiseXor)
66
+ @vmap_rules_getters.register(P.IsClose)
66
67
  @vmap_rules_getters.register(P.Xlogy)
67
68
  @vmap_rules_getters.register(P.ApproximateEqual)
68
69
  @vmap_rules_getters.register(P.TruncateDiv)
@@ -887,35 +888,6 @@ def get_logit_vmap_rule(prim_func, axis_size):
887
888
 
888
889
  return vmap_rule
889
890
 
890
-
891
- @vmap_rules_getters.register(P.IsClose)
892
- def get_isclose_vmap_rule(prim, axis_size):
893
- """VmapRule for `IsClose` operation"""
894
-
895
- def vmap_rule(x_bdim, y_bdim, rtol_bdim, atol_bdim, equal_nan_bdim):
896
- is_all_none, result = vmap_general_preprocess(prim, x_bdim, x_bdim, rtol_bdim, atol_bdim, equal_nan_bdim)
897
- if is_all_none:
898
- return result
899
-
900
- x, x_dim = x_bdim
901
- y, y_dim = y_bdim
902
- rtol, _ = rtol_bdim
903
- atol, _ = atol_bdim
904
- equal_nan, _ = equal_nan_bdim
905
-
906
- if x_dim == y_dim:
907
- out = prim(x, y, rtol, atol, equal_nan)
908
- return out, x_dim
909
- if y_dim is None:
910
- y = _broadcast_by_axis(y, x_dim, axis_size)
911
- else:
912
- y = mnp.moveaxis(y, y_dim, x_dim)
913
-
914
- out = prim(x, y, rtol, atol, equal_nan)
915
- return out, x_dim
916
-
917
- return vmap_rule
918
-
919
891
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
920
892
  get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
921
893
 
@@ -31,7 +31,6 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
31
31
  from mindspore.ops.primitive import Primitive
32
32
  from mindspore.ops.auto_generate.gen_arg_handler import Format
33
33
  from mindspore.ops.auto_generate import Embedding
34
- from mindspore.ops.auto_generate import gen_arg_handler as handler
35
34
 
36
35
 
37
36
  @vmap_rules_getters.register(P.ApplyAdaMax)
@@ -299,19 +298,25 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
299
298
 
300
299
  if isinstance(prim, str):
301
300
  prim = Primitive(prim)
301
+ prim_reduction = 'none'
302
+ else:
303
+ prim_reduction = prim.reduction
302
304
  prim_name = prim.name
303
305
  bce_logits_with_loss_op = NN.BCEWithLogitsLoss('none')
306
+ if prim_reduction == 'mean':
307
+ reduce_op = P.ReduceMean()
308
+ elif prim_reduction == "sum":
309
+ reduce_op = P.ReduceSum()
304
310
 
305
- def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim, reduction_bdim):
306
- is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim, weight_bdim, pos_weight_bdim,
307
- reduction_bdim)
311
+ def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim):
312
+ is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim,
313
+ weight_bdim, pos_weight_bdim)
308
314
  if is_all_none:
309
315
  return result
310
316
  logits, logits_dim = logits_bdim
311
317
  label, label_dim = label_bdim
312
318
  weight, weight_dim = weight_bdim
313
319
  pos_weight, pos_weight_dim = pos_weight_bdim
314
- prim_reduction, _ = reduction_bdim
315
320
  logits_rank = F.rank(logits)
316
321
  label_rank = F.rank(label)
317
322
  weight_rank = F.rank(weight)
@@ -327,14 +332,11 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
327
332
  shape = F.shape(logits)
328
333
  shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
329
334
  if logits_dim_ok and shape_ok:
330
- if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
331
- output = prim(logits, label, weight, pos_weight, prim_reduction)
332
- elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
335
+ if prim_reduction == 'none':
336
+ output = prim(logits, label, weight, pos_weight)
337
+ elif prim_reduction in ('mean', 'sum'):
333
338
  out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
334
- output = P.ReduceMean()(out, reduce_indexes)
335
- elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
336
- out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
337
- output = P.ReduceSum()(out, reduce_indexes)
339
+ output = reduce_op(out, reduce_indexes)
338
340
  else:
339
341
  raise RuntimeError("For {} vmap, the attribute of reduction must in "
340
342
  "('none', 'mean', 'sum'), but got {}."
@@ -350,14 +352,11 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
350
352
  pos_weight_shape = F.shape(pos_weight)
351
353
  weight = _handle_broadcasting(weight, weight_shape, logits_shape)
352
354
  pos_weight = _handle_broadcasting(pos_weight, pos_weight_shape, logits_shape)
353
- if prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'none'):
354
- output = prim(logits, label, weight, pos_weight, prim_reduction)
355
- elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'mean'):
356
- out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
357
- output = P.ReduceMean()(out, reduce_indexes)
358
- elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
355
+ if prim_reduction == 'none':
356
+ output = prim(logits, label, weight, pos_weight)
357
+ elif prim_reduction in ('mean', 'sum'):
359
358
  out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
360
- output = P.ReduceSum()(out, reduce_indexes)
359
+ output = reduce_op(out, reduce_indexes)
361
360
  else:
362
361
  raise RuntimeError("For {} vmap, the attribute of reduction must in "
363
362
  "('none', 'mean', 'sum'), but got {}."