mindspore 2.3.0__cp39-none-any.whl → 2.3.0rc2__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +0 -1512
- mindspore/__init__.py +1 -2
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +25 -5
- mindspore/_extends/graph_kernel/model/graph_parallel.py +1 -1
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +0 -29
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +5 -21
- mindspore/_extends/parse/resources.py +7 -5
- mindspore/_extends/parse/standard_method.py +59 -40
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -26
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/less_batch_normalization.py +6 -9
- mindspore/common/__init__.py +1 -8
- mindspore/common/_register_for_tensor.py +9 -8
- mindspore/common/api.py +65 -275
- mindspore/common/dtype.py +4 -8
- mindspore/common/dump.py +5 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/lazy_inline.py +2 -14
- mindspore/common/parameter.py +15 -14
- mindspore/common/recompute.py +5 -20
- mindspore/common/sparse_tensor.py +6 -21
- mindspore/common/tensor.py +52 -100
- mindspore/communication/__init__.py +11 -6
- mindspore/communication/management.py +94 -92
- mindspore/context.py +18 -180
- mindspore/dataset/engine/datasets.py +46 -69
- mindspore/dataset/engine/datasets_user_defined.py +53 -72
- mindspore/dataset/engine/datasets_vision.py +2 -2
- mindspore/dataset/engine/queue.py +38 -56
- mindspore/dataset/engine/validators.py +5 -11
- mindspore/dataset/vision/__init__.py +5 -5
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +46 -591
- mindspore/dataset/vision/utils.py +1 -121
- mindspore/dataset/vision/validators.py +3 -9
- mindspore/hal/__init__.py +1 -7
- mindspore/hal/device.py +1 -1
- mindspore/include/api/model.h +0 -3
- mindspore/include/dataset/vision.h +2 -54
- mindspore/include/mindapi/base/types.h +0 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +0 -35
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +0 -72
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/{aclnn_all_finite.h → aclnn_add_custom.h} +11 -9
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +1 -1
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +12 -184
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +15 -7
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +31 -77
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +5 -4
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/liblowlatency_collective.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/DeviceBin +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/PkgInspect +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/bin/op_man +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/device/ascend910b/bin/ascend910b.bin +286 -275
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_cann_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/host/libasdops_host.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/add/add_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/apply_rotary_pos_emb_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/asdop/asd_op_impl.h +0 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/backend_param.h +0 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/cast/cast_tiling.h +45 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/compare/compare_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_impl.h +4 -8
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/flash_attention_score_tiling.h +4 -11
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/flash_attention_score/kernel/flash_attention_score_mix_hwsync.h +0 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_kernel.h +0 -6
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal_rtbackend.h +75 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/kernel/matmul.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul/matmul_impl.h +3 -18
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_common_tiling.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/pp_matmul_info.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_common/tiling_data.h +3 -36
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/kernel/matmul_stridedslice_fusion.h +2 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/matmul_stridedslice/matmul_stridedslice_fusion_impl.h +4 -22
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/op_param.h +2 -16
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/kernel/paged_attention_mix_hwsync.h +3 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_impl.h +4 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/paged_attention/paged_attention_tiling.h +4 -9
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/attention_param.h +2 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_ext_param.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/matmul_qkv_param.h +4 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/sub_param.h +12 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/rms_norm/rms_norm_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/sub/sub_impl.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/tune_repo/matmul_table.h +1 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/backend.h +2 -10
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/elewise_utils.h +1 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log.h +0 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_tiling.h +0 -17
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/math.h +7 -2
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libAdd_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libSub_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_layernorm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libadd_rms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libcast_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libgelu_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libmatmul_stridedslice_fusion_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libnot_equal_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libreshape_and_cache_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/librms_norm_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/BSAttention/paged_attention_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblccl_wrapper.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/filewriter.py +2 -2
- mindspore/mint/__init__.py +40 -720
- mindspore/mint/nn/__init__.py +7 -89
- mindspore/mint/nn/functional.py +16 -165
- mindspore/mint/optim/adamw.py +16 -15
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +98 -97
- mindspore/nn/extend/basic.py +2 -2
- mindspore/nn/extend/embedding.py +1 -1
- mindspore/nn/extend/layer/normalization.py +5 -7
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/activation.py +3 -4
- mindspore/nn/layer/basic.py +16 -79
- mindspore/nn/layer/conv.py +8 -17
- mindspore/nn/layer/embedding.py +4 -1
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +1 -1
- mindspore/nn/layer/pooling.py +0 -5
- mindspore/nn/layer/rnn_cells.py +2 -2
- mindspore/nn/loss/loss.py +19 -19
- mindspore/nn/optim/adasum.py +1 -1
- mindspore/nn/optim/sgd.py +2 -3
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/wrap/cell_wrapper.py +1 -25
- mindspore/nn/wrap/loss_scale.py +1 -24
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/math_ops.py +8 -8
- mindspore/ops/__init__.py +1 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -75
- mindspore/ops/_vmap/vmap_array_ops.py +0 -27
- mindspore/ops/_vmap/vmap_math_ops.py +1 -29
- mindspore/ops/_vmap/vmap_nn_ops.py +18 -19
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +8 -34
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +9 -2
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -26
- mindspore/ops/auto_generate/gen_extend_func.py +27 -603
- mindspore/ops/auto_generate/gen_ops_def.py +203 -993
- mindspore/ops/auto_generate/gen_ops_prim.py +402 -1946
- mindspore/ops/auto_generate/pyboost_inner_prim.py +20 -90
- mindspore/ops/composite/base.py +6 -3
- mindspore/ops/composite/math_ops.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +17 -24
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/extend/__init__.py +3 -2
- mindspore/ops/extend/array_func.py +51 -10
- mindspore/ops/extend/nn_func.py +78 -2
- mindspore/ops/function/__init__.py +13 -8
- mindspore/ops/function/array_func.py +179 -455
- mindspore/ops/function/clip_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +3 -3
- mindspore/ops/function/math_func.py +103 -117
- mindspore/ops/function/nn_func.py +163 -275
- mindspore/ops/function/other_func.py +2 -2
- mindspore/ops/function/random_func.py +69 -202
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/functional.py +327 -332
- mindspore/ops/operations/__init__.py +3 -13
- mindspore/ops/operations/_grad_ops.py +27 -3
- mindspore/ops/operations/_inner_ops.py +356 -53
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +65 -82
- mindspore/ops/operations/comm_ops.py +93 -784
- mindspore/ops/operations/custom_ops.py +28 -51
- mindspore/ops/operations/debug_ops.py +4 -4
- mindspore/ops/operations/inner_ops.py +2 -2
- mindspore/ops/operations/manually_defined/ops_def.py +4 -304
- mindspore/ops/operations/math_ops.py +50 -3
- mindspore/ops/operations/nn_ops.py +247 -14
- mindspore/ops/operations/other_ops.py +3 -3
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +1 -1
- mindspore/ops/primitive.py +8 -9
- mindspore/ops/silent_check.py +5 -5
- mindspore/ops_generate/arg_dtype_cast.py +9 -2
- mindspore/ops_generate/arg_handler.py +0 -26
- mindspore/ops_generate/gen_aclnn_implement.py +4 -1
- mindspore/ops_generate/gen_ops.py +4 -26
- mindspore/ops_generate/gen_pyboost_func.py +12 -41
- mindspore/ops_generate/gen_utils.py +0 -21
- mindspore/ops_generate/pyboost_utils.py +2 -7
- mindspore/ops_generate/template.py +0 -1
- mindspore/parallel/_auto_parallel_context.py +1 -21
- mindspore/parallel/_tensor.py +5 -0
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +1 -15
- mindspore/parallel/algo_parameter_config.py +3 -1
- mindspore/parallel/checkpoint_transform.py +9 -12
- mindspore/parallel/cluster/process_entity/_api.py +29 -28
- mindspore/parallel/cluster/process_entity/_utils.py +3 -13
- mindspore/parallel/cluster/run.py +16 -13
- mindspore/parallel/parameter_broadcast.py +2 -2
- mindspore/parallel/shard.py +17 -31
- mindspore/profiler/__init__.py +2 -3
- mindspore/profiler/common/util.py +2 -107
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +21 -8
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -82
- mindspore/profiler/parser/ascend_analysis/function_event.py +28 -43
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +27 -49
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +10 -15
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +20 -25
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +5 -5
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +1 -10
- mindspore/profiler/parser/ascend_hccl_generator.py +1 -4
- mindspore/profiler/parser/ascend_msprof_exporter.py +22 -43
- mindspore/profiler/parser/ascend_timeline_generator.py +5 -7
- mindspore/profiler/parser/minddata_parser.py +3 -72
- mindspore/profiler/profiling.py +59 -176
- mindspore/rewrite/api/node.py +1 -1
- mindspore/rewrite/common/namespace.py +5 -5
- mindspore/rewrite/parsers/assign_parser.py +0 -2
- mindspore/rewrite/parsers/class_def_parser.py +4 -8
- mindspore/run_check/_check_version.py +1 -1
- mindspore/scipy/fft.py +3 -1
- mindspore/scipy/linalg.py +3 -2
- mindspore/scipy/ops.py +3 -5
- mindspore/scipy/optimize/__init__.py +2 -2
- mindspore/train/__init__.py +4 -4
- mindspore/train/anf_ir_pb2.py +2 -8
- mindspore/train/callback/__init__.py +2 -5
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_checkpoint.py +16 -104
- mindspore/train/callback/_landscape.py +1 -1
- mindspore/train/callback/_time_monitor.py +1 -1
- mindspore/train/data_sink.py +4 -5
- mindspore/train/dataset_helper.py +20 -45
- mindspore/train/model.py +38 -266
- mindspore/train/serialization.py +105 -256
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +2 -2
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +303 -420
- mindspore/_extends/pijit/__init__.py +0 -23
- mindspore/_extends/pijit/pijit_func_white_list.py +0 -343
- mindspore/common/file_system.py +0 -48
- mindspore/common/generator.py +0 -260
- mindspore/common/no_inline.py +0 -54
- mindspore/common/np_dtype.py +0 -25
- mindspore/communication/comm_func.py +0 -1140
- mindspore/hal/memory.py +0 -326
- mindspore/lib/libavcodec.so.59 +0 -0
- mindspore/lib/libavdevice.so.59 +0 -0
- mindspore/lib/libavfilter.so.8 +0 -0
- mindspore/lib/libavformat.so.59 +0 -0
- mindspore/lib/libavutil.so.57 +0 -0
- mindspore/lib/libmindspore_np_dtype.so +0 -0
- mindspore/lib/libswresample.so.4 +0 -0
- mindspore/lib/libswscale.so.6 +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.cpp +0 -326
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/all_finite.py +0 -180
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_576ceaeef5870c451cab59af55ea46ad.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_86a73ff6e28d734c96bb8d3054f7dd18.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.json +0 -58
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/ascend910b/all_finite/AllFinite_f55e0ebaad1f2f572e43677336992fa0.o +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +0 -109
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +0 -38
- mindspore/lib/plugin/ascend/custom_compiler/OWNERS +0 -12
- mindspore/lib/plugin/ascend/custom_compiler/setup.py +0 -255
- mindspore/lib/plugin/ascend/custom_compiler/start.sh +0 -26
- mindspore/lib/plugin/ascend/custom_compiler/template.json +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/base_type.h +0 -133
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_creator.h +0 -32
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/op_param.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/tiling_info.h +0 -60
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/kernel_register.h +0 -37
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/platform_configs.h +0 -89
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/core/platform/rt_funcs.h +0 -135
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/add_op.h +0 -34
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_backoff_base.h +0 -62
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_elewise_op.h +0 -33
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_ops.h +0 -88
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/asd_pa_op.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/cast_op.h +0 -52
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/matmul_op.h +0 -95
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/asd_utils.h +0 -84
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/utils/comm_utils.h +0 -61
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/apply_rotary_pos_emb/kernel/apply_rotary_pos_emb_fp32.h +0 -224
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/and_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/div_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_impl.h +0 -48
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/elewise_binary_tiling.h +0 -25
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/and_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/div_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_base.h +0 -260
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/elewise_binary_kernel.h +0 -35
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/max_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/min_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/mul_kernel.h +0 -66
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/kernel/or_kernel.h +0 -46
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/max_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/min_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/mul_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_binary/or_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/abs_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_impl.h +0 -47
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/elewise_unary_tiling.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/exp_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/abs_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_base.h +0 -148
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/elewise_unary_kernel.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/exp_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/ln_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/not_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/reciprocal_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/relu_kernel.h +0 -55
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/rsqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/kernel/sqrt_kernel.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/ln_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/not_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/reciprocal_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/relu_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/rsqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/elewise_unary/sqrt_impl.h +0 -29
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_impl.h +0 -45
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/grouped_matmul_tiling.h +0 -187
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul.h +0 -245
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_interface.h +0 -24
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/kernel/grouped_matmul_utils.h +0 -111
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/grouped_matmul/tiling_data.h +0 -54
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/compare_param.h +0 -31
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/elewise_param.h +0 -41
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/param/grouped_matmul_param.h +0 -40
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/profiling_util.h +0 -364
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/log/log_utils.h +0 -69
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_creator.h +0 -39
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/register/kernel_registry.h +0 -114
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/utils/utils.h +0 -98
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MatMulPostFusionMixTactic/matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.json +0 -19
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aic_0.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/MultiMatMulPostFusionMixTactic/multi_matmul_postfusion_mix_mix_aiv_0.o +0 -0
- mindspore/mint/linalg/__init__.py +0 -22
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/nn/layer/embedding_service_layer.py +0 -393
- mindspore/ops/function/reshard_func.py +0 -102
- mindspore/ops/operations/_infer_ops.py +0 -19
- mindspore/ops/operations/reshard_ops.py +0 -53
- mindspore/profiler/common/process_pool.py +0 -41
- mindspore/profiler/common/singleton.py +0 -28
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/train/callback/_cluster_monitor.py +0 -201
- mindspore/train/callback/_flops_collector.py +0 -238
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from __future__ import division
|
|
19
19
|
|
|
20
|
-
import binascii
|
|
21
20
|
import copy
|
|
22
21
|
import json
|
|
23
22
|
import os
|
|
@@ -31,7 +30,6 @@ from io import BytesIO
|
|
|
31
30
|
import math
|
|
32
31
|
import sys
|
|
33
32
|
import time
|
|
34
|
-
import google
|
|
35
33
|
import numpy as np
|
|
36
34
|
|
|
37
35
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
@@ -54,7 +52,6 @@ from mindspore.common.parameter import Parameter, _offload_if_config
|
|
|
54
52
|
from mindspore.common.tensor import Tensor
|
|
55
53
|
from mindspore._c_expression import Tensor as Tensor_
|
|
56
54
|
from mindspore.common._utils import is_shape_unknown
|
|
57
|
-
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
58
55
|
from mindspore.communication.management import get_rank, get_group_size
|
|
59
56
|
from mindspore.experimental import MapParameter
|
|
60
57
|
from mindspore.ops import Cast
|
|
@@ -70,9 +67,6 @@ from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_paramet
|
|
|
70
67
|
from mindspore.train._utils import read_proto
|
|
71
68
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
72
69
|
split_mindir, split_dynamic_mindir
|
|
73
|
-
from mindspore.common.generator import Generator
|
|
74
|
-
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
75
|
-
from mindspore.parallel.parameter_broadcast import parameter_broadcast
|
|
76
70
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
77
71
|
|
|
78
72
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -102,19 +96,6 @@ INT_64_MAX = 9223372036854775807
|
|
|
102
96
|
|
|
103
97
|
cpu_cast = Cast().set_device("CPU")
|
|
104
98
|
|
|
105
|
-
_ckpt_fs = FileSystem()
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def init_ckpt_file_system(fs: FileSystem):
|
|
109
|
-
"""Initialize checkpoint file system"""
|
|
110
|
-
if _register_mindio_file_system(fs):
|
|
111
|
-
return
|
|
112
|
-
_register_basic_file_system(fs)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
# Initialize checkpoint file system
|
|
116
|
-
init_ckpt_file_system(_ckpt_fs)
|
|
117
|
-
|
|
118
99
|
|
|
119
100
|
class ParamDictFuture:
|
|
120
101
|
def __init__(self, executor, param_dict_future):
|
|
@@ -252,19 +233,18 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
252
233
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
253
234
|
|
|
254
235
|
|
|
255
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False
|
|
236
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
|
|
256
237
|
"""Execute the process of saving checkpoint into file."""
|
|
257
238
|
try:
|
|
258
239
|
with _ckpt_mutex:
|
|
259
240
|
if os.path.exists(ckpt_file_name):
|
|
260
241
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
261
242
|
os.remove(ckpt_file_name)
|
|
262
|
-
with
|
|
243
|
+
with open(ckpt_file_name, "ab") as f:
|
|
263
244
|
plain_data = None
|
|
264
245
|
if enc_key is not None:
|
|
265
246
|
plain_data = BytesIO()
|
|
266
247
|
|
|
267
|
-
crc_num = 0
|
|
268
248
|
for name, value in data_list.items():
|
|
269
249
|
if name == "random_op":
|
|
270
250
|
_write_random_seed(name, value, f)
|
|
@@ -279,16 +259,16 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
279
259
|
_offload_if_config(value[3])
|
|
280
260
|
continue
|
|
281
261
|
if value[1] == "str":
|
|
282
|
-
|
|
262
|
+
_write_parameter_data(name, value, f, enc_key, plain_data)
|
|
283
263
|
continue
|
|
284
264
|
if isinstance(value[2], np.ndarray):
|
|
285
|
-
|
|
265
|
+
_write_parameter_data(name, value, f, enc_key, plain_data)
|
|
286
266
|
continue
|
|
287
267
|
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
288
268
|
_write_hugeparameter(name, value, f)
|
|
289
269
|
continue
|
|
290
270
|
|
|
291
|
-
|
|
271
|
+
_write_parameter_bytes_data(name, value, f, enc_key, plain_data)
|
|
292
272
|
|
|
293
273
|
if enc_key is not None:
|
|
294
274
|
plain_data.seek(0)
|
|
@@ -298,10 +278,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
298
278
|
f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
|
|
299
279
|
block_data = plain_data.read(max_block_size)
|
|
300
280
|
|
|
301
|
-
|
|
302
|
-
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
303
|
-
|
|
304
|
-
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
281
|
+
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
305
282
|
|
|
306
283
|
except BaseException as e:
|
|
307
284
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
@@ -321,7 +298,7 @@ def _write_random_seed(name, value, f):
|
|
|
321
298
|
f.write(checkpoint_list.SerializeToString())
|
|
322
299
|
|
|
323
300
|
|
|
324
|
-
def _write_parameter_data(name, value, f, enc_key, plain_data
|
|
301
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
325
302
|
"""Write parameter data into protobuf file."""
|
|
326
303
|
data_size = value[2].nbytes / 1024
|
|
327
304
|
if data_size > SLICE_SIZE:
|
|
@@ -340,17 +317,12 @@ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_ch
|
|
|
340
317
|
param_tensor.tensor_content = param_slice.tobytes()
|
|
341
318
|
|
|
342
319
|
if enc_key is None:
|
|
343
|
-
|
|
344
|
-
if crc_check:
|
|
345
|
-
crc_num = binascii.crc32(output_data, crc_num)
|
|
346
|
-
f.write(output_data)
|
|
320
|
+
f.write(checkpoint_list.SerializeToString())
|
|
347
321
|
else:
|
|
348
322
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
349
323
|
|
|
350
|
-
return crc_num
|
|
351
|
-
|
|
352
324
|
|
|
353
|
-
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data
|
|
325
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data):
|
|
354
326
|
"""Write parameter bytes data into protobuf file."""
|
|
355
327
|
bytes_value = value[2].get_bytes()
|
|
356
328
|
chunk_size = 1024 * SLICE_SIZE
|
|
@@ -365,15 +337,10 @@ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0,
|
|
|
365
337
|
param_tensor.tensor_content = bytes_value[i:i + chunk_size]
|
|
366
338
|
|
|
367
339
|
if enc_key is None:
|
|
368
|
-
|
|
369
|
-
if crc_check:
|
|
370
|
-
crc_num = binascii.crc32(output_data, crc_num)
|
|
371
|
-
f.write(output_data)
|
|
340
|
+
f.write(checkpoint_list.SerializeToString())
|
|
372
341
|
else:
|
|
373
342
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
374
343
|
|
|
375
|
-
return crc_num
|
|
376
|
-
|
|
377
344
|
|
|
378
345
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
379
346
|
"""Write map parameter into protobuf file."""
|
|
@@ -434,14 +401,10 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
434
401
|
|
|
435
402
|
|
|
436
403
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
437
|
-
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
438
|
-
crc_check=False, **kwargs):
|
|
404
|
+
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
|
|
439
405
|
r"""
|
|
440
406
|
Save checkpoint to a specified file.
|
|
441
407
|
|
|
442
|
-
Note:
|
|
443
|
-
The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
|
|
444
|
-
|
|
445
408
|
Args:
|
|
446
409
|
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
447
410
|
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
@@ -463,8 +426,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
463
426
|
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
464
427
|
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
465
428
|
be saved. Default: ``None`` .
|
|
466
|
-
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
467
|
-
result to the file. Default: ``False`` .
|
|
468
429
|
kwargs (dict): Configuration options dictionary.
|
|
469
430
|
|
|
470
431
|
Raises:
|
|
@@ -504,32 +465,24 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
504
465
|
append_dict = _check_append_dict(append_dict)
|
|
505
466
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
506
467
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
507
|
-
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
508
468
|
map_param_inc = kwargs.get('incremental', False)
|
|
509
469
|
logger.info("Execute the process of saving checkpoint files.")
|
|
510
|
-
global_step_num = kwargs.get('global_step_num', None)
|
|
511
470
|
|
|
512
471
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
513
472
|
|
|
514
473
|
if append_dict:
|
|
515
474
|
append_info_list = []
|
|
516
475
|
for k_name, value in append_dict.items():
|
|
517
|
-
if isinstance(value,
|
|
518
|
-
value = value.get_state()
|
|
519
|
-
elif not isinstance(value, str):
|
|
476
|
+
if not isinstance(value, str):
|
|
520
477
|
value = Tensor(value)
|
|
521
478
|
append_info_list.append({"name": k_name, "data": value})
|
|
522
479
|
save_obj.extend(append_info_list)
|
|
523
480
|
|
|
524
481
|
data_list = OrderedDict()
|
|
525
|
-
data_list_np = OrderedDict()
|
|
526
482
|
with _ckpt_mutex:
|
|
527
483
|
for param in save_obj:
|
|
528
484
|
if param["name"] == "random_op":
|
|
529
|
-
|
|
530
|
-
data_list_np["random_op"] = param["data"]
|
|
531
|
-
else:
|
|
532
|
-
data_list["random_op"] = param["data"]
|
|
485
|
+
data_list["random_op"] = param["data"]
|
|
533
486
|
continue
|
|
534
487
|
key = param["name"]
|
|
535
488
|
data_list[key] = []
|
|
@@ -545,39 +498,28 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
545
498
|
_save_param_list_data(data_list, key, param)
|
|
546
499
|
|
|
547
500
|
if isinstance(param["data"], str):
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
data_list[key].append('str')
|
|
553
|
-
data = np.array(param["data"])
|
|
554
|
-
data_list[key].append(data)
|
|
501
|
+
data_list[key].append([0])
|
|
502
|
+
data_list[key].append('str')
|
|
503
|
+
data = np.array(param["data"])
|
|
504
|
+
data_list[key].append(data)
|
|
555
505
|
else:
|
|
556
506
|
if isinstance(param["data"], Parameter):
|
|
557
507
|
param["data"].init_data()
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
data_list[key].append(data)
|
|
569
|
-
|
|
570
|
-
if os.getenv("AITURBO") == "1":
|
|
571
|
-
import aiturbo
|
|
572
|
-
ckpt_name = os.path.basename(ckpt_file_name)
|
|
573
|
-
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
|
|
574
|
-
elif async_save:
|
|
508
|
+
dims = []
|
|
509
|
+
for dim in param['data'].shape:
|
|
510
|
+
dims.append(dim)
|
|
511
|
+
data_list[key].append(dims)
|
|
512
|
+
tensor_type = str(param["data"].dtype)
|
|
513
|
+
data_list[key].append(tensor_type)
|
|
514
|
+
data = param["data"]
|
|
515
|
+
data_list[key].append(data)
|
|
516
|
+
|
|
517
|
+
if async_save:
|
|
575
518
|
data_copy = copy.deepcopy(data_list)
|
|
576
|
-
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode,
|
|
577
|
-
name="asyn_save_ckpt")
|
|
519
|
+
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
|
|
578
520
|
thr.start()
|
|
579
521
|
else:
|
|
580
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc
|
|
522
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
|
|
581
523
|
|
|
582
524
|
logger.info("Saving checkpoint process is finished.")
|
|
583
525
|
|
|
@@ -736,9 +678,9 @@ def _check_append_dict(append_dict):
|
|
|
736
678
|
raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
|
|
737
679
|
"{}.".format(type(append_dict)))
|
|
738
680
|
for key, value in append_dict.items():
|
|
739
|
-
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor
|
|
681
|
+
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor)):
|
|
740
682
|
raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
|
|
741
|
-
f"value: int, float
|
|
683
|
+
f"value: int, float or bool, but got key: {type(key)}, value: {type(value)}")
|
|
742
684
|
return append_dict
|
|
743
685
|
|
|
744
686
|
|
|
@@ -1069,76 +1011,12 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
1069
1011
|
obf_net = nn.GraphCell(obf_graph)
|
|
1070
1012
|
if obf_random_seed != 0:
|
|
1071
1013
|
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
1072
|
-
model_inputs += [append_y_tensor]
|
|
1014
|
+
model_inputs += [append_y_tensor,]
|
|
1073
1015
|
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
1074
1016
|
|
|
1075
1017
|
|
|
1076
|
-
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1077
|
-
dec_mode, crc_check):
|
|
1078
|
-
"""load parameter into parameter_dict"""
|
|
1079
|
-
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1080
|
-
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1081
|
-
try:
|
|
1082
|
-
param_data_list = []
|
|
1083
|
-
map_data_list = [[], [], []]
|
|
1084
|
-
map_shape_list = [0, 0, 0]
|
|
1085
|
-
if specify_prefix:
|
|
1086
|
-
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1087
|
-
"please use `choice_func` instead.")
|
|
1088
|
-
if filter_prefix:
|
|
1089
|
-
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1090
|
-
"please use `choice_func` instead.")
|
|
1091
|
-
for element_id, element in enumerate(checkpoint_list.value):
|
|
1092
|
-
if element.tag == "random_op":
|
|
1093
|
-
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1094
|
-
continue
|
|
1095
|
-
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1096
|
-
continue
|
|
1097
|
-
if specify_prefix is None and filter_prefix is None and \
|
|
1098
|
-
choice_func is not None and not choice_func(element.tag):
|
|
1099
|
-
continue
|
|
1100
|
-
if element.tensor.ByteSize() == 0:
|
|
1101
|
-
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
|
|
1102
|
-
parameter_dict)
|
|
1103
|
-
if element.tag in parameter_dict:
|
|
1104
|
-
map_data_list = [[], [], []]
|
|
1105
|
-
map_shape_list = [0, 0, 0]
|
|
1106
|
-
continue
|
|
1107
|
-
data = element.tensor.tensor_content
|
|
1108
|
-
data_type = element.tensor.tensor_type
|
|
1109
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
1110
|
-
ms_type = tensor_to_ms_type[data_type]
|
|
1111
|
-
if data_type == 'str':
|
|
1112
|
-
str_length = int(len(data) / 4)
|
|
1113
|
-
np_type = np_type + str(str_length)
|
|
1114
|
-
param_data_list.append(data)
|
|
1115
|
-
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1116
|
-
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1117
|
-
new_data = b"".join(param_data_list)
|
|
1118
|
-
param_data_list.clear()
|
|
1119
|
-
dims = element.tensor.dims
|
|
1120
|
-
if data_type == 'str':
|
|
1121
|
-
str_value = np.frombuffer(new_data, np_type)
|
|
1122
|
-
parameter_dict[element.tag] = str(str_value[0])
|
|
1123
|
-
else:
|
|
1124
|
-
if dims == [0]:
|
|
1125
|
-
dims = []
|
|
1126
|
-
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1127
|
-
parameter = Parameter(param_data, name=element.tag)
|
|
1128
|
-
parameter_dict[element.tag] = parameter
|
|
1129
|
-
_offload_if_config(parameter)
|
|
1130
|
-
|
|
1131
|
-
logger.info("Loading checkpoint files process is finished.")
|
|
1132
|
-
|
|
1133
|
-
except BaseException as e:
|
|
1134
|
-
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
1135
|
-
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
|
1136
|
-
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
1018
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
1140
|
-
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None
|
|
1141
|
-
crc_check=False):
|
|
1019
|
+
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
1142
1020
|
"""
|
|
1143
1021
|
Load checkpoint info from a specified file.
|
|
1144
1022
|
|
|
@@ -1169,7 +1047,6 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1169
1047
|
and the return value is a bool. If returns ``True`` , the Parameter
|
|
1170
1048
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1171
1049
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1172
|
-
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1173
1050
|
|
|
1174
1051
|
Returns:
|
|
1175
1052
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1214,29 +1091,70 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1214
1091
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1215
1092
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1216
1093
|
"""
|
|
1094
|
+
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1217
1095
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1218
1096
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1219
1097
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1220
1098
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1221
|
-
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1222
1099
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1100
|
+
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
|
|
1223
1101
|
|
|
1224
1102
|
parameter_dict = {}
|
|
1103
|
+
try:
|
|
1104
|
+
param_data_list = []
|
|
1105
|
+
map_data_list = [[], [], []]
|
|
1106
|
+
map_shape_list = [0, 0, 0]
|
|
1107
|
+
if specify_prefix:
|
|
1108
|
+
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1109
|
+
"please use `choice_func` instead.")
|
|
1110
|
+
if filter_prefix:
|
|
1111
|
+
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1112
|
+
"please use `choice_func` instead.")
|
|
1113
|
+
for element_id, element in enumerate(checkpoint_list.value):
|
|
1114
|
+
if element.tag == "random_op":
|
|
1115
|
+
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1116
|
+
continue
|
|
1117
|
+
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1118
|
+
continue
|
|
1119
|
+
if specify_prefix is None and filter_prefix is None and \
|
|
1120
|
+
choice_func is not None and not choice_func(element.tag):
|
|
1121
|
+
continue
|
|
1122
|
+
if element.tensor.ByteSize() == 0:
|
|
1123
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
1124
|
+
if element.tag in parameter_dict:
|
|
1125
|
+
map_data_list = [[], [], []]
|
|
1126
|
+
map_shape_list = [0, 0, 0]
|
|
1127
|
+
continue
|
|
1128
|
+
data = element.tensor.tensor_content
|
|
1129
|
+
data_type = element.tensor.tensor_type
|
|
1130
|
+
np_type = tensor_to_np_type.get(data_type)
|
|
1131
|
+
ms_type = tensor_to_ms_type[data_type]
|
|
1132
|
+
if data_type == 'str':
|
|
1133
|
+
str_length = int(len(data) / 4)
|
|
1134
|
+
np_type = np_type + str(str_length)
|
|
1135
|
+
param_data_list.append(data)
|
|
1136
|
+
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1137
|
+
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1138
|
+
new_data = b"".join(param_data_list)
|
|
1139
|
+
param_data_list.clear()
|
|
1140
|
+
dims = element.tensor.dims
|
|
1141
|
+
if dims == [0] and data_type == 'str':
|
|
1142
|
+
str_value = np.frombuffer(new_data, np_type)
|
|
1143
|
+
parameter_dict[element.tag] = str(str_value[0])
|
|
1144
|
+
else:
|
|
1145
|
+
if dims == [0]:
|
|
1146
|
+
dims = []
|
|
1147
|
+
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1148
|
+
parameter = Parameter(param_data, name=element.tag)
|
|
1149
|
+
parameter_dict[element.tag] = parameter
|
|
1150
|
+
_offload_if_config(parameter)
|
|
1225
1151
|
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
for key, value in np_dict.items():
|
|
1233
|
-
if isinstance(value, str):
|
|
1234
|
-
parameter_dict[key] = value
|
|
1235
|
-
else:
|
|
1236
|
-
parameter_dict[key] = Parameter(Tensor(value), name=key)
|
|
1237
|
-
else:
|
|
1238
|
-
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1239
|
-
dec_mode, crc_check)
|
|
1152
|
+
logger.info("Loading checkpoint files process is finished.")
|
|
1153
|
+
|
|
1154
|
+
except BaseException as e:
|
|
1155
|
+
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
1156
|
+
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
|
1157
|
+
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
|
1240
1158
|
|
|
1241
1159
|
if not parameter_dict:
|
|
1242
1160
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1403,28 +1321,17 @@ def _check_prefix(prefix):
|
|
|
1403
1321
|
return prefix
|
|
1404
1322
|
|
|
1405
1323
|
|
|
1406
|
-
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode
|
|
1324
|
+
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
|
1407
1325
|
"""Parse checkpoint protobuf."""
|
|
1408
1326
|
checkpoint_list = Checkpoint()
|
|
1409
1327
|
try:
|
|
1410
1328
|
if dec_key is None:
|
|
1411
|
-
with
|
|
1329
|
+
with open(ckpt_file_name, "rb") as f:
|
|
1412
1330
|
pb_content = f.read()
|
|
1413
1331
|
else:
|
|
1414
1332
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1415
1333
|
if pb_content is None:
|
|
1416
1334
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
|
1417
|
-
if crc_check and pb_content[-17:-10] == b"crc_num":
|
|
1418
|
-
logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
|
|
1419
|
-
if pb_content[-17:-10] == b"crc_num":
|
|
1420
|
-
crc_num_bytes = pb_content[-10:]
|
|
1421
|
-
pb_content = pb_content[:-17]
|
|
1422
|
-
if crc_check:
|
|
1423
|
-
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
1424
|
-
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
1425
|
-
if cal_crc_num != crc_num:
|
|
1426
|
-
raise ValueError("For 'load_checkpoint', the crc check is failed, "
|
|
1427
|
-
"please check whether the ckpt file is damaged.")
|
|
1428
1335
|
checkpoint_list.ParseFromString(pb_content)
|
|
1429
1336
|
except BaseException as e:
|
|
1430
1337
|
if _is_cipher_file(ckpt_file_name):
|
|
@@ -1457,33 +1364,13 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
1457
1364
|
|
|
1458
1365
|
def _init_parameter_data_in_parallel_mode(net, parameter_dict):
|
|
1459
1366
|
"""In parallel mode, only init the paraemters in ckpt."""
|
|
1460
|
-
is_train_phase = net.phase.startswith('train')
|
|
1461
1367
|
for _, param in net.parameters_and_names():
|
|
1462
|
-
if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
|
|
1463
|
-
param.shape = tuple(parameter_dict[param.name].shape)
|
|
1464
|
-
continue
|
|
1465
1368
|
if param.name in parameter_dict and param.has_init:
|
|
1466
1369
|
logger.warning("{} is not init while load ckpt.".format(param.name))
|
|
1467
1370
|
new_tensor = param.init_data()
|
|
1468
1371
|
param._update_tensor_data(new_tensor)
|
|
1469
1372
|
|
|
1470
1373
|
|
|
1471
|
-
def _check_load_param_into_net(net, parameter_dict):
|
|
1472
|
-
"""check load_param_into_net"""
|
|
1473
|
-
if not isinstance(net, nn.Cell):
|
|
1474
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1475
|
-
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1476
|
-
raise TypeError(msg)
|
|
1477
|
-
if not isinstance(parameter_dict, dict):
|
|
1478
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1479
|
-
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1480
|
-
"but got {}.".format(type(parameter_dict)))
|
|
1481
|
-
raise TypeError(msg)
|
|
1482
|
-
if "random_op" in parameter_dict.keys():
|
|
1483
|
-
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1484
|
-
parameter_dict.pop("random_op")
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
1374
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
1488
1375
|
"""
|
|
1489
1376
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
@@ -1520,7 +1407,18 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1520
1407
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1521
1408
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1522
1409
|
"""
|
|
1523
|
-
|
|
1410
|
+
if not isinstance(net, nn.Cell):
|
|
1411
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1412
|
+
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1413
|
+
raise TypeError(msg)
|
|
1414
|
+
if not isinstance(parameter_dict, dict):
|
|
1415
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1416
|
+
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1417
|
+
"but got {}.".format(type(parameter_dict)))
|
|
1418
|
+
raise TypeError(msg)
|
|
1419
|
+
if "random_op" in parameter_dict.keys():
|
|
1420
|
+
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1421
|
+
parameter_dict.pop("random_op")
|
|
1524
1422
|
for key, value in parameter_dict.items():
|
|
1525
1423
|
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1526
1424
|
logger.critical("Load parameters into net failed.")
|
|
@@ -1530,8 +1428,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1530
1428
|
|
|
1531
1429
|
strict_load = Validator.check_bool(strict_load)
|
|
1532
1430
|
logger.info("Execute the process of loading parameters into net.")
|
|
1533
|
-
for _, param in net.parameters_and_names():
|
|
1534
|
-
param.from_ckpt = True
|
|
1535
1431
|
if not _is_in_auto_parallel_mode():
|
|
1536
1432
|
net.init_parameters_data()
|
|
1537
1433
|
else:
|
|
@@ -1546,7 +1442,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1546
1442
|
# Add has attr protection when load server checkpoint file on worker.
|
|
1547
1443
|
if not hasattr(parameter_dict[param.name], "data"):
|
|
1548
1444
|
continue
|
|
1549
|
-
new_param = parameter_dict[param.name]
|
|
1445
|
+
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
1550
1446
|
_update_param(param, new_param, strict_load)
|
|
1551
1447
|
ckpt_not_load.remove(param.name)
|
|
1552
1448
|
else:
|
|
@@ -1562,14 +1458,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1562
1458
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1563
1459
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
1564
1460
|
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1565
|
-
if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None:
|
|
1566
|
-
param_layout = net.parameter_layout_dict
|
|
1567
|
-
param_redundancy = get_parameter_redundancy(param_layout)
|
|
1568
|
-
remove_param_redundancy_dict = remove_param_redundancy(param_redundancy)
|
|
1569
|
-
target_parameter_name_set = set(parameter_dict.keys())
|
|
1570
|
-
for rank_id, param_name_set in remove_param_redundancy_dict:
|
|
1571
|
-
if param_name_set == target_parameter_name_set:
|
|
1572
|
-
parameter_broadcast(net, param_layout, rank_id)
|
|
1573
1461
|
return param_not_load, ckpt_not_load
|
|
1574
1462
|
|
|
1575
1463
|
|
|
@@ -2321,45 +2209,6 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
2321
2209
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
2322
2210
|
|
|
2323
2211
|
|
|
2324
|
-
def check_checkpoint(ckpt_file_name):
|
|
2325
|
-
"""
|
|
2326
|
-
Check whether the checkpoint is valid.
|
|
2327
|
-
|
|
2328
|
-
Args:
|
|
2329
|
-
ckpt_file_name (str): Checkpoint file name.
|
|
2330
|
-
|
|
2331
|
-
Returns:
|
|
2332
|
-
bool, whether the checkpoint is valid.
|
|
2333
|
-
|
|
2334
|
-
Examples:
|
|
2335
|
-
>>> import mindspore as ms
|
|
2336
|
-
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
2337
|
-
>>> check_result = ms.check_checkpoint(ckpt_file_name)
|
|
2338
|
-
>>> print(check_result)
|
|
2339
|
-
True
|
|
2340
|
-
"""
|
|
2341
|
-
if not ckpt_file_name.endswith('.ckpt'):
|
|
2342
|
-
return False
|
|
2343
|
-
checkpoint_list = Checkpoint()
|
|
2344
|
-
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
2345
|
-
pb_content = f.read()
|
|
2346
|
-
if pb_content[-17:-10] == b"crc_num":
|
|
2347
|
-
crc_num_bytes = pb_content[-10:]
|
|
2348
|
-
pb_content = pb_content[:-17]
|
|
2349
|
-
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
2350
|
-
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
2351
|
-
if cal_crc_num != crc_num:
|
|
2352
|
-
logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
|
|
2353
|
-
return False
|
|
2354
|
-
try:
|
|
2355
|
-
checkpoint_list.ParseFromString(pb_content)
|
|
2356
|
-
except google.protobuf.message.DecodeError as e:
|
|
2357
|
-
logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
|
|
2358
|
-
logger.warning(e)
|
|
2359
|
-
return False
|
|
2360
|
-
return True
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
2212
|
def parse_print(print_file_name):
|
|
2364
2213
|
"""
|
|
2365
2214
|
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
@@ -190,7 +190,7 @@ def _nptype_to_prototype(np_value):
|
|
|
190
190
|
if proto is None:
|
|
191
191
|
raise TypeError("Transform numpy type failed in Summary, expect numpy type is one of ['np.bool_', 'np.int8', "
|
|
192
192
|
"'np.int16', 'np.int32', 'np.int64', 'np.uint8', 'np.uint16', 'np.uint32', 'np.uint64', "
|
|
193
|
-
"'np.float16', 'np.
|
|
193
|
+
"'np.float16', 'np.float', 'np.float64'].")
|
|
194
194
|
|
|
195
195
|
return proto
|
|
196
196
|
|
mindspore/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.3.
|
|
1
|
+
__version__ = '2.3.0rc2'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mindspore
|
|
3
|
-
Version: 2.3.
|
|
3
|
+
Version: 2.3.0rc2
|
|
4
4
|
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
|
|
5
5
|
Home-page: https://www.mindspore.cn
|
|
6
6
|
Download-URL: https://github.com/mindspore-ai/mindspore/tags
|
|
@@ -26,7 +26,7 @@ Classifier: Topic :: Software Development :: Libraries
|
|
|
26
26
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
27
27
|
Requires-Python: >=3.7
|
|
28
28
|
Description-Content-Type: text/markdown
|
|
29
|
-
Requires-Dist: numpy (
|
|
29
|
+
Requires-Dist: numpy (>=1.17.0)
|
|
30
30
|
Requires-Dist: protobuf (>=3.13.0)
|
|
31
31
|
Requires-Dist: asttokens (>=2.0.4)
|
|
32
32
|
Requires-Dist: pillow (>=6.2.0)
|