mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -36,13 +36,12 @@ from functools import partial
|
|
|
36
36
|
import math
|
|
37
37
|
import sys
|
|
38
38
|
import time
|
|
39
|
-
import numpy as np
|
|
40
39
|
from safetensors.numpy import save_file
|
|
40
|
+
import numpy as np
|
|
41
41
|
import google
|
|
42
42
|
|
|
43
43
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
44
44
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
45
|
-
from mindspore.train.print_pb2 import Print
|
|
46
45
|
|
|
47
46
|
import mindspore
|
|
48
47
|
import mindspore.nn as nn
|
|
@@ -52,15 +51,13 @@ from mindspore.log import vlog_print
|
|
|
52
51
|
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
53
52
|
from mindspore import _checkparam as Validator
|
|
54
53
|
from mindspore.common import dtype as mstype
|
|
55
|
-
from mindspore.common import np_dtype
|
|
56
54
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
57
55
|
from mindspore.common.api import _JitExecutor
|
|
58
56
|
from mindspore.common.api import _get_parameter_layout
|
|
59
|
-
from mindspore.common.initializer import initializer
|
|
57
|
+
from mindspore.common.initializer import initializer
|
|
60
58
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
61
59
|
from mindspore.common.tensor import Tensor
|
|
62
60
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
63
|
-
from mindspore.common._utils import is_shape_unknown
|
|
64
61
|
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
65
62
|
from mindspore.communication.management import get_rank, get_group_size
|
|
66
63
|
from mindspore.experimental import MapParameter
|
|
@@ -76,9 +73,9 @@ from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint
|
|
|
76
73
|
from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
|
|
77
74
|
from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
|
|
78
75
|
from mindspore.parallel.transform_safetensors import _fast_safe_open
|
|
79
|
-
from mindspore.train._utils import
|
|
76
|
+
from mindspore.train._utils import get_parameter_redundancy, _progress_bar, _load_and_transform
|
|
80
77
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
|
|
81
|
-
split_mindir, split_dynamic_mindir
|
|
78
|
+
split_mindir, split_dynamic_mindir, _get_snapshot_params
|
|
82
79
|
from mindspore.common.generator import Generator
|
|
83
80
|
|
|
84
81
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -86,12 +83,9 @@ tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype
|
|
|
86
83
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
87
84
|
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
|
|
88
85
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if hasattr(np_dtype, "bfloat16"):
|
|
94
|
-
tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
|
|
86
|
+
_tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
87
|
+
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
88
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
95
89
|
|
|
96
90
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
97
91
|
|
|
@@ -114,6 +108,21 @@ INT_64_MAX = 9223372036854775807
|
|
|
114
108
|
cpu_cast = Cast().set_device("CPU")
|
|
115
109
|
|
|
116
110
|
_ckpt_fs = FileSystem()
|
|
111
|
+
_ckpt_fs_initialized = False
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def tensor_to_np_type(tensor_type_str):
|
|
115
|
+
"""tensor to numpy type"""
|
|
116
|
+
if tensor_type_str == "BFloat16":
|
|
117
|
+
from mindspore.common import np_dtype
|
|
118
|
+
if not np_dtype.np_dtype_valid(True):
|
|
119
|
+
raise TypeError(
|
|
120
|
+
"The Numpy bfloat16 data type is not supported now, please ensure that the current "
|
|
121
|
+
"Numpy version is not less than the version when the mindspore is compiled, "
|
|
122
|
+
"and the major versions are same."
|
|
123
|
+
)
|
|
124
|
+
return np_dtype.bfloat16
|
|
125
|
+
return _tensor_to_np_type.get(tensor_type_str)
|
|
117
126
|
|
|
118
127
|
|
|
119
128
|
def init_ckpt_file_system(fs: FileSystem):
|
|
@@ -123,8 +132,12 @@ def init_ckpt_file_system(fs: FileSystem):
|
|
|
123
132
|
_register_basic_file_system(fs)
|
|
124
133
|
|
|
125
134
|
|
|
126
|
-
|
|
127
|
-
|
|
135
|
+
def _ensure_ckpt_fs_initialized():
|
|
136
|
+
"""Ensure checkpoint file system is initialized"""
|
|
137
|
+
global _ckpt_fs_initialized
|
|
138
|
+
if not _ckpt_fs_initialized:
|
|
139
|
+
init_ckpt_file_system(_ckpt_fs)
|
|
140
|
+
_ckpt_fs_initialized = True
|
|
128
141
|
|
|
129
142
|
|
|
130
143
|
def _wait_async_process_save_ckpt():
|
|
@@ -401,9 +414,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
401
414
|
crc_num, crc_check,
|
|
402
415
|
ckpt_total_io_time)
|
|
403
416
|
continue
|
|
404
|
-
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
405
|
-
_write_hugeparameter(name, value, f)
|
|
406
|
-
continue
|
|
407
417
|
|
|
408
418
|
crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
409
419
|
crc_num, crc_check,
|
|
@@ -458,7 +468,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
458
468
|
f"simultaneously modified a file.")
|
|
459
469
|
elif _ckpt_fs.backend != "mindio":
|
|
460
470
|
os.rename(tmp_name, ckpt_file_name)
|
|
461
|
-
|
|
471
|
+
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
462
472
|
except BaseException as e:
|
|
463
473
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
464
474
|
"or the disk space is insufficient and so on.", ckpt_file_name)
|
|
@@ -546,27 +556,6 @@ def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
|
546
556
|
break
|
|
547
557
|
|
|
548
558
|
|
|
549
|
-
def _write_hugeparameter(name, value, f):
|
|
550
|
-
"""Write huge parameter into protobuf file."""
|
|
551
|
-
slice_num = value[2].slice_num
|
|
552
|
-
offset = 0
|
|
553
|
-
max_size = value[0][0]
|
|
554
|
-
for param_slice in range(slice_num):
|
|
555
|
-
checkpoint_list = Checkpoint()
|
|
556
|
-
param_value = checkpoint_list.value.add()
|
|
557
|
-
param_value.tag = name
|
|
558
|
-
param_tensor = param_value.tensor
|
|
559
|
-
param_tensor.dims.extend(value[0])
|
|
560
|
-
param_tensor.tensor_type = value[1]
|
|
561
|
-
param_key = value[3]
|
|
562
|
-
numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
|
|
563
|
-
if offset + numpy_data.shape[0] > max_size:
|
|
564
|
-
numpy_data = numpy_data[:max_size - offset]
|
|
565
|
-
param_tensor.tensor_content = numpy_data.tobytes()
|
|
566
|
-
f.write(checkpoint_list.SerializeToString())
|
|
567
|
-
offset += numpy_data.shape[0]
|
|
568
|
-
|
|
569
|
-
|
|
570
559
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
571
560
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
572
561
|
if format not in ["safetensors", "ckpt"]:
|
|
@@ -718,6 +707,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
718
707
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
719
708
|
"""
|
|
720
709
|
start_save_time = time.time()
|
|
710
|
+
_ensure_ckpt_fs_initialized()
|
|
721
711
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
722
712
|
integrated_save = Validator.check_bool(integrated_save)
|
|
723
713
|
async_save = _check_async_save(async_save)
|
|
@@ -767,9 +757,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
767
757
|
data_list[param["name"]].append(param["data"])
|
|
768
758
|
continue
|
|
769
759
|
if isinstance(param["data"], list):
|
|
770
|
-
if param["data"][0] == "
|
|
771
|
-
_save_param_list_data(data_list, key, param)
|
|
772
|
-
elif param["data"][0] == "offload_parameter":
|
|
760
|
+
if param["data"][0] == "offload_parameter":
|
|
773
761
|
data_list[key].append("offload_parameter")
|
|
774
762
|
_save_param_list_data(data_list, key, param)
|
|
775
763
|
|
|
@@ -955,6 +943,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
955
943
|
if not is_parallel_mode:
|
|
956
944
|
save_obj.init_parameters_data()
|
|
957
945
|
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
|
|
946
|
+
enable_ckpt_d2h_sync = os.getenv('MS_ENABLE_D2H_ASYNC') == '1'
|
|
947
|
+
param_snapshot = _get_snapshot_params() if enable_ckpt_d2h_sync else {}
|
|
958
948
|
for (key, value) in param_dict.items():
|
|
959
949
|
each_param = {"name": key}
|
|
960
950
|
if isinstance(value, MapParameter):
|
|
@@ -962,10 +952,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
962
952
|
param_list.append(each_param)
|
|
963
953
|
continue
|
|
964
954
|
|
|
965
|
-
if value.data.
|
|
966
|
-
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
967
|
-
param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
|
|
968
|
-
elif value.data.offload_file_path() != "":
|
|
955
|
+
if value.data.offload_file_path() != "":
|
|
969
956
|
# list save offload data: [Param, shape, type, param.key]
|
|
970
957
|
param_data = ["offload_parameter"]
|
|
971
958
|
param_tensor = value.data
|
|
@@ -980,7 +967,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
980
967
|
if append_dict and "__exception_save__" in append_dict:
|
|
981
968
|
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
982
969
|
else:
|
|
983
|
-
|
|
970
|
+
# when enable MS_ENABLE_D2H_ASYNC=1, fetch param from sanpshot in priority
|
|
971
|
+
param_data = param_snapshot.get(key, Tensor(value.data))
|
|
984
972
|
|
|
985
973
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
986
974
|
# which should be combined before saving
|
|
@@ -1004,13 +992,16 @@ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choi
|
|
|
1004
992
|
|
|
1005
993
|
return _handle_shared_param_for_pipeline_parallel(save_obj)
|
|
1006
994
|
|
|
1007
|
-
|
|
995
|
+
if isinstance(save_obj, nn.Cell):
|
|
996
|
+
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
997
|
+
|
|
998
|
+
raise TypeError("For 'save_checkpoint', the argument 'save_obj' must be list、dict or nn.cell, "
|
|
999
|
+
"but got {}.".format(type(save_obj)))
|
|
1008
1000
|
|
|
1009
1001
|
|
|
1010
1002
|
def _save_param_list_data(data_list, key, param):
|
|
1011
1003
|
"""Save persistent data into save_obj."""
|
|
1012
1004
|
dims = []
|
|
1013
|
-
# persistent_data shape can not be ()
|
|
1014
1005
|
for dim in param['data'][2]:
|
|
1015
1006
|
dims.append(dim)
|
|
1016
1007
|
data_list[key].append(dims)
|
|
@@ -1268,11 +1259,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1268
1259
|
continue
|
|
1269
1260
|
data = element.tensor.tensor_content
|
|
1270
1261
|
data_type = element.tensor.tensor_type
|
|
1271
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
1272
1262
|
ms_type = tensor_to_ms_type[data_type]
|
|
1273
|
-
if data_type == 'str':
|
|
1274
|
-
str_length = int(len(data) / 4)
|
|
1275
|
-
np_type = np_type + str(str_length)
|
|
1276
1263
|
param_data_list.append(data)
|
|
1277
1264
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1278
1265
|
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
@@ -1280,6 +1267,8 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1280
1267
|
param_data_list.clear()
|
|
1281
1268
|
dims = element.tensor.dims
|
|
1282
1269
|
if data_type == 'str':
|
|
1270
|
+
str_length = int(len(data) / 4)
|
|
1271
|
+
np_type = "U" + str(str_length)
|
|
1283
1272
|
str_value = np.frombuffer(new_data, np_type)
|
|
1284
1273
|
parameter_dict[element.tag] = str(str_value[0])
|
|
1285
1274
|
else:
|
|
@@ -1288,7 +1277,6 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1288
1277
|
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1289
1278
|
parameter = Parameter(param_data, name=element.tag)
|
|
1290
1279
|
parameter_dict[element.tag] = parameter
|
|
1291
|
-
_offload_if_config(parameter)
|
|
1292
1280
|
|
|
1293
1281
|
logger.info("Loading checkpoint files process is finished.")
|
|
1294
1282
|
return remove_redundancy
|
|
@@ -1386,6 +1374,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1386
1374
|
"""
|
|
1387
1375
|
start_load_time = time.time()
|
|
1388
1376
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
|
|
1377
|
+
_ensure_ckpt_fs_initialized()
|
|
1389
1378
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1390
1379
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1391
1380
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
@@ -2133,6 +2122,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
2133
2122
|
if file_format == 'AIR':
|
|
2134
2123
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
2135
2124
|
elif file_format == 'ONNX':
|
|
2125
|
+
logger.warning("mindspore.export(file_format='ONNX') will be deleted, please use mindspore.onnx.export()")
|
|
2136
2126
|
_save_onnx(net, file_name, *inputs, **kwargs)
|
|
2137
2127
|
elif file_format == 'MINDIR':
|
|
2138
2128
|
_save_mindir(net, file_name, *inputs, **kwargs)
|
|
@@ -2198,6 +2188,11 @@ def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
|
2198
2188
|
file_name += ".onnx"
|
|
2199
2189
|
if os.path.exists(file_name):
|
|
2200
2190
|
os.chmod(file_name, stat.S_IWUSR)
|
|
2191
|
+
else:
|
|
2192
|
+
dir_path = os.path.dirname(file_name)
|
|
2193
|
+
if not os.path.exists(dir_path):
|
|
2194
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
2195
|
+
os.chmod(dir_path, 0o700)
|
|
2201
2196
|
with open(file_name, 'wb') as f:
|
|
2202
2197
|
f.write(onnx_stream)
|
|
2203
2198
|
os.chmod(file_name, stat.S_IRUSR)
|
|
@@ -2477,147 +2472,6 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
2477
2472
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
2478
2473
|
|
|
2479
2474
|
|
|
2480
|
-
def check_checkpoint(ckpt_file_name):
|
|
2481
|
-
"""
|
|
2482
|
-
Check whether the checkpoint is valid.
|
|
2483
|
-
|
|
2484
|
-
Note:
|
|
2485
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2486
|
-
|
|
2487
|
-
Args:
|
|
2488
|
-
ckpt_file_name (str): Checkpoint file name.
|
|
2489
|
-
|
|
2490
|
-
Returns:
|
|
2491
|
-
bool, whether the checkpoint is valid.
|
|
2492
|
-
|
|
2493
|
-
Examples:
|
|
2494
|
-
>>> import mindspore as ms
|
|
2495
|
-
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
2496
|
-
>>> check_result = ms.check_checkpoint(ckpt_file_name)
|
|
2497
|
-
>>> print(check_result)
|
|
2498
|
-
True
|
|
2499
|
-
"""
|
|
2500
|
-
logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
|
|
2501
|
-
"and will be removed in a future version.")
|
|
2502
|
-
if not ckpt_file_name.endswith('.ckpt'):
|
|
2503
|
-
return False
|
|
2504
|
-
checkpoint_list = Checkpoint()
|
|
2505
|
-
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
2506
|
-
pb_content = f.read()
|
|
2507
|
-
if pb_content[-17:-10] == b"crc_num":
|
|
2508
|
-
crc_num_bytes = pb_content[-10:]
|
|
2509
|
-
pb_content = pb_content[:-17]
|
|
2510
|
-
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
2511
|
-
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
2512
|
-
if cal_crc_num != crc_num:
|
|
2513
|
-
logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
|
|
2514
|
-
return False
|
|
2515
|
-
try:
|
|
2516
|
-
checkpoint_list.ParseFromString(pb_content)
|
|
2517
|
-
except google.protobuf.message.DecodeError as e:
|
|
2518
|
-
logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
|
|
2519
|
-
logger.warning(e)
|
|
2520
|
-
return False
|
|
2521
|
-
return True
|
|
2522
|
-
|
|
2523
|
-
|
|
2524
|
-
def parse_print(print_file_name):
|
|
2525
|
-
"""
|
|
2526
|
-
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
2527
|
-
|
|
2528
|
-
Note:
|
|
2529
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2530
|
-
|
|
2531
|
-
Args:
|
|
2532
|
-
print_file_name (str): The file name needs to be parsed.
|
|
2533
|
-
|
|
2534
|
-
Returns:
|
|
2535
|
-
List, element of list is Tensor.
|
|
2536
|
-
|
|
2537
|
-
Raises:
|
|
2538
|
-
ValueError: The print file does not exist or is empty.
|
|
2539
|
-
RuntimeError: Failed to parse the file.
|
|
2540
|
-
|
|
2541
|
-
Examples:
|
|
2542
|
-
>>> import numpy as np
|
|
2543
|
-
>>> import mindspore as ms
|
|
2544
|
-
>>> from mindspore import nn, Tensor, ops
|
|
2545
|
-
>>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
|
|
2546
|
-
>>> class PrintInputTensor(nn.Cell):
|
|
2547
|
-
... def __init__(self):
|
|
2548
|
-
... super().__init__()
|
|
2549
|
-
... self.print = ops.Print()
|
|
2550
|
-
...
|
|
2551
|
-
... def construct(self, input_pra):
|
|
2552
|
-
... self.print('print:', input_pra)
|
|
2553
|
-
... return input_pra
|
|
2554
|
-
>>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
|
|
2555
|
-
>>> input_pra = Tensor(x)
|
|
2556
|
-
>>> net = PrintInputTensor()
|
|
2557
|
-
>>> net(input_pra)
|
|
2558
|
-
>>>
|
|
2559
|
-
>>> data = ms.parse_print('./log.data')
|
|
2560
|
-
>>> print(data)
|
|
2561
|
-
['print:', Tensor(shape=[2, 4], dtype=Float32, value=
|
|
2562
|
-
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2563
|
-
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2564
|
-
"""
|
|
2565
|
-
logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
|
|
2566
|
-
"and will be removed in a future version.")
|
|
2567
|
-
print_file_path = os.path.realpath(print_file_name)
|
|
2568
|
-
|
|
2569
|
-
if os.path.getsize(print_file_path) == 0:
|
|
2570
|
-
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
2571
|
-
"'print_file_name'.")
|
|
2572
|
-
|
|
2573
|
-
logger.info("Execute load print process.")
|
|
2574
|
-
print_list = Print()
|
|
2575
|
-
|
|
2576
|
-
try:
|
|
2577
|
-
with open(print_file_path, "rb") as f:
|
|
2578
|
-
pb_content = f.read()
|
|
2579
|
-
print_list.ParseFromString(pb_content)
|
|
2580
|
-
except BaseException as e:
|
|
2581
|
-
logger.critical("Failed to read the print file %s, please check whether the file is "
|
|
2582
|
-
"correct.", print_file_name)
|
|
2583
|
-
raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether "
|
|
2584
|
-
"the file is correct.".format(print_file_name)) from e
|
|
2585
|
-
|
|
2586
|
-
tensor_list = []
|
|
2587
|
-
|
|
2588
|
-
try:
|
|
2589
|
-
for print_ in print_list.value:
|
|
2590
|
-
# String type
|
|
2591
|
-
if print_.HasField("desc"):
|
|
2592
|
-
tensor_list.append(print_.desc)
|
|
2593
|
-
elif print_.HasField("tensor"):
|
|
2594
|
-
dims = print_.tensor.dims
|
|
2595
|
-
data_type = print_.tensor.tensor_type
|
|
2596
|
-
data = print_.tensor.tensor_content
|
|
2597
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
2598
|
-
param_data = np.fromstring(data, np_type)
|
|
2599
|
-
ms_type = tensor_to_ms_type.get(data_type)
|
|
2600
|
-
if dims and dims != [0]:
|
|
2601
|
-
param_value = param_data.reshape(dims)
|
|
2602
|
-
tensor_list.append(Tensor(param_value, ms_type))
|
|
2603
|
-
# Scalar type
|
|
2604
|
-
else:
|
|
2605
|
-
data_type_ = data_type.lower()
|
|
2606
|
-
if 'float' in data_type_:
|
|
2607
|
-
param_data = float(param_data[0])
|
|
2608
|
-
elif 'int' in data_type_:
|
|
2609
|
-
param_data = int(param_data[0])
|
|
2610
|
-
elif 'bool' in data_type_:
|
|
2611
|
-
param_data = bool(param_data[0])
|
|
2612
|
-
tensor_list.append(Tensor(param_data, ms_type))
|
|
2613
|
-
|
|
2614
|
-
except BaseException as e:
|
|
2615
|
-
logger.critical("Failed to load the print file %s.", print_list)
|
|
2616
|
-
raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e
|
|
2617
|
-
|
|
2618
|
-
return tensor_list
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
2475
|
def async_ckpt_thread_status():
|
|
2622
2476
|
"""
|
|
2623
2477
|
Get the status of asynchronous save checkpoint thread.
|
|
@@ -2652,170 +2506,132 @@ def _calculation_net_size(net):
|
|
|
2652
2506
|
return data_total
|
|
2653
2507
|
|
|
2654
2508
|
|
|
2655
|
-
def
|
|
2509
|
+
def _load_file_and_convert_name(path, name_map=None, format="ckpt"):
|
|
2656
2510
|
"""
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
Note:
|
|
2660
|
-
1. Parsing encrypted MindIR file is not supported.
|
|
2661
|
-
2. Parsing dynamic shape MindIR file is not supported.
|
|
2511
|
+
Load file, during load convert name by name_map.
|
|
2662
2512
|
|
|
2663
2513
|
Args:
|
|
2664
|
-
|
|
2514
|
+
path (str): The file path.
|
|
2515
|
+
name_map (dict): Convert the name of parameter by name_map.
|
|
2516
|
+
format (str): The format of the file. Option: 'ckpt', 'safetensors'
|
|
2665
2517
|
|
|
2666
2518
|
Returns:
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
Raises:
|
|
2670
|
-
TypeError: If the parameter file_name is not `str`.
|
|
2671
|
-
RuntimeError: MindIR's input is not tensor type or has no dims.
|
|
2672
|
-
|
|
2673
|
-
Examples:
|
|
2674
|
-
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
2519
|
+
Dict, key is parameter name, value is a Parameter or string.
|
|
2675
2520
|
"""
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
input_tensor = []
|
|
2680
|
-
|
|
2681
|
-
for ele_input in model.graph.input:
|
|
2682
|
-
input_shape = []
|
|
2683
|
-
if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
|
|
2684
|
-
raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
|
|
2685
|
-
|
|
2686
|
-
for ele_shape in ele_input.tensor[0].dims:
|
|
2687
|
-
input_shape.append(ele_shape)
|
|
2688
|
-
if is_shape_unknown(input_shape):
|
|
2689
|
-
raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
|
|
2690
|
-
|
|
2691
|
-
mindir_type = ele_input.tensor[0].data_type
|
|
2692
|
-
if mindir_type not in mindir_to_tensor_type:
|
|
2693
|
-
raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
|
|
2694
|
-
|
|
2695
|
-
input_type = mindir_to_tensor_type.get(mindir_type)
|
|
2696
|
-
input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
|
|
2697
|
-
|
|
2698
|
-
if not input_tensor:
|
|
2699
|
-
logger.warning("The MindIR model has no input, return None.")
|
|
2700
|
-
return None
|
|
2701
|
-
return input_tensor[0] if len(input_tensor) == 1 else input_tensor
|
|
2702
|
-
|
|
2703
|
-
|
|
2704
|
-
def convert_model(mindir_file, convert_file, file_format):
|
|
2705
|
-
"""
|
|
2706
|
-
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
2707
|
-
|
|
2708
|
-
Note:
|
|
2709
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2710
|
-
|
|
2711
|
-
Args:
|
|
2712
|
-
mindir_file (str): MindIR file name.
|
|
2713
|
-
convert_file (str): Convert model file name.
|
|
2714
|
-
file_format (str): Convert model's format, current version only supports "ONNX".
|
|
2715
|
-
|
|
2716
|
-
Raises:
|
|
2717
|
-
TypeError: If the parameter `mindir_file` is not `str`.
|
|
2718
|
-
TypeError: If the parameter `convert_file` is not `str`.
|
|
2719
|
-
ValueError: If the parameter `file_format` is not "ONNX".
|
|
2720
|
-
|
|
2721
|
-
Examples:
|
|
2722
|
-
>>> import mindspore as ms
|
|
2723
|
-
>>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
|
|
2724
|
-
"""
|
|
2725
|
-
logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
|
|
2726
|
-
"and will be removed in a future version.")
|
|
2727
|
-
Validator.check_file_name_by_regular(mindir_file)
|
|
2728
|
-
Validator.check_file_name_by_regular(convert_file)
|
|
2729
|
-
if file_format != "ONNX":
|
|
2730
|
-
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
|
|
2731
|
-
net_input = _get_mindir_inputs(mindir_file)
|
|
2732
|
-
graph = load(mindir_file)
|
|
2733
|
-
net = nn.GraphCell(graph)
|
|
2734
|
-
if isinstance(net_input, Tensor):
|
|
2735
|
-
export(net, net_input, file_name=convert_file, file_format=file_format)
|
|
2736
|
-
else:
|
|
2737
|
-
export(net, *net_input, file_name=convert_file, file_format=file_format)
|
|
2738
|
-
|
|
2739
|
-
|
|
2740
|
-
def _load_ckpt_to_new_name_map(path, name_map=None):
|
|
2741
|
-
return _load_and_transform(path, name_map, mindspore.load_checkpoint, None)
|
|
2742
|
-
|
|
2521
|
+
if name_map is not None:
|
|
2522
|
+
load_func = partial(mindspore.load_checkpoint, format=format)
|
|
2523
|
+
return _load_and_transform(path, name_map, load_func)
|
|
2743
2524
|
|
|
2744
|
-
|
|
2745
|
-
load_func = partial(mindspore.load_checkpoint, format="safetensors")
|
|
2746
|
-
return _load_and_transform(path, name_map, load_func, None)
|
|
2525
|
+
return mindspore.load_checkpoint(path, format=format)
|
|
2747
2526
|
|
|
2748
2527
|
|
|
2749
2528
|
def _process_file(file_info):
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2529
|
+
"""Rrocess file."""
|
|
2530
|
+
cur_path, name_map, save_path, file, dst_format = file_info
|
|
2531
|
+
if dst_format == "safetensors":
|
|
2532
|
+
param_dict = _load_file_and_convert_name(cur_path, name_map, format="ckpt")
|
|
2533
|
+
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
2534
|
+
dst_file = os.path.join(save_path, safetensors_filename)
|
|
2535
|
+
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2753
2536
|
else:
|
|
2754
|
-
param_dict =
|
|
2755
|
-
|
|
2756
|
-
|
|
2757
|
-
|
|
2537
|
+
param_dict = _load_file_and_convert_name(cur_path, name_map, format="safetensors")
|
|
2538
|
+
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2539
|
+
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2540
|
+
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2758
2541
|
|
|
2759
2542
|
|
|
2760
|
-
def
|
|
2761
|
-
|
|
2762
|
-
if
|
|
2763
|
-
|
|
2543
|
+
def _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format="ckpt"):
|
|
2544
|
+
"""gather transform rank together"""
|
|
2545
|
+
if dst_format == "ckpt":
|
|
2546
|
+
cur_file_suffix = ".safetensors"
|
|
2764
2547
|
else:
|
|
2765
|
-
|
|
2766
|
-
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2767
|
-
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2768
|
-
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2548
|
+
cur_file_suffix = ".ckpt"
|
|
2769
2549
|
|
|
2770
|
-
|
|
2771
|
-
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
2772
|
-
"""gather transform rank together"""
|
|
2773
|
-
tasks = []
|
|
2550
|
+
tasks_list = []
|
|
2774
2551
|
for root, dirs, _ in os.walk(file_path):
|
|
2775
2552
|
if root != file_path:
|
|
2776
2553
|
continue
|
|
2777
2554
|
|
|
2778
2555
|
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2779
2556
|
if not rank_dirs:
|
|
2780
|
-
|
|
2781
|
-
|
|
2557
|
+
if dst_format == "safetensors":
|
|
2558
|
+
raise ValueError(
|
|
2559
|
+
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}.")
|
|
2560
|
+
if dst_format == "ckpt":
|
|
2561
|
+
raise ValueError(
|
|
2562
|
+
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}.")
|
|
2563
|
+
|
|
2564
|
+
raise ValueError(f"For '_gather_all_tasks', error args 'format': {dst_format}.")
|
|
2782
2565
|
|
|
2783
2566
|
for rank_dir in rank_dirs:
|
|
2784
2567
|
rank_dir_path = os.path.join(root, rank_dir)
|
|
2785
|
-
|
|
2786
|
-
|
|
2568
|
+
if save_path is not None:
|
|
2569
|
+
dst_root = os.path.join(save_path, os.path.relpath(rank_dir_path, file_path))
|
|
2570
|
+
else:
|
|
2571
|
+
dst_root = rank_dir_path
|
|
2572
|
+
|
|
2787
2573
|
os.makedirs(dst_root, exist_ok=True)
|
|
2788
|
-
tasks.extend(
|
|
2789
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
2790
|
-
for file in os.listdir(rank_dir_path)
|
|
2791
|
-
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
2792
|
-
)
|
|
2793
|
-
return tasks
|
|
2794
2574
|
|
|
2575
|
+
for file in os.listdir(rank_dir_path):
|
|
2576
|
+
if file.endswith(cur_file_suffix) and (file_name_regex is None or re.search(file_name_regex, file)):
|
|
2577
|
+
tasks_list.append((os.path.join(rank_dir_path, file), name_map, dst_root, file, dst_format))
|
|
2795
2578
|
|
|
2796
|
-
|
|
2797
|
-
"""gather transform rank together"""
|
|
2798
|
-
tasks = []
|
|
2799
|
-
for root, dirs, _ in os.walk(file_path):
|
|
2800
|
-
if root != file_path:
|
|
2801
|
-
continue
|
|
2579
|
+
return tasks_list
|
|
2802
2580
|
|
|
2803
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2804
|
-
if not rank_dirs:
|
|
2805
|
-
raise ValueError(
|
|
2806
|
-
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
2807
2581
|
|
|
2808
|
-
|
|
2809
|
-
|
|
2810
|
-
|
|
2811
|
-
|
|
2812
|
-
|
|
2813
|
-
|
|
2814
|
-
|
|
2815
|
-
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
|
|
2582
|
+
def _convert_checkpoint_file(file_path, save_path=None, name_map=None, file_name_regex=None,
|
|
2583
|
+
processes_num=1, dst_format="safetensors"):
|
|
2584
|
+
"""
|
|
2585
|
+
Converts MindSpore checkpoint files format and saves them to `save_path`.
|
|
2586
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
2587
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
2588
|
+
|
|
2589
|
+
Args:
|
|
2590
|
+
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
2591
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
|
|
2592
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2593
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2594
|
+
Default: ``None``.
|
|
2595
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2596
|
+
dst_format (str): dst file format. Default: "safetensors".
|
|
2597
|
+
"""
|
|
2598
|
+
if dst_format == "safetensors":
|
|
2599
|
+
src_format = "ckpt"
|
|
2600
|
+
src_file_suffix = ".ckpt"
|
|
2601
|
+
dst_file_suffix = ".safetensors"
|
|
2602
|
+
func_name = "ckpt_to_safetensors"
|
|
2603
|
+
else:
|
|
2604
|
+
src_format = "safetensors"
|
|
2605
|
+
src_file_suffix = ".safetensors"
|
|
2606
|
+
dst_file_suffix = ".ckpt"
|
|
2607
|
+
func_name = "safetensors_to_ckpt"
|
|
2608
|
+
is_dir = os.path.isdir(file_path)
|
|
2609
|
+
is_file = os.path.isfile(file_path)
|
|
2610
|
+
if not is_dir and not is_file:
|
|
2611
|
+
raise ValueError(f"For {func_name}, the input path must be a valid path or file, but got {file_path}")
|
|
2612
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
2613
|
+
raise ValueError(f"For {func_name}, the save_path must be a directory, but got '{save_path}'")
|
|
2614
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
2615
|
+
raise ValueError(
|
|
2616
|
+
f"For {func_name}, the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2617
|
+
|
|
2618
|
+
if is_dir:
|
|
2619
|
+
tasks_list = _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format=dst_format)
|
|
2620
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
2621
|
+
list(_progress_bar(pool.imap(_process_file, tasks_list), total=len(tasks_list)))
|
|
2622
|
+
elif is_file:
|
|
2623
|
+
if not file_path.endswith(src_file_suffix):
|
|
2624
|
+
raise ValueError(f"For {func_name}, the input file must be a {src_file_suffix} file, but got {file_path}")
|
|
2625
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2626
|
+
raise ValueError(f"For {func_name}, the input file does not match the regular expression.")
|
|
2627
|
+
if save_path and not os.path.exists(save_path):
|
|
2628
|
+
os.makedirs(save_path, exist_ok=True)
|
|
2629
|
+
|
|
2630
|
+
param_dict = _load_file_and_convert_name(file_path, name_map, format=src_format)
|
|
2631
|
+
|
|
2632
|
+
file_filename = os.path.basename(file_path).replace(src_file_suffix, dst_file_suffix)
|
|
2633
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), file_filename)
|
|
2634
|
+
mindspore.save_checkpoint(param_dict, dst_file, format=dst_format)
|
|
2819
2635
|
|
|
2820
2636
|
|
|
2821
2637
|
def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
@@ -2834,11 +2650,11 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2834
2650
|
|
|
2835
2651
|
Args:
|
|
2836
2652
|
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
2837
|
-
save_path (str, optional): Directory path where safetensors files will be saved.
|
|
2838
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names.
|
|
2653
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
|
|
2654
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2839
2655
|
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2840
|
-
|
|
2841
|
-
processes_num (int, optional): Number of processes to use for parallel processing.
|
|
2656
|
+
Default: ``None``.
|
|
2657
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2842
2658
|
Raises:
|
|
2843
2659
|
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
2844
2660
|
or the file_path does not end with '.ckpt'.
|
|
@@ -2854,36 +2670,8 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2854
2670
|
>>> namemap = {"lin.weight":"new_name"}
|
|
2855
2671
|
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
2856
2672
|
"""
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
|
-
if not is_dir and not is_file:
|
|
2860
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
2861
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
2862
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
2863
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
2864
|
-
raise ValueError(
|
|
2865
|
-
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2866
|
-
|
|
2867
|
-
if is_dir:
|
|
2868
|
-
tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
|
|
2869
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
2870
|
-
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
2871
|
-
elif is_file:
|
|
2872
|
-
if not file_path.endswith(".ckpt"):
|
|
2873
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
2874
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2875
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
2876
|
-
if save_path and not os.path.exists(save_path):
|
|
2877
|
-
os.makedirs(save_path, exist_ok=True)
|
|
2878
|
-
|
|
2879
|
-
if name_map is not None:
|
|
2880
|
-
param_dict = _load_ckpt_to_new_name_map(file_path, name_map)
|
|
2881
|
-
else:
|
|
2882
|
-
param_dict = mindspore.load_checkpoint(file_path)
|
|
2883
|
-
|
|
2884
|
-
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
2885
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
2886
|
-
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2673
|
+
_convert_checkpoint_file(file_path, save_path, name_map,
|
|
2674
|
+
file_name_regex, processes_num, "safetensors")
|
|
2887
2675
|
|
|
2888
2676
|
|
|
2889
2677
|
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
@@ -2898,11 +2686,11 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2898
2686
|
|
|
2899
2687
|
Args:
|
|
2900
2688
|
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
2901
|
-
save_path (str, optional): Directory path where checkpoint files will be saved.
|
|
2902
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names.
|
|
2689
|
+
save_path (str, optional): Directory path where checkpoint files will be saved. Default: ``None``.
|
|
2690
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2903
2691
|
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2904
|
-
|
|
2905
|
-
processes_num (int, optional): Number of processes to use for parallel processing.
|
|
2692
|
+
Default: ``None``.
|
|
2693
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2906
2694
|
|
|
2907
2695
|
Raises:
|
|
2908
2696
|
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
@@ -2919,37 +2707,8 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2919
2707
|
>>> namemap = {"lin.weight":"new_name"}
|
|
2920
2708
|
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
2921
2709
|
"""
|
|
2922
|
-
|
|
2923
|
-
|
|
2924
|
-
if not is_dir and not is_file:
|
|
2925
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
2926
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
2927
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
2928
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
2929
|
-
raise ValueError(
|
|
2930
|
-
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2931
|
-
|
|
2932
|
-
if is_dir:
|
|
2933
|
-
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
2934
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
2935
|
-
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
2936
|
-
elif is_file:
|
|
2937
|
-
if not file_path.endswith(".safetensors"):
|
|
2938
|
-
raise ValueError(
|
|
2939
|
-
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
2940
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2941
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
2942
|
-
if save_path and not os.path.exists(save_path):
|
|
2943
|
-
os.makedirs(save_path, exist_ok=True)
|
|
2944
|
-
|
|
2945
|
-
if name_map is not None:
|
|
2946
|
-
param_dict = _load_sf_to_new_name_map(file_path, name_map)
|
|
2947
|
-
else:
|
|
2948
|
-
param_dict = mindspore.load_checkpoint(file_path, format="safetensors")
|
|
2949
|
-
|
|
2950
|
-
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
2951
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
2952
|
-
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2710
|
+
_convert_checkpoint_file(file_path, save_path, name_map,
|
|
2711
|
+
file_name_regex, processes_num, "ckpt")
|
|
2953
2712
|
|
|
2954
2713
|
|
|
2955
2714
|
def restore_group_info_list(group_info_file_name):
|