mindspore 2.3.0__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +0 -1512
- mindspore/__init__.py +1 -2
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +25 -5
- mindspore/_extends/graph_kernel/model/graph_parallel.py +1 -1
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +0 -29
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +5 -21
- mindspore/_extends/parse/resources.py +7 -5
- mindspore/_extends/parse/standard_method.py +59 -40
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -26
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/less_batch_normalization.py +6 -9
- mindspore/common/__init__.py +1 -8
- mindspore/common/_register_for_tensor.py +9 -8
- mindspore/common/api.py +65 -275
- mindspore/common/dtype.py +4 -8
- mindspore/common/dump.py +5 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/lazy_inline.py +2 -14
- mindspore/common/parameter.py +15 -14
- mindspore/common/recompute.py +5 -20
- mindspore/common/sparse_tensor.py +6 -21
- mindspore/common/tensor.py +52 -100
- mindspore/communication/__init__.py +11 -6
- mindspore/communication/management.py +94 -92
- mindspore/context.py +18 -180
- mindspore/dataset/engine/datasets.py +46 -69
- mindspore/dataset/engine/datasets_user_defined.py +53 -72
- mindspore/dataset/engine/datasets_vision.py +2 -2
- mindspore/dataset/engine/queue.py +38 -56
- mindspore/dataset/engine/validators.py +5 -11
- mindspore/dataset/vision/__init__.py +5 -5
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +46 -591
- mindspore/dataset/vision/utils.py +1 -121
- mindspore/dataset/vision/validators.py +3 -9
- mindspore/hal/__init__.py +1 -7
- mindspore/hal/device.py +1 -1
- mindspore/include/api/model.h +0 -3
- mindspore/include/dataset/vision.h +2 -54
- mindspore/include/mindapi/base/types.h +0 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +0 -35
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +0 -72
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/{aclnn_all_finite.h → aclnn_add_custom.h} +11 -9
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +12 -184
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +5 -4
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +286 -275
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/add_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/backend_param.h +0 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/cast/cast_tiling.h +45 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/compare/compare_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_impl.h +4 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_tiling.h +4 -11
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/kernel/flash_attention_score_mix_hwsync.h +0 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_kernel.h +0 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_rtbackend.h +75 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/kernel/matmul.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +3 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_common_tiling.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_info.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_data.h +3 -36
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/kernel/matmul_stridedslice_fusion.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +4 -22
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +2 -16
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +3 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_impl.h +4 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +4 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/attention_param.h +2 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_qkv_param.h +4 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +12 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/backend.h +2 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/elewise_utils.h +1 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +0 -17
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/math.h +7 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/filewriter.py +2 -2
- mindspore/mint/__init__.py +40 -720
- mindspore/mint/nn/__init__.py +7 -89
- mindspore/mint/nn/functional.py +16 -165
- mindspore/mint/optim/adamw.py +16 -15
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +98 -97
- mindspore/nn/extend/basic.py +2 -2
- mindspore/nn/extend/embedding.py +1 -1
- mindspore/nn/extend/layer/normalization.py +5 -7
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +3 -4
- mindspore/nn/layer/basic.py +16 -79
- mindspore/nn/layer/conv.py +8 -17
- mindspore/nn/layer/embedding.py +4 -1
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +1 -1
- mindspore/nn/layer/pooling.py +0 -5
- mindspore/nn/layer/rnn_cells.py +2 -2
- mindspore/nn/loss/loss.py +19 -19
- mindspore/nn/optim/adasum.py +1 -1
- mindspore/nn/optim/sgd.py +2 -3
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/wrap/cell_wrapper.py +1 -25
- mindspore/nn/wrap/loss_scale.py +1 -24
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/math_ops.py +8 -8
- mindspore/ops/__init__.py +1 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -75
- mindspore/ops/_vmap/vmap_array_ops.py +0 -27
- mindspore/ops/_vmap/vmap_math_ops.py +1 -29
- mindspore/ops/_vmap/vmap_nn_ops.py +18 -19
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +8 -34
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +9 -2
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -26
- mindspore/ops/auto_generate/gen_extend_func.py +27 -603
- mindspore/ops/auto_generate/gen_ops_def.py +203 -993
- mindspore/ops/auto_generate/gen_ops_prim.py +402 -1946
- mindspore/ops/auto_generate/pyboost_inner_prim.py +20 -90
- mindspore/ops/composite/base.py +6 -3
- mindspore/ops/composite/math_ops.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +17 -24
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/extend/__init__.py +3 -2
- mindspore/ops/extend/array_func.py +51 -10
- mindspore/ops/extend/nn_func.py +78 -2
- mindspore/ops/function/__init__.py +13 -8
- mindspore/ops/function/array_func.py +179 -455
- mindspore/ops/function/clip_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +3 -3
- mindspore/ops/function/math_func.py +103 -117
- mindspore/ops/function/nn_func.py +163 -275
- mindspore/ops/function/other_func.py +2 -2
- mindspore/ops/function/random_func.py +69 -202
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/functional.py +327 -332
- mindspore/ops/operations/__init__.py +3 -13
- mindspore/ops/operations/_grad_ops.py +27 -3
- mindspore/ops/operations/_inner_ops.py +356 -53
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +65 -82
- mindspore/ops/operations/comm_ops.py +93 -784
- mindspore/ops/operations/custom_ops.py +28 -51
- mindspore/ops/operations/debug_ops.py +4 -4
- mindspore/ops/operations/inner_ops.py +2 -2
- mindspore/ops/operations/manually_defined/ops_def.py +4 -304
- mindspore/ops/operations/math_ops.py +50 -3
- mindspore/ops/operations/nn_ops.py +247 -14
- mindspore/ops/operations/other_ops.py +3 -3
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +1 -1
- mindspore/ops/primitive.py +8 -9
- mindspore/ops/silent_check.py +5 -5
- mindspore/ops_generate/arg_dtype_cast.py +9 -2
- mindspore/ops_generate/arg_handler.py +0 -26
- mindspore/ops_generate/gen_aclnn_implement.py +4 -1
- mindspore/ops_generate/gen_ops.py +4 -26
- mindspore/ops_generate/gen_pyboost_func.py +12 -41
- mindspore/ops_generate/gen_utils.py +0 -21
- mindspore/ops_generate/pyboost_utils.py +2 -7
- mindspore/ops_generate/template.py +0 -1
- mindspore/parallel/_auto_parallel_context.py +1 -21
- mindspore/parallel/_tensor.py +5 -0
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +1 -15
- mindspore/parallel/algo_parameter_config.py +3 -1
- mindspore/parallel/checkpoint_transform.py +9 -12
- mindspore/parallel/cluster/process_entity/_api.py +29 -28
- mindspore/parallel/cluster/process_entity/_utils.py +3 -13
- mindspore/parallel/cluster/run.py +16 -13
- mindspore/parallel/parameter_broadcast.py +2 -2
- mindspore/parallel/shard.py +17 -31
- mindspore/profiler/__init__.py +2 -3
- mindspore/profiler/common/util.py +2 -107
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +21 -8
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -82
- mindspore/profiler/parser/ascend_analysis/function_event.py +28 -43
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +27 -49
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +10 -15
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +20 -25
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +5 -5
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +1 -10
- mindspore/profiler/parser/ascend_hccl_generator.py +1 -4
- mindspore/profiler/parser/ascend_msprof_exporter.py +22 -43
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -7
- mindspore/profiler/parser/minddata_parser.py +3 -72
- mindspore/profiler/profiling.py +59 -176
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/common/namespace.py +5 -5
- mindspore/rewrite/parsers/assign_parser.py +0 -2
- mindspore/rewrite/parsers/class_def_parser.py +4 -8
- mindspore/run_check/_check_version.py +1 -1
- mindspore/scipy/fft.py +3 -1
- mindspore/scipy/linalg.py +3 -2
- mindspore/scipy/ops.py +3 -5
- mindspore/scipy/optimize/__init__.py +2 -2
- mindspore/train/__init__.py +4 -4
- mindspore/train/anf_ir_pb2.py +2 -8
- mindspore/train/callback/__init__.py +2 -5
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_checkpoint.py +16 -104
- mindspore/train/callback/_landscape.py +1 -1
- mindspore/train/callback/_time_monitor.py +1 -1
- mindspore/train/data_sink.py +4 -5
- mindspore/train/dataset_helper.py +20 -45
- mindspore/train/model.py +38 -266
- mindspore/train/serialization.py +105 -256
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +303 -420
- mindspore/_extends/pijit/__init__.py +0 -23
- mindspore/_extends/pijit/pijit_func_white_list.py +0 -343
- mindspore/common/file_system.py +0 -48
- mindspore/common/generator.py +0 -260
- mindspore/common/no_inline.py +0 -54
- mindspore/common/np_dtype.py +0 -25
- mindspore/communication/comm_func.py +0 -1140
- mindspore/hal/memory.py +0 -326
- mindspore/lib/libavcodec.so.59 +0 -0
- mindspore/lib/libavdevice.so.59 +0 -0
- mindspore/lib/libavfilter.so.8 +0 -0
- mindspore/lib/libavformat.so.59 +0 -0
- mindspore/lib/libavutil.so.57 +0 -0
- mindspore/lib/libmindspore_np_dtype.so +0 -0
- mindspore/lib/libswresample.so.4 +0 -0
- mindspore/lib/libswscale.so.6 +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.cpp +0 -326
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.py +0 -180
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +0 -109
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +0 -38
- mindspore/lib/plugin/ascend/custom_compiler/OWNERS +0 -12
- mindspore/lib/plugin/ascend/custom_compiler/setup.py +0 -255
- mindspore/lib/plugin/ascend/custom_compiler/start.sh +0 -26
- mindspore/lib/plugin/ascend/custom_compiler/template.json +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/base_type.h +0 -133
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_creator.h +0 -32
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_param.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/tiling_info.h +0 -60
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/kernel_register.h +0 -37
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/platform_configs.h +0 -89
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/rt_funcs.h +0 -135
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/add_op.h +0 -34
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_backoff_base.h +0 -62
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_elewise_op.h +0 -33
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_ops.h +0 -88
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_pa_op.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/cast_op.h +0 -52
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/matmul_op.h +0 -95
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/asd_utils.h +0 -84
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/comm_utils.h +0 -61
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp32.h +0 -224
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/and_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/div_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_impl.h +0 -48
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_tiling.h +0 -25
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/and_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/div_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_base.h +0 -260
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_kernel.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/max_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/min_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/mul_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/or_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/max_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/min_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/mul_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/or_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/abs_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_impl.h +0 -47
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_tiling.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/exp_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/abs_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_base.h +0 -148
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_kernel.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/exp_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/ln_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/not_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/reciprocal_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/relu_kernel.h +0 -55
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/rsqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/sqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/ln_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/not_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/reciprocal_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/relu_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/rsqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/sqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_impl.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_tiling.h +0 -187
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul.h +0 -245
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_interface.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_utils.h +0 -111
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/tiling_data.h +0 -54
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/compare_param.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/elewise_param.h +0 -41
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/grouped_matmul_param.h +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/profiling_util.h +0 -364
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_utils.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_creator.h +0 -39
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_registry.h +0 -114
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/utils.h +0 -98
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/mint/linalg/__init__.py +0 -22
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/nn/layer/embedding_service_layer.py +0 -393
- mindspore/ops/function/reshard_func.py +0 -102
- mindspore/ops/operations/_infer_ops.py +0 -19
- mindspore/ops/operations/reshard_ops.py +0 -53
- mindspore/profiler/common/process_pool.py +0 -41
- mindspore/profiler/common/singleton.py +0 -28
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/train/callback/_cluster_monitor.py +0 -201
- mindspore/train/callback/_flops_collector.py +0 -238
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/nn/loss/loss.py
CHANGED
|
@@ -1820,10 +1820,10 @@ class MultilabelMarginLoss(LossBase):
|
|
|
1820
1820
|
|
|
1821
1821
|
class BCEWithLogitsLoss(LossBase):
|
|
1822
1822
|
r"""
|
|
1823
|
-
Adds sigmoid activation function to input
|
|
1824
|
-
|
|
1823
|
+
Adds sigmoid activation function to input logits, and uses the given logits to compute binary cross entropy
|
|
1824
|
+
between the logits and the labels.
|
|
1825
1825
|
|
|
1826
|
-
Sets input `
|
|
1826
|
+
Sets input `logits` as :math:`X`, input `labels` as :math:`Y`, output as :math:`L`. Then,
|
|
1827
1827
|
|
|
1828
1828
|
.. math::
|
|
1829
1829
|
p_{ij} = sigmoid(X_{ij}) = \frac{1}{1 + e^{-X_{ij}}}
|
|
@@ -1849,29 +1849,29 @@ class BCEWithLogitsLoss(LossBase):
|
|
|
1849
1849
|
- ``'sum'``: the output elements will be summed.
|
|
1850
1850
|
|
|
1851
1851
|
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
|
1852
|
-
If not None, it can be broadcast to a tensor with shape of `
|
|
1852
|
+
If not None, it can be broadcast to a tensor with shape of `logits`,
|
|
1853
1853
|
data type must be float16 or float32. Default: ``None`` .
|
|
1854
1854
|
pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
|
|
1855
|
-
number of classes. If not None, it must be broadcast to a tensor with shape of `
|
|
1855
|
+
number of classes. If not None, it must be broadcast to a tensor with shape of `logits`, data type
|
|
1856
1856
|
must be float16 or float32. Default: ``None`` .
|
|
1857
1857
|
|
|
1858
1858
|
Inputs:
|
|
1859
|
-
- **
|
|
1859
|
+
- **logits** (Tensor) - Input logits with shape :math:`(N, *)` where :math:`*` means, any number
|
|
1860
1860
|
of additional dimensions. The data type must be float16 or float32.
|
|
1861
|
-
- **
|
|
1862
|
-
of additional dimensions. The same shape and data type as `
|
|
1861
|
+
- **labels** (Tensor) - Ground truth label with shape :math:`(N, *)` where :math:`*` means, any number
|
|
1862
|
+
of additional dimensions. The same shape and data type as `logits`.
|
|
1863
1863
|
|
|
1864
1864
|
Outputs:
|
|
1865
|
-
Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `
|
|
1865
|
+
Tensor or Scalar, if `reduction` is ``'none'``, its shape is the same as `logits`.
|
|
1866
1866
|
Otherwise, a scalar value will be returned.
|
|
1867
1867
|
|
|
1868
1868
|
Raises:
|
|
1869
|
-
TypeError: If input `
|
|
1870
|
-
TypeError: If data type of `
|
|
1869
|
+
TypeError: If input `logits` or `labels` is not Tensor.
|
|
1870
|
+
TypeError: If data type of `logits` or `labels` is neither float16 nor float32.
|
|
1871
1871
|
TypeError: If `weight` or `pos_weight` is a parameter.
|
|
1872
1872
|
TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
|
|
1873
1873
|
TypeError: If data type of `reduction` is not string.
|
|
1874
|
-
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `
|
|
1874
|
+
ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `logits`.
|
|
1875
1875
|
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
1876
1876
|
|
|
1877
1877
|
Supported Platforms:
|
|
@@ -1881,10 +1881,10 @@ class BCEWithLogitsLoss(LossBase):
|
|
|
1881
1881
|
>>> import mindspore as ms
|
|
1882
1882
|
>>> import mindspore.nn as nn
|
|
1883
1883
|
>>> import numpy as np
|
|
1884
|
-
>>>
|
|
1885
|
-
>>>
|
|
1884
|
+
>>> logits = ms.Tensor(np.array([[-0.8, 1.2, 0.7], [-0.1, -0.4, 0.7]]).astype(np.float32))
|
|
1885
|
+
>>> labels = ms.Tensor(np.array([[0.3, 0.8, 1.2], [-0.6, 0.1, 2.2]]).astype(np.float32))
|
|
1886
1886
|
>>> loss = nn.BCEWithLogitsLoss()
|
|
1887
|
-
>>> output = loss(
|
|
1887
|
+
>>> output = loss(logits, labels)
|
|
1888
1888
|
>>> print(output)
|
|
1889
1889
|
0.3463612
|
|
1890
1890
|
"""
|
|
@@ -1900,10 +1900,10 @@ class BCEWithLogitsLoss(LossBase):
|
|
|
1900
1900
|
self.weight = weight
|
|
1901
1901
|
self.pos_weight = pos_weight
|
|
1902
1902
|
|
|
1903
|
-
def construct(self,
|
|
1904
|
-
_check_is_tensor('
|
|
1905
|
-
_check_is_tensor('
|
|
1906
|
-
loss = ops.binary_cross_entropy_with_logits(
|
|
1903
|
+
def construct(self, logits, labels):
|
|
1904
|
+
_check_is_tensor('logits', logits, self.cls_name)
|
|
1905
|
+
_check_is_tensor('labels', labels, self.cls_name)
|
|
1906
|
+
loss = ops.binary_cross_entropy_with_logits(logits, labels, self.weight, self.pos_weight, self.reduction)
|
|
1907
1907
|
return loss
|
|
1908
1908
|
|
|
1909
1909
|
|
mindspore/nn/optim/adasum.py
CHANGED
|
@@ -29,7 +29,7 @@ from mindspore.parallel._utils import _get_global_rank, _get_stage_device_num
|
|
|
29
29
|
from mindspore.ops import composite as C
|
|
30
30
|
from mindspore.ops import functional as F
|
|
31
31
|
from mindspore.ops import operations as P
|
|
32
|
-
from mindspore.ops import Send, Receive
|
|
32
|
+
from mindspore.ops.operations._inner_ops import Send, Receive
|
|
33
33
|
from mindspore.common.tensor import Tensor
|
|
34
34
|
from mindspore.common import dtype as mstype
|
|
35
35
|
from mindspore.communication.management import create_group
|
mindspore/nn/optim/sgd.py
CHANGED
|
@@ -195,9 +195,9 @@ class SGD(Optimizer):
|
|
|
195
195
|
"or 'weight_decay' set in grouped 'params' must be float or int type.")
|
|
196
196
|
|
|
197
197
|
if hasattr(self, "group_weight_decay") and self.group_weight_decay:
|
|
198
|
-
self.opt = tuple(P.SGD(dampening,
|
|
198
|
+
self.opt = tuple(P.SGD(dampening, wd, nesterov) for wd in self.group_weight_decay)
|
|
199
199
|
else:
|
|
200
|
-
self.opt = tuple([P.SGD(dampening,
|
|
200
|
+
self.opt = tuple([P.SGD(dampening, float(weight_decay), nesterov)] * len(self._parameters))
|
|
201
201
|
|
|
202
202
|
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
|
203
203
|
|
|
@@ -222,7 +222,6 @@ class SGD(Optimizer):
|
|
|
222
222
|
params = self._parameters
|
|
223
223
|
accum = self.accum
|
|
224
224
|
stat = self.stat
|
|
225
|
-
gradients = self.decay_weight(gradients)
|
|
226
225
|
gradients = self.flatten_gradients(gradients)
|
|
227
226
|
gradients = self.gradients_centralization(gradients)
|
|
228
227
|
gradients = self.scale_grad(gradients)
|
|
@@ -152,7 +152,7 @@ class Exponential(Distribution):
|
|
|
152
152
|
if self.rate is not None:
|
|
153
153
|
check_greater_zero(self.rate, 'rate')
|
|
154
154
|
|
|
155
|
-
self.minval = np.finfo(np.
|
|
155
|
+
self.minval = np.finfo(np.float).tiny
|
|
156
156
|
|
|
157
157
|
# ops needed for the class
|
|
158
158
|
self.exp = exp_generic
|
|
@@ -170,7 +170,7 @@ class Logistic(Distribution):
|
|
|
170
170
|
self.neg = P.Neg()
|
|
171
171
|
|
|
172
172
|
self.threshold = np.log(np.finfo(np.float32).eps) + 1.
|
|
173
|
-
self.tiny = np.finfo(np.
|
|
173
|
+
self.tiny = np.finfo(np.float).tiny
|
|
174
174
|
self.sd_const = np.pi / np.sqrt(3)
|
|
175
175
|
|
|
176
176
|
def _softplus(self, x):
|
|
@@ -17,18 +17,16 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from __future__ import division
|
|
19
19
|
|
|
20
|
-
import os
|
|
21
20
|
from types import FunctionType, MethodType
|
|
22
21
|
|
|
23
22
|
from mindspore import log as logger
|
|
24
23
|
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
|
|
25
24
|
_get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
|
|
26
|
-
from mindspore.context import ParallelMode
|
|
25
|
+
from mindspore.context import ParallelMode
|
|
27
26
|
from mindspore import _checkparam as validator
|
|
28
27
|
from mindspore import ops, nn
|
|
29
28
|
from mindspore.common import dtype as mstype
|
|
30
29
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
31
|
-
from mindspore.common.tensor import Tensor
|
|
32
30
|
from mindspore.ops.primitive import _primexpr
|
|
33
31
|
from mindspore.ops import composite as C
|
|
34
32
|
from mindspore.ops import functional as F
|
|
@@ -742,18 +740,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
|
742
740
|
self.hyper_map = ops.HyperMap()
|
|
743
741
|
self.opt_shard = _get_enable_parallel_optimizer()
|
|
744
742
|
self._get_attr_from_cell(network)
|
|
745
|
-
self.enable_mindio = False
|
|
746
|
-
mode = get_context("mode")
|
|
747
|
-
device_type = get_context("device_target")
|
|
748
|
-
if device_type != "Ascend" or mode != GRAPH_MODE:
|
|
749
|
-
return
|
|
750
|
-
graceful_exit = os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT")
|
|
751
|
-
ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
|
|
752
|
-
ttp_path_check = ttp_lib_path is not None and os.path.isfile(ttp_lib_path)
|
|
753
|
-
if graceful_exit == "true" and ttp_path_check:
|
|
754
|
-
self.g_one = Tensor([0.1])
|
|
755
|
-
self.allreduce_sum = ops.AllReduce()
|
|
756
|
-
self.enable_mindio = True
|
|
757
743
|
|
|
758
744
|
def construct(self, *inputs):
|
|
759
745
|
if not self.sense_flag:
|
|
@@ -762,11 +748,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
|
762
748
|
sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
|
|
763
749
|
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
|
764
750
|
accu_grads = ops.depend(self.accu_grads, grads)
|
|
765
|
-
if self.enable_mindio:
|
|
766
|
-
g_one = ops.depend(self.g_one, accu_grads)
|
|
767
|
-
g_one_res = self.allreduce_sum(g_one)
|
|
768
|
-
accu_grads = ops.depend(accu_grads, g_one_res)
|
|
769
|
-
grads = ops.depend(grads, g_one_res)
|
|
770
751
|
if self.opt_shard:
|
|
771
752
|
succ = self.optimizer(grads)
|
|
772
753
|
else:
|
|
@@ -781,11 +762,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
|
|
|
781
762
|
loss = self.network(*inputs)
|
|
782
763
|
grads = self.grad_no_sens(self.network, self.weights)(*inputs)
|
|
783
764
|
accu_grads = ops.depend(self.accu_grads, grads)
|
|
784
|
-
if self.enable_mindio:
|
|
785
|
-
g_one = ops.depend(self.g_one, accu_grads)
|
|
786
|
-
g_one_res = self.allreduce_sum(g_one)
|
|
787
|
-
accu_grads = ops.depend(accu_grads, g_one_res)
|
|
788
|
-
grads = ops.depend(grads, g_one_res)
|
|
789
765
|
if self.opt_shard:
|
|
790
766
|
succ = self.optimizer(grads)
|
|
791
767
|
else:
|
mindspore/nn/wrap/loss_scale.py
CHANGED
|
@@ -29,7 +29,6 @@ from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloat
|
|
|
29
29
|
from mindspore.ops import functional as F
|
|
30
30
|
from mindspore.ops import composite as C
|
|
31
31
|
from mindspore.ops import operations as P
|
|
32
|
-
from mindspore.ops.operations.nn_ops import AllFinite
|
|
33
32
|
from mindspore.common import dtype as mstype
|
|
34
33
|
from mindspore.common.api import jit
|
|
35
34
|
from mindspore._c_expression import MSContext
|
|
@@ -373,15 +372,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
373
372
|
self.loss_scaling_manager = None
|
|
374
373
|
self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
|
|
375
374
|
|
|
376
|
-
self.enable_allfinite = False
|
|
377
|
-
runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
|
|
378
|
-
global_jit_config = context.get_jit_config()
|
|
379
|
-
if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
|
|
380
|
-
self.enable_allfinite = True
|
|
381
|
-
elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
|
|
382
|
-
self.enable_allfinite = False
|
|
383
|
-
elif global_jit_config:
|
|
384
|
-
self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
|
|
385
375
|
|
|
386
376
|
if isinstance(scale_sense, Cell):
|
|
387
377
|
self.loss_scaling_manager = scale_sense
|
|
@@ -488,15 +478,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
488
478
|
overflow = self.less_equal(self.base, flag_sum)
|
|
489
479
|
return overflow
|
|
490
480
|
|
|
491
|
-
def _get_distributed_overflow_status_on_infnan_enable_allfinite(self, compute_output):
|
|
492
|
-
"""check overflow status on infnan kernel mode."""
|
|
493
|
-
overflow = AllFinite()(compute_output)
|
|
494
|
-
|
|
495
|
-
if self.is_distributed:
|
|
496
|
-
overflow = P.Cast()(overflow, mstype.int8)
|
|
497
|
-
overflow = P.Cast()(self.allreduce(overflow), mstype.bool_)
|
|
498
|
-
return overflow
|
|
499
|
-
|
|
500
481
|
def _get_gpu_overflow_status(self, compute_output):
|
|
501
482
|
"""get overflow status of gpu."""
|
|
502
483
|
overflow = self._get_distributed_overflow_status_on_infnan_mode(_grad_overflow, compute_output)
|
|
@@ -504,11 +485,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
|
|
|
504
485
|
|
|
505
486
|
def _get_ascend_overflow_status_on_infnan_mode(self, compute_output):
|
|
506
487
|
"""get overflow status of ascend on infnan mode."""
|
|
507
|
-
overflow =
|
|
508
|
-
if self.enable_allfinite:
|
|
509
|
-
overflow = self._get_distributed_overflow_status_on_infnan_enable_allfinite(compute_output)
|
|
510
|
-
else:
|
|
511
|
-
overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
|
|
488
|
+
overflow = self._get_distributed_overflow_status_on_infnan_mode(_ascend_grad_overflow, compute_output)
|
|
512
489
|
return overflow
|
|
513
490
|
|
|
514
491
|
def _get_ascend_overflow_status_on_saturation_mode(self, status, compute_output):
|
mindspore/numpy/array_ops.py
CHANGED
|
@@ -2606,11 +2606,7 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
|
|
|
2606
2606
|
array1 = ar1.ravel()
|
|
2607
2607
|
array2 = ar2.ravel()
|
|
2608
2608
|
concat_array = concatenate((array1, array2))
|
|
2609
|
-
|
|
2610
|
-
concat_sort_indices = F.argsort(concat_array)
|
|
2611
|
-
concat_array = concat_array[concat_sort_indices]
|
|
2612
|
-
else:
|
|
2613
|
-
concat_array, concat_sort_indices = concat_array.sort()
|
|
2609
|
+
concat_array, concat_sort_indices = concat_array.sort()
|
|
2614
2610
|
|
|
2615
2611
|
mask_res = concat_array[1:] == concat_array[:-1]
|
|
2616
2612
|
res = F.masked_select(concat_array[1:], mask_res)
|
mindspore/numpy/dtypes.py
CHANGED
|
@@ -86,7 +86,7 @@ dtype_map = {
|
|
|
86
86
|
}
|
|
87
87
|
|
|
88
88
|
all_types = [
|
|
89
|
-
'np.
|
|
89
|
+
'np.int',
|
|
90
90
|
'np.int8',
|
|
91
91
|
'np.int16',
|
|
92
92
|
'np.int32',
|
|
@@ -96,11 +96,11 @@ all_types = [
|
|
|
96
96
|
'np.uint16',
|
|
97
97
|
'np.uint32',
|
|
98
98
|
'np.uint64',
|
|
99
|
-
'np.
|
|
99
|
+
'np.float',
|
|
100
100
|
'np.float16',
|
|
101
101
|
'np.float32',
|
|
102
102
|
'np.float64',
|
|
103
|
-
'np.
|
|
103
|
+
'np.bool']
|
|
104
104
|
|
|
105
105
|
promotion_rule = {
|
|
106
106
|
(uint8, uint16): uint16,
|
mindspore/numpy/math_ops.py
CHANGED
|
@@ -4166,18 +4166,18 @@ def multi_dot(arrays):
|
|
|
4166
4166
|
Examples:
|
|
4167
4167
|
>>> import mindspore.numpy as np
|
|
4168
4168
|
>>> A = np.ones((10000, 100))
|
|
4169
|
-
>>> B = np.ones((100,
|
|
4170
|
-
>>> C = np.ones((
|
|
4169
|
+
>>> B = np.ones((100, 1000))
|
|
4170
|
+
>>> C = np.ones((1000, 5))
|
|
4171
4171
|
>>> D = np.ones((5, 333))
|
|
4172
4172
|
>>> output = np.multi_dot([A, B, C, D])
|
|
4173
4173
|
>>> print(output)
|
|
4174
|
-
[[
|
|
4175
|
-
[
|
|
4176
|
-
[
|
|
4174
|
+
[[500000. 500000. 500000. ... 500000. 500000. 500000.]
|
|
4175
|
+
[500000. 500000. 500000. ... 500000. 500000. 500000.]
|
|
4176
|
+
[500000. 500000. 500000. ... 500000. 500000. 500000.]
|
|
4177
4177
|
...
|
|
4178
|
-
[
|
|
4179
|
-
[
|
|
4180
|
-
[
|
|
4178
|
+
[500000. 500000. 500000. ... 500000. 500000. 500000.]
|
|
4179
|
+
[500000. 500000. 500000. ... 500000. 500000. 500000.]
|
|
4180
|
+
[500000. 500000. 500000. ... 500000. 500000. 500000.]]
|
|
4181
4181
|
"""
|
|
4182
4182
|
if len(arrays) < 2:
|
|
4183
4183
|
_raise_value_error('Expecting at least 2 arrays')
|
mindspore/ops/__init__.py
CHANGED
|
@@ -44,7 +44,7 @@ __primitive__ = [
|
|
|
44
44
|
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
|
|
45
45
|
"op_info_register", "custom_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp",
|
|
46
46
|
"CpuRegOp", "CustomRegOp", "DataType",
|
|
47
|
-
"constexpr"
|
|
47
|
+
"constexpr"]
|
|
48
48
|
__all__.extend(__primitive__)
|
|
49
49
|
__all__.extend(composite.__all__)
|
|
50
50
|
__all__.extend(operations.__all__)
|
|
@@ -22,8 +22,7 @@ from mindspore.ops import functional as F
|
|
|
22
22
|
from mindspore.communication import get_rank, get_group_size
|
|
23
23
|
from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
|
|
24
24
|
from mindspore.ops import operations as P
|
|
25
|
-
from mindspore.ops import Send, Receive
|
|
26
|
-
from mindspore.ops.operations._inner_ops import issubclass_
|
|
25
|
+
from mindspore.ops.operations._inner_ops import Send, Receive, issubclass_
|
|
27
26
|
from mindspore.common.sparse_tensor import RowTensorInner
|
|
28
27
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
29
28
|
from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce,
|
|
@@ -31,7 +30,7 @@ from mindspore.ops.operations.comm_ops import (AllGather, _MiniStepAllGather, _H
|
|
|
31
30
|
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
|
32
31
|
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
|
|
33
32
|
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator,
|
|
34
|
-
_MicroStepAllGather
|
|
33
|
+
_MicroStepAllGather)
|
|
35
34
|
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
36
35
|
from mindspore.ops.operations import _grad_ops as G
|
|
37
36
|
|
|
@@ -211,17 +210,21 @@ def get_bprop_mirror_micro_step_operator(self):
|
|
|
211
210
|
def bprop(x, z, out, dout):
|
|
212
211
|
real_grad = z
|
|
213
212
|
assign_out = dout
|
|
214
|
-
if
|
|
215
|
-
|
|
216
|
-
|
|
213
|
+
if mean_flag:
|
|
214
|
+
if issubclass_(F.typeof(dout), mstype.tensor_type):
|
|
215
|
+
z = F.depend(z, dout)
|
|
217
216
|
real_grad = all_reduce(z)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
217
|
+
real_grad = F.tensor_mul(real_grad, scale)
|
|
218
|
+
if opt_shard:
|
|
219
|
+
return (real_grad, cast(out_tensor, dtype(z)))
|
|
220
|
+
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
|
|
221
|
+
else:
|
|
222
|
+
if issubclass_(F.typeof(dout), mstype.tensor_type):
|
|
223
|
+
z = F.depend(z, dout)
|
|
224
|
+
real_grad = all_reduce(z)
|
|
225
|
+
if opt_shard:
|
|
226
|
+
return (real_grad, cast(out_tensor, dtype(z)))
|
|
227
|
+
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign(z, real_grad))
|
|
225
228
|
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
|
|
226
229
|
|
|
227
230
|
return bprop
|
|
@@ -241,13 +244,11 @@ def get_bprop_broad_cast(self):
|
|
|
241
244
|
def get_bprop_all_gather(self):
|
|
242
245
|
"""Generate bprop for AllGather"""
|
|
243
246
|
fusion = self.get_attr_dict()["fusion"]
|
|
244
|
-
self.group = self.get_attr_dict()["group"]
|
|
245
247
|
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
|
|
246
248
|
if hasattr(self, "instance_name") and self.instance_name:
|
|
247
249
|
instance_name = "grad_" + self.instance_name
|
|
248
250
|
reduce_scatter.set_prim_instance_name(instance_name)
|
|
249
251
|
mean_flag = self.get_attr_dict()["mean_flag"]
|
|
250
|
-
self.rank_size = self.get_attr_dict()["rank_size"]
|
|
251
252
|
if self.rank_size == 0:
|
|
252
253
|
raise ValueError(f"The 'rank_size' can not be zero, but got {self.rank_size}.")
|
|
253
254
|
scale = 1.0 / self.rank_size
|
|
@@ -377,66 +378,6 @@ def get_bprop_reduce_scatter(self):
|
|
|
377
378
|
return bprop
|
|
378
379
|
|
|
379
380
|
|
|
380
|
-
@bprop_getters.register(Reduce)
|
|
381
|
-
def get_bprop_reduce(self):
|
|
382
|
-
"""Generate bprop for Reduce"""
|
|
383
|
-
dest_rank = self.get_attr_dict()["dest_rank"]
|
|
384
|
-
group = self.get_attr_dict()["group"]
|
|
385
|
-
reduce_grad = Broadcast(dest_rank, group)
|
|
386
|
-
if hasattr(self, "instance_name") and self.instance_name:
|
|
387
|
-
instance_name = "grad" + self.instance_name
|
|
388
|
-
reduce_grad.set_prim_instance_name(instance_name)
|
|
389
|
-
|
|
390
|
-
def bprop(x, out, dout):
|
|
391
|
-
dx = reduce_grad((dout,))
|
|
392
|
-
return (dx[0],)
|
|
393
|
-
|
|
394
|
-
return bprop
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
@bprop_getters.register(CollectiveGather)
|
|
398
|
-
def get_bprop_collective_gather(self):
|
|
399
|
-
"""Generate bprop for CollectiveGather"""
|
|
400
|
-
group = self.get_attr_dict()["group"]
|
|
401
|
-
dest_rank = self.get_attr_dict()["dest_rank"]
|
|
402
|
-
collective_gather_grad = Broadcast(dest_rank, group)
|
|
403
|
-
rank = get_rank(group)
|
|
404
|
-
dev_num = self.rank_size
|
|
405
|
-
split = P.Split(output_num=dev_num)
|
|
406
|
-
if hasattr(self, "instance_name") and self.instance_name:
|
|
407
|
-
instance_name = "grad" + self.instance_name
|
|
408
|
-
collective_gather_grad.set_prim_instance_name(instance_name)
|
|
409
|
-
|
|
410
|
-
def bprop(x, out, dout):
|
|
411
|
-
grad = collective_gather_grad((dout,))
|
|
412
|
-
dx = split(grad[0])[rank]
|
|
413
|
-
return (dx,)
|
|
414
|
-
|
|
415
|
-
return bprop
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
@bprop_getters.register(CollectiveScatter)
|
|
419
|
-
def get_bprop_collective_scatter(self):
|
|
420
|
-
"""Generate bprop for CollectiveScatter"""
|
|
421
|
-
group = self.get_attr_dict()["group"]
|
|
422
|
-
dest_rank = self.get_attr_dict()["src_rank"]
|
|
423
|
-
rank = get_rank(group)
|
|
424
|
-
collective_scatter_grad = CollectiveGather(dest_rank, group)
|
|
425
|
-
if hasattr(self, "instance_name") and self.instance_name:
|
|
426
|
-
instance_name = "grad" + self.instance_name
|
|
427
|
-
collective_scatter_grad.set_prim_instance_name(instance_name)
|
|
428
|
-
|
|
429
|
-
def bprop(x, out, dout):
|
|
430
|
-
dx_out = collective_scatter_grad(dout)
|
|
431
|
-
if rank == dest_rank:
|
|
432
|
-
dx = dx_out
|
|
433
|
-
else:
|
|
434
|
-
dx = F.depend(F.zeros_like(x), dx_out)
|
|
435
|
-
return (dx,)
|
|
436
|
-
|
|
437
|
-
return bprop
|
|
438
|
-
|
|
439
|
-
|
|
440
381
|
@bprop_getters.register(_AllSwap)
|
|
441
382
|
def get_bprop_allswap(self):
|
|
442
383
|
"""Generate bprop for _AllSwap."""
|
|
@@ -2113,33 +2113,6 @@ def get_split_vmap_rule(prim, axis_size):
|
|
|
2113
2113
|
|
|
2114
2114
|
return vmap_rule
|
|
2115
2115
|
|
|
2116
|
-
@vmap_rules_getters.register(P.SearchSorted)
|
|
2117
|
-
def get_searchsorted_vmap_rule(prim, axis_size):
|
|
2118
|
-
"""VmapRule for `SearchSorted`."""
|
|
2119
|
-
def vmap_rule(sequence_bdim, values_bdim, sorter_bdim, dtype_bdim, right_bdim):
|
|
2120
|
-
is_all_none, result = vmap_general_preprocess(prim, sequence_bdim, values_bdim,
|
|
2121
|
-
sorter_bdim, dtype_bdim, right_bdim)
|
|
2122
|
-
if is_all_none:
|
|
2123
|
-
return result
|
|
2124
|
-
|
|
2125
|
-
sequence, sequence_dim = sequence_bdim
|
|
2126
|
-
values, values_dim = values_bdim
|
|
2127
|
-
sorter, sorter_dim = sorter_bdim
|
|
2128
|
-
|
|
2129
|
-
sequence = _bdim_at_front(sequence, sequence_dim, axis_size)
|
|
2130
|
-
values = _bdim_at_front(values, values_dim, axis_size)
|
|
2131
|
-
if sorter is not None and sorter_dim is not None:
|
|
2132
|
-
sorter = _bdim_at_front(sorter, sorter_dim, axis_size)
|
|
2133
|
-
|
|
2134
|
-
dtype, _ = dtype_bdim
|
|
2135
|
-
right, _ = right_bdim
|
|
2136
|
-
|
|
2137
|
-
outputs = prim(sequence, values, sorter, dtype, right)
|
|
2138
|
-
|
|
2139
|
-
return outputs, 0
|
|
2140
|
-
|
|
2141
|
-
return vmap_rule
|
|
2142
|
-
|
|
2143
2116
|
|
|
2144
2117
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(NonZero)(get_unsupported_dynamic_vmap_rule)
|
|
2145
2118
|
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
|
@@ -63,6 +63,7 @@ def _broadcast_shape(nd, x_ndim, x_shape):
|
|
|
63
63
|
@vmap_rules_getters.register(P.BitwiseAnd)
|
|
64
64
|
@vmap_rules_getters.register(P.BitwiseOr)
|
|
65
65
|
@vmap_rules_getters.register(P.BitwiseXor)
|
|
66
|
+
@vmap_rules_getters.register(P.IsClose)
|
|
66
67
|
@vmap_rules_getters.register(P.Xlogy)
|
|
67
68
|
@vmap_rules_getters.register(P.ApproximateEqual)
|
|
68
69
|
@vmap_rules_getters.register(P.TruncateDiv)
|
|
@@ -887,35 +888,6 @@ def get_logit_vmap_rule(prim_func, axis_size):
|
|
|
887
888
|
|
|
888
889
|
return vmap_rule
|
|
889
890
|
|
|
890
|
-
|
|
891
|
-
@vmap_rules_getters.register(P.IsClose)
|
|
892
|
-
def get_isclose_vmap_rule(prim, axis_size):
|
|
893
|
-
"""VmapRule for `IsClose` operation"""
|
|
894
|
-
|
|
895
|
-
def vmap_rule(x_bdim, y_bdim, rtol_bdim, atol_bdim, equal_nan_bdim):
|
|
896
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim, x_bdim, rtol_bdim, atol_bdim, equal_nan_bdim)
|
|
897
|
-
if is_all_none:
|
|
898
|
-
return result
|
|
899
|
-
|
|
900
|
-
x, x_dim = x_bdim
|
|
901
|
-
y, y_dim = y_bdim
|
|
902
|
-
rtol, _ = rtol_bdim
|
|
903
|
-
atol, _ = atol_bdim
|
|
904
|
-
equal_nan, _ = equal_nan_bdim
|
|
905
|
-
|
|
906
|
-
if x_dim == y_dim:
|
|
907
|
-
out = prim(x, y, rtol, atol, equal_nan)
|
|
908
|
-
return out, x_dim
|
|
909
|
-
if y_dim is None:
|
|
910
|
-
y = _broadcast_by_axis(y, x_dim, axis_size)
|
|
911
|
-
else:
|
|
912
|
-
y = mnp.moveaxis(y, y_dim, x_dim)
|
|
913
|
-
|
|
914
|
-
out = prim(x, y, rtol, atol, equal_nan)
|
|
915
|
-
return out, x_dim
|
|
916
|
-
|
|
917
|
-
return vmap_rule
|
|
918
|
-
|
|
919
891
|
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignAdd)(get_assign_vmap_rule)
|
|
920
892
|
get_assign_vmap_rule = vmap_rules_getters.register(P.AssignSub)(get_assign_vmap_rule)
|
|
921
893
|
|
|
@@ -31,7 +31,6 @@ from mindspore.ops._vmap.vmap_base import vmap_rules_getters, vmap_general_prepr
|
|
|
31
31
|
from mindspore.ops.primitive import Primitive
|
|
32
32
|
from mindspore.ops.auto_generate.gen_arg_handler import Format
|
|
33
33
|
from mindspore.ops.auto_generate import Embedding
|
|
34
|
-
from mindspore.ops.auto_generate import gen_arg_handler as handler
|
|
35
34
|
|
|
36
35
|
|
|
37
36
|
@vmap_rules_getters.register(P.ApplyAdaMax)
|
|
@@ -299,19 +298,25 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
|
|
|
299
298
|
|
|
300
299
|
if isinstance(prim, str):
|
|
301
300
|
prim = Primitive(prim)
|
|
301
|
+
prim_reduction = 'none'
|
|
302
|
+
else:
|
|
303
|
+
prim_reduction = prim.reduction
|
|
302
304
|
prim_name = prim.name
|
|
303
305
|
bce_logits_with_loss_op = NN.BCEWithLogitsLoss('none')
|
|
306
|
+
if prim_reduction == 'mean':
|
|
307
|
+
reduce_op = P.ReduceMean()
|
|
308
|
+
elif prim_reduction == "sum":
|
|
309
|
+
reduce_op = P.ReduceSum()
|
|
304
310
|
|
|
305
|
-
def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim
|
|
306
|
-
is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim,
|
|
307
|
-
|
|
311
|
+
def vmap_rule(logits_bdim, label_bdim, weight_bdim, pos_weight_bdim):
|
|
312
|
+
is_all_none, result = vmap_general_preprocess(prim, logits_bdim, label_bdim,
|
|
313
|
+
weight_bdim, pos_weight_bdim)
|
|
308
314
|
if is_all_none:
|
|
309
315
|
return result
|
|
310
316
|
logits, logits_dim = logits_bdim
|
|
311
317
|
label, label_dim = label_bdim
|
|
312
318
|
weight, weight_dim = weight_bdim
|
|
313
319
|
pos_weight, pos_weight_dim = pos_weight_bdim
|
|
314
|
-
prim_reduction, _ = reduction_bdim
|
|
315
320
|
logits_rank = F.rank(logits)
|
|
316
321
|
label_rank = F.rank(label)
|
|
317
322
|
weight_rank = F.rank(weight)
|
|
@@ -327,14 +332,11 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
|
|
|
327
332
|
shape = F.shape(logits)
|
|
328
333
|
shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
|
|
329
334
|
if logits_dim_ok and shape_ok:
|
|
330
|
-
if prim_reduction ==
|
|
331
|
-
output = prim(logits, label, weight, pos_weight
|
|
332
|
-
elif prim_reduction
|
|
335
|
+
if prim_reduction == 'none':
|
|
336
|
+
output = prim(logits, label, weight, pos_weight)
|
|
337
|
+
elif prim_reduction in ('mean', 'sum'):
|
|
333
338
|
out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
|
|
334
|
-
output =
|
|
335
|
-
elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
|
|
336
|
-
out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
|
|
337
|
-
output = P.ReduceSum()(out, reduce_indexes)
|
|
339
|
+
output = reduce_op(out, reduce_indexes)
|
|
338
340
|
else:
|
|
339
341
|
raise RuntimeError("For {} vmap, the attribute of reduction must in "
|
|
340
342
|
"('none', 'mean', 'sum'), but got {}."
|
|
@@ -350,14 +352,11 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
|
|
|
350
352
|
pos_weight_shape = F.shape(pos_weight)
|
|
351
353
|
weight = _handle_broadcasting(weight, weight_shape, logits_shape)
|
|
352
354
|
pos_weight = _handle_broadcasting(pos_weight, pos_weight_shape, logits_shape)
|
|
353
|
-
if prim_reduction ==
|
|
354
|
-
output = prim(logits, label, weight, pos_weight
|
|
355
|
-
elif prim_reduction
|
|
356
|
-
out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
|
|
357
|
-
output = P.ReduceMean()(out, reduce_indexes)
|
|
358
|
-
elif prim_reduction == handler.str_to_enum("BCEWithLogitsLoss", "reduction", 'sum'):
|
|
355
|
+
if prim_reduction == 'none':
|
|
356
|
+
output = prim(logits, label, weight, pos_weight)
|
|
357
|
+
elif prim_reduction in ('mean', 'sum'):
|
|
359
358
|
out = bce_logits_with_loss_op(logits, label, weight, pos_weight)
|
|
360
|
-
output =
|
|
359
|
+
output = reduce_op(out, reduce_indexes)
|
|
361
360
|
else:
|
|
362
361
|
raise RuntimeError("For {} vmap, the attribute of reduction must in "
|
|
363
362
|
"('none', 'mean', 'sum'), but got {}."
|