mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -36,13 +36,12 @@ from functools import partial
|
|
|
36
36
|
import math
|
|
37
37
|
import sys
|
|
38
38
|
import time
|
|
39
|
-
import numpy as np
|
|
40
39
|
from safetensors.numpy import save_file
|
|
40
|
+
import numpy as np
|
|
41
41
|
import google
|
|
42
42
|
|
|
43
43
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
44
44
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
45
|
-
from mindspore.train.print_pb2 import Print
|
|
46
45
|
|
|
47
46
|
import mindspore
|
|
48
47
|
import mindspore.nn as nn
|
|
@@ -55,11 +54,10 @@ from mindspore.common import dtype as mstype
|
|
|
55
54
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
56
55
|
from mindspore.common.api import _JitExecutor
|
|
57
56
|
from mindspore.common.api import _get_parameter_layout
|
|
58
|
-
from mindspore.common.initializer import initializer
|
|
57
|
+
from mindspore.common.initializer import initializer
|
|
59
58
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
60
59
|
from mindspore.common.tensor import Tensor
|
|
61
60
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
62
|
-
from mindspore.common._utils import is_shape_unknown
|
|
63
61
|
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
64
62
|
from mindspore.communication.management import get_rank, get_group_size
|
|
65
63
|
from mindspore.experimental import MapParameter
|
|
@@ -75,9 +73,9 @@ from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint
|
|
|
75
73
|
from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
|
|
76
74
|
from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
|
|
77
75
|
from mindspore.parallel.transform_safetensors import _fast_safe_open
|
|
78
|
-
from mindspore.train._utils import
|
|
76
|
+
from mindspore.train._utils import get_parameter_redundancy, _progress_bar, _load_and_transform
|
|
79
77
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
|
|
80
|
-
split_mindir, split_dynamic_mindir
|
|
78
|
+
split_mindir, split_dynamic_mindir, _get_snapshot_params
|
|
81
79
|
from mindspore.common.generator import Generator
|
|
82
80
|
|
|
83
81
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -416,9 +414,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
416
414
|
crc_num, crc_check,
|
|
417
415
|
ckpt_total_io_time)
|
|
418
416
|
continue
|
|
419
|
-
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
420
|
-
_write_hugeparameter(name, value, f)
|
|
421
|
-
continue
|
|
422
417
|
|
|
423
418
|
crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
424
419
|
crc_num, crc_check,
|
|
@@ -561,27 +556,6 @@ def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
|
561
556
|
break
|
|
562
557
|
|
|
563
558
|
|
|
564
|
-
def _write_hugeparameter(name, value, f):
|
|
565
|
-
"""Write huge parameter into protobuf file."""
|
|
566
|
-
slice_num = value[2].slice_num
|
|
567
|
-
offset = 0
|
|
568
|
-
max_size = value[0][0]
|
|
569
|
-
for param_slice in range(slice_num):
|
|
570
|
-
checkpoint_list = Checkpoint()
|
|
571
|
-
param_value = checkpoint_list.value.add()
|
|
572
|
-
param_value.tag = name
|
|
573
|
-
param_tensor = param_value.tensor
|
|
574
|
-
param_tensor.dims.extend(value[0])
|
|
575
|
-
param_tensor.tensor_type = value[1]
|
|
576
|
-
param_key = value[3]
|
|
577
|
-
numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
|
|
578
|
-
if offset + numpy_data.shape[0] > max_size:
|
|
579
|
-
numpy_data = numpy_data[:max_size - offset]
|
|
580
|
-
param_tensor.tensor_content = numpy_data.tobytes()
|
|
581
|
-
f.write(checkpoint_list.SerializeToString())
|
|
582
|
-
offset += numpy_data.shape[0]
|
|
583
|
-
|
|
584
|
-
|
|
585
559
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
586
560
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
587
561
|
if format not in ["safetensors", "ckpt"]:
|
|
@@ -783,9 +757,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
783
757
|
data_list[param["name"]].append(param["data"])
|
|
784
758
|
continue
|
|
785
759
|
if isinstance(param["data"], list):
|
|
786
|
-
if param["data"][0] == "
|
|
787
|
-
_save_param_list_data(data_list, key, param)
|
|
788
|
-
elif param["data"][0] == "offload_parameter":
|
|
760
|
+
if param["data"][0] == "offload_parameter":
|
|
789
761
|
data_list[key].append("offload_parameter")
|
|
790
762
|
_save_param_list_data(data_list, key, param)
|
|
791
763
|
|
|
@@ -971,6 +943,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
971
943
|
if not is_parallel_mode:
|
|
972
944
|
save_obj.init_parameters_data()
|
|
973
945
|
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
|
|
946
|
+
enable_ckpt_d2h_sync = os.getenv('MS_ENABLE_D2H_ASYNC') == '1'
|
|
947
|
+
param_snapshot = _get_snapshot_params() if enable_ckpt_d2h_sync else {}
|
|
974
948
|
for (key, value) in param_dict.items():
|
|
975
949
|
each_param = {"name": key}
|
|
976
950
|
if isinstance(value, MapParameter):
|
|
@@ -978,10 +952,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
978
952
|
param_list.append(each_param)
|
|
979
953
|
continue
|
|
980
954
|
|
|
981
|
-
if value.data.
|
|
982
|
-
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
983
|
-
param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
|
|
984
|
-
elif value.data.offload_file_path() != "":
|
|
955
|
+
if value.data.offload_file_path() != "":
|
|
985
956
|
# list save offload data: [Param, shape, type, param.key]
|
|
986
957
|
param_data = ["offload_parameter"]
|
|
987
958
|
param_tensor = value.data
|
|
@@ -996,7 +967,8 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
996
967
|
if append_dict and "__exception_save__" in append_dict:
|
|
997
968
|
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
998
969
|
else:
|
|
999
|
-
|
|
970
|
+
# when enable MS_ENABLE_D2H_ASYNC=1, fetch param from sanpshot in priority
|
|
971
|
+
param_data = param_snapshot.get(key, Tensor(value.data))
|
|
1000
972
|
|
|
1001
973
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
1002
974
|
# which should be combined before saving
|
|
@@ -1020,13 +992,16 @@ def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choi
|
|
|
1020
992
|
|
|
1021
993
|
return _handle_shared_param_for_pipeline_parallel(save_obj)
|
|
1022
994
|
|
|
1023
|
-
|
|
995
|
+
if isinstance(save_obj, nn.Cell):
|
|
996
|
+
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
997
|
+
|
|
998
|
+
raise TypeError("For 'save_checkpoint', the argument 'save_obj' must be list、dict or nn.cell, "
|
|
999
|
+
"but got {}.".format(type(save_obj)))
|
|
1024
1000
|
|
|
1025
1001
|
|
|
1026
1002
|
def _save_param_list_data(data_list, key, param):
|
|
1027
1003
|
"""Save persistent data into save_obj."""
|
|
1028
1004
|
dims = []
|
|
1029
|
-
# persistent_data shape can not be ()
|
|
1030
1005
|
for dim in param['data'][2]:
|
|
1031
1006
|
dims.append(dim)
|
|
1032
1007
|
data_list[key].append(dims)
|
|
@@ -1302,7 +1277,6 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1302
1277
|
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1303
1278
|
parameter = Parameter(param_data, name=element.tag)
|
|
1304
1279
|
parameter_dict[element.tag] = parameter
|
|
1305
|
-
_offload_if_config(parameter)
|
|
1306
1280
|
|
|
1307
1281
|
logger.info("Loading checkpoint files process is finished.")
|
|
1308
1282
|
return remove_redundancy
|
|
@@ -2148,6 +2122,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
2148
2122
|
if file_format == 'AIR':
|
|
2149
2123
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
2150
2124
|
elif file_format == 'ONNX':
|
|
2125
|
+
logger.warning("mindspore.export(file_format='ONNX') will be deleted, please use mindspore.onnx.export()")
|
|
2151
2126
|
_save_onnx(net, file_name, *inputs, **kwargs)
|
|
2152
2127
|
elif file_format == 'MINDIR':
|
|
2153
2128
|
_save_mindir(net, file_name, *inputs, **kwargs)
|
|
@@ -2497,147 +2472,6 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
2497
2472
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
2498
2473
|
|
|
2499
2474
|
|
|
2500
|
-
def check_checkpoint(ckpt_file_name):
|
|
2501
|
-
"""
|
|
2502
|
-
Check whether the checkpoint is valid.
|
|
2503
|
-
|
|
2504
|
-
Note:
|
|
2505
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2506
|
-
|
|
2507
|
-
Args:
|
|
2508
|
-
ckpt_file_name (str): Checkpoint file name.
|
|
2509
|
-
|
|
2510
|
-
Returns:
|
|
2511
|
-
bool, whether the checkpoint is valid.
|
|
2512
|
-
|
|
2513
|
-
Examples:
|
|
2514
|
-
>>> import mindspore as ms
|
|
2515
|
-
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
2516
|
-
>>> check_result = ms.check_checkpoint(ckpt_file_name)
|
|
2517
|
-
>>> print(check_result)
|
|
2518
|
-
True
|
|
2519
|
-
"""
|
|
2520
|
-
logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
|
|
2521
|
-
"and will be removed in a future version.")
|
|
2522
|
-
if not ckpt_file_name.endswith('.ckpt'):
|
|
2523
|
-
return False
|
|
2524
|
-
checkpoint_list = Checkpoint()
|
|
2525
|
-
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
2526
|
-
pb_content = f.read()
|
|
2527
|
-
if pb_content[-17:-10] == b"crc_num":
|
|
2528
|
-
crc_num_bytes = pb_content[-10:]
|
|
2529
|
-
pb_content = pb_content[:-17]
|
|
2530
|
-
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
2531
|
-
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
2532
|
-
if cal_crc_num != crc_num:
|
|
2533
|
-
logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
|
|
2534
|
-
return False
|
|
2535
|
-
try:
|
|
2536
|
-
checkpoint_list.ParseFromString(pb_content)
|
|
2537
|
-
except google.protobuf.message.DecodeError as e:
|
|
2538
|
-
logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
|
|
2539
|
-
logger.warning(e)
|
|
2540
|
-
return False
|
|
2541
|
-
return True
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
def parse_print(print_file_name):
|
|
2545
|
-
"""
|
|
2546
|
-
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
2547
|
-
|
|
2548
|
-
Note:
|
|
2549
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2550
|
-
|
|
2551
|
-
Args:
|
|
2552
|
-
print_file_name (str): The file name needs to be parsed.
|
|
2553
|
-
|
|
2554
|
-
Returns:
|
|
2555
|
-
List, element of list is Tensor.
|
|
2556
|
-
|
|
2557
|
-
Raises:
|
|
2558
|
-
ValueError: The print file does not exist or is empty.
|
|
2559
|
-
RuntimeError: Failed to parse the file.
|
|
2560
|
-
|
|
2561
|
-
Examples:
|
|
2562
|
-
>>> import numpy as np
|
|
2563
|
-
>>> import mindspore as ms
|
|
2564
|
-
>>> from mindspore import nn, Tensor, ops
|
|
2565
|
-
>>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
|
|
2566
|
-
>>> class PrintInputTensor(nn.Cell):
|
|
2567
|
-
... def __init__(self):
|
|
2568
|
-
... super().__init__()
|
|
2569
|
-
... self.print = ops.Print()
|
|
2570
|
-
...
|
|
2571
|
-
... def construct(self, input_pra):
|
|
2572
|
-
... self.print('print:', input_pra)
|
|
2573
|
-
... return input_pra
|
|
2574
|
-
>>> x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]).astype(np.float32)
|
|
2575
|
-
>>> input_pra = Tensor(x)
|
|
2576
|
-
>>> net = PrintInputTensor()
|
|
2577
|
-
>>> net(input_pra)
|
|
2578
|
-
>>>
|
|
2579
|
-
>>> data = ms.parse_print('./log.data')
|
|
2580
|
-
>>> print(data)
|
|
2581
|
-
['print:', Tensor(shape=[2, 4], dtype=Float32, value=
|
|
2582
|
-
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2583
|
-
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2584
|
-
"""
|
|
2585
|
-
logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
|
|
2586
|
-
"and will be removed in a future version.")
|
|
2587
|
-
print_file_path = os.path.realpath(print_file_name)
|
|
2588
|
-
|
|
2589
|
-
if os.path.getsize(print_file_path) == 0:
|
|
2590
|
-
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
2591
|
-
"'print_file_name'.")
|
|
2592
|
-
|
|
2593
|
-
logger.info("Execute load print process.")
|
|
2594
|
-
print_list = Print()
|
|
2595
|
-
|
|
2596
|
-
try:
|
|
2597
|
-
with open(print_file_path, "rb") as f:
|
|
2598
|
-
pb_content = f.read()
|
|
2599
|
-
print_list.ParseFromString(pb_content)
|
|
2600
|
-
except BaseException as e:
|
|
2601
|
-
logger.critical("Failed to read the print file %s, please check whether the file is "
|
|
2602
|
-
"correct.", print_file_name)
|
|
2603
|
-
raise ValueError(e.__str__() + "\nFailed to read the print file {}, please check whether "
|
|
2604
|
-
"the file is correct.".format(print_file_name)) from e
|
|
2605
|
-
|
|
2606
|
-
tensor_list = []
|
|
2607
|
-
|
|
2608
|
-
try:
|
|
2609
|
-
for print_ in print_list.value:
|
|
2610
|
-
# String type
|
|
2611
|
-
if print_.HasField("desc"):
|
|
2612
|
-
tensor_list.append(print_.desc)
|
|
2613
|
-
elif print_.HasField("tensor"):
|
|
2614
|
-
dims = print_.tensor.dims
|
|
2615
|
-
data_type = print_.tensor.tensor_type
|
|
2616
|
-
data = print_.tensor.tensor_content
|
|
2617
|
-
np_type = tensor_to_np_type(data_type)
|
|
2618
|
-
param_data = np.fromstring(data, np_type)
|
|
2619
|
-
ms_type = tensor_to_ms_type.get(data_type)
|
|
2620
|
-
if dims and dims != [0]:
|
|
2621
|
-
param_value = param_data.reshape(dims)
|
|
2622
|
-
tensor_list.append(Tensor(param_value, ms_type))
|
|
2623
|
-
# Scalar type
|
|
2624
|
-
else:
|
|
2625
|
-
data_type_ = data_type.lower()
|
|
2626
|
-
if 'float' in data_type_:
|
|
2627
|
-
param_data = float(param_data[0])
|
|
2628
|
-
elif 'int' in data_type_:
|
|
2629
|
-
param_data = int(param_data[0])
|
|
2630
|
-
elif 'bool' in data_type_:
|
|
2631
|
-
param_data = bool(param_data[0])
|
|
2632
|
-
tensor_list.append(Tensor(param_data, ms_type))
|
|
2633
|
-
|
|
2634
|
-
except BaseException as e:
|
|
2635
|
-
logger.critical("Failed to load the print file %s.", print_list)
|
|
2636
|
-
raise RuntimeError(e.__str__() + "\nFailed to load the print file {}.".format(print_list)) from e
|
|
2637
|
-
|
|
2638
|
-
return tensor_list
|
|
2639
|
-
|
|
2640
|
-
|
|
2641
2475
|
def async_ckpt_thread_status():
|
|
2642
2476
|
"""
|
|
2643
2477
|
Get the status of asynchronous save checkpoint thread.
|
|
@@ -2672,170 +2506,132 @@ def _calculation_net_size(net):
|
|
|
2672
2506
|
return data_total
|
|
2673
2507
|
|
|
2674
2508
|
|
|
2675
|
-
def
|
|
2509
|
+
def _load_file_and_convert_name(path, name_map=None, format="ckpt"):
|
|
2676
2510
|
"""
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
Note:
|
|
2680
|
-
1. Parsing encrypted MindIR file is not supported.
|
|
2681
|
-
2. Parsing dynamic shape MindIR file is not supported.
|
|
2511
|
+
Load file, during load convert name by name_map.
|
|
2682
2512
|
|
|
2683
2513
|
Args:
|
|
2684
|
-
|
|
2514
|
+
path (str): The file path.
|
|
2515
|
+
name_map (dict): Convert the name of parameter by name_map.
|
|
2516
|
+
format (str): The format of the file. Option: 'ckpt', 'safetensors'
|
|
2685
2517
|
|
|
2686
2518
|
Returns:
|
|
2687
|
-
|
|
2688
|
-
|
|
2689
|
-
Raises:
|
|
2690
|
-
TypeError: If the parameter file_name is not `str`.
|
|
2691
|
-
RuntimeError: MindIR's input is not tensor type or has no dims.
|
|
2692
|
-
|
|
2693
|
-
Examples:
|
|
2694
|
-
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
2695
|
-
"""
|
|
2696
|
-
Validator.check_file_name_by_regular(file_name)
|
|
2697
|
-
file_name = os.path.realpath(file_name)
|
|
2698
|
-
model = read_proto(file_name)
|
|
2699
|
-
input_tensor = []
|
|
2700
|
-
|
|
2701
|
-
for ele_input in model.graph.input:
|
|
2702
|
-
input_shape = []
|
|
2703
|
-
if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
|
|
2704
|
-
raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
|
|
2705
|
-
|
|
2706
|
-
for ele_shape in ele_input.tensor[0].dims:
|
|
2707
|
-
input_shape.append(ele_shape)
|
|
2708
|
-
if is_shape_unknown(input_shape):
|
|
2709
|
-
raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
|
|
2710
|
-
|
|
2711
|
-
mindir_type = ele_input.tensor[0].data_type
|
|
2712
|
-
if mindir_type not in mindir_to_tensor_type:
|
|
2713
|
-
raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
|
|
2714
|
-
|
|
2715
|
-
input_type = mindir_to_tensor_type.get(mindir_type)
|
|
2716
|
-
input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
|
|
2717
|
-
|
|
2718
|
-
if not input_tensor:
|
|
2719
|
-
logger.warning("The MindIR model has no input, return None.")
|
|
2720
|
-
return None
|
|
2721
|
-
return input_tensor[0] if len(input_tensor) == 1 else input_tensor
|
|
2722
|
-
|
|
2723
|
-
|
|
2724
|
-
def convert_model(mindir_file, convert_file, file_format):
|
|
2725
|
-
"""
|
|
2726
|
-
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
2727
|
-
|
|
2728
|
-
Note:
|
|
2729
|
-
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2730
|
-
|
|
2731
|
-
Args:
|
|
2732
|
-
mindir_file (str): MindIR file name.
|
|
2733
|
-
convert_file (str): Convert model file name.
|
|
2734
|
-
file_format (str): Convert model's format, current version only supports "ONNX".
|
|
2735
|
-
|
|
2736
|
-
Raises:
|
|
2737
|
-
TypeError: If the parameter `mindir_file` is not `str`.
|
|
2738
|
-
TypeError: If the parameter `convert_file` is not `str`.
|
|
2739
|
-
ValueError: If the parameter `file_format` is not "ONNX".
|
|
2740
|
-
|
|
2741
|
-
Examples:
|
|
2742
|
-
>>> import mindspore as ms
|
|
2743
|
-
>>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
|
|
2519
|
+
Dict, key is parameter name, value is a Parameter or string.
|
|
2744
2520
|
"""
|
|
2745
|
-
|
|
2746
|
-
|
|
2747
|
-
|
|
2748
|
-
Validator.check_file_name_by_regular(convert_file)
|
|
2749
|
-
if file_format != "ONNX":
|
|
2750
|
-
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
|
|
2751
|
-
net_input = _get_mindir_inputs(mindir_file)
|
|
2752
|
-
graph = load(mindir_file)
|
|
2753
|
-
net = nn.GraphCell(graph)
|
|
2754
|
-
if isinstance(net_input, Tensor):
|
|
2755
|
-
export(net, net_input, file_name=convert_file, file_format=file_format)
|
|
2756
|
-
else:
|
|
2757
|
-
export(net, *net_input, file_name=convert_file, file_format=file_format)
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
def _load_ckpt_to_new_name_map(path, name_map=None):
|
|
2761
|
-
return _load_and_transform(path, name_map, mindspore.load_checkpoint, None)
|
|
2762
|
-
|
|
2521
|
+
if name_map is not None:
|
|
2522
|
+
load_func = partial(mindspore.load_checkpoint, format=format)
|
|
2523
|
+
return _load_and_transform(path, name_map, load_func)
|
|
2763
2524
|
|
|
2764
|
-
|
|
2765
|
-
load_func = partial(mindspore.load_checkpoint, format="safetensors")
|
|
2766
|
-
return _load_and_transform(path, name_map, load_func, None)
|
|
2525
|
+
return mindspore.load_checkpoint(path, format=format)
|
|
2767
2526
|
|
|
2768
2527
|
|
|
2769
2528
|
def _process_file(file_info):
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
|
|
2529
|
+
"""Rrocess file."""
|
|
2530
|
+
cur_path, name_map, save_path, file, dst_format = file_info
|
|
2531
|
+
if dst_format == "safetensors":
|
|
2532
|
+
param_dict = _load_file_and_convert_name(cur_path, name_map, format="ckpt")
|
|
2533
|
+
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
2534
|
+
dst_file = os.path.join(save_path, safetensors_filename)
|
|
2535
|
+
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2773
2536
|
else:
|
|
2774
|
-
param_dict =
|
|
2775
|
-
|
|
2776
|
-
|
|
2777
|
-
|
|
2537
|
+
param_dict = _load_file_and_convert_name(cur_path, name_map, format="safetensors")
|
|
2538
|
+
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2539
|
+
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2540
|
+
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2778
2541
|
|
|
2779
2542
|
|
|
2780
|
-
def
|
|
2781
|
-
|
|
2782
|
-
if
|
|
2783
|
-
|
|
2543
|
+
def _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format="ckpt"):
|
|
2544
|
+
"""gather transform rank together"""
|
|
2545
|
+
if dst_format == "ckpt":
|
|
2546
|
+
cur_file_suffix = ".safetensors"
|
|
2784
2547
|
else:
|
|
2785
|
-
|
|
2786
|
-
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2787
|
-
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2788
|
-
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2789
|
-
|
|
2548
|
+
cur_file_suffix = ".ckpt"
|
|
2790
2549
|
|
|
2791
|
-
|
|
2792
|
-
"""gather transform rank together"""
|
|
2793
|
-
tasks = []
|
|
2550
|
+
tasks_list = []
|
|
2794
2551
|
for root, dirs, _ in os.walk(file_path):
|
|
2795
2552
|
if root != file_path:
|
|
2796
2553
|
continue
|
|
2797
2554
|
|
|
2798
2555
|
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2799
2556
|
if not rank_dirs:
|
|
2800
|
-
|
|
2801
|
-
|
|
2557
|
+
if dst_format == "safetensors":
|
|
2558
|
+
raise ValueError(
|
|
2559
|
+
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}.")
|
|
2560
|
+
if dst_format == "ckpt":
|
|
2561
|
+
raise ValueError(
|
|
2562
|
+
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}.")
|
|
2563
|
+
|
|
2564
|
+
raise ValueError(f"For '_gather_all_tasks', error args 'format': {dst_format}.")
|
|
2802
2565
|
|
|
2803
2566
|
for rank_dir in rank_dirs:
|
|
2804
2567
|
rank_dir_path = os.path.join(root, rank_dir)
|
|
2805
|
-
|
|
2806
|
-
|
|
2568
|
+
if save_path is not None:
|
|
2569
|
+
dst_root = os.path.join(save_path, os.path.relpath(rank_dir_path, file_path))
|
|
2570
|
+
else:
|
|
2571
|
+
dst_root = rank_dir_path
|
|
2572
|
+
|
|
2807
2573
|
os.makedirs(dst_root, exist_ok=True)
|
|
2808
|
-
tasks.extend(
|
|
2809
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
2810
|
-
for file in os.listdir(rank_dir_path)
|
|
2811
|
-
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
2812
|
-
)
|
|
2813
|
-
return tasks
|
|
2814
2574
|
|
|
2575
|
+
for file in os.listdir(rank_dir_path):
|
|
2576
|
+
if file.endswith(cur_file_suffix) and (file_name_regex is None or re.search(file_name_regex, file)):
|
|
2577
|
+
tasks_list.append((os.path.join(rank_dir_path, file), name_map, dst_root, file, dst_format))
|
|
2815
2578
|
|
|
2816
|
-
|
|
2817
|
-
"""gather transform rank together"""
|
|
2818
|
-
tasks = []
|
|
2819
|
-
for root, dirs, _ in os.walk(file_path):
|
|
2820
|
-
if root != file_path:
|
|
2821
|
-
continue
|
|
2579
|
+
return tasks_list
|
|
2822
2580
|
|
|
2823
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2824
|
-
if not rank_dirs:
|
|
2825
|
-
raise ValueError(
|
|
2826
|
-
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
2827
2581
|
|
|
2828
|
-
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2582
|
+
def _convert_checkpoint_file(file_path, save_path=None, name_map=None, file_name_regex=None,
|
|
2583
|
+
processes_num=1, dst_format="safetensors"):
|
|
2584
|
+
"""
|
|
2585
|
+
Converts MindSpore checkpoint files format and saves them to `save_path`.
|
|
2586
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
2587
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
2588
|
+
|
|
2589
|
+
Args:
|
|
2590
|
+
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
2591
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
|
|
2592
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2593
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2594
|
+
Default: ``None``.
|
|
2595
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2596
|
+
dst_format (str): dst file format. Default: "safetensors".
|
|
2597
|
+
"""
|
|
2598
|
+
if dst_format == "safetensors":
|
|
2599
|
+
src_format = "ckpt"
|
|
2600
|
+
src_file_suffix = ".ckpt"
|
|
2601
|
+
dst_file_suffix = ".safetensors"
|
|
2602
|
+
func_name = "ckpt_to_safetensors"
|
|
2603
|
+
else:
|
|
2604
|
+
src_format = "safetensors"
|
|
2605
|
+
src_file_suffix = ".safetensors"
|
|
2606
|
+
dst_file_suffix = ".ckpt"
|
|
2607
|
+
func_name = "safetensors_to_ckpt"
|
|
2608
|
+
is_dir = os.path.isdir(file_path)
|
|
2609
|
+
is_file = os.path.isfile(file_path)
|
|
2610
|
+
if not is_dir and not is_file:
|
|
2611
|
+
raise ValueError(f"For {func_name}, the input path must be a valid path or file, but got {file_path}")
|
|
2612
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
2613
|
+
raise ValueError(f"For {func_name}, the save_path must be a directory, but got '{save_path}'")
|
|
2614
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
2615
|
+
raise ValueError(
|
|
2616
|
+
f"For {func_name}, the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2617
|
+
|
|
2618
|
+
if is_dir:
|
|
2619
|
+
tasks_list = _gather_all_tasks(file_path, save_path, file_name_regex, name_map, dst_format=dst_format)
|
|
2620
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
2621
|
+
list(_progress_bar(pool.imap(_process_file, tasks_list), total=len(tasks_list)))
|
|
2622
|
+
elif is_file:
|
|
2623
|
+
if not file_path.endswith(src_file_suffix):
|
|
2624
|
+
raise ValueError(f"For {func_name}, the input file must be a {src_file_suffix} file, but got {file_path}")
|
|
2625
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2626
|
+
raise ValueError(f"For {func_name}, the input file does not match the regular expression.")
|
|
2627
|
+
if save_path and not os.path.exists(save_path):
|
|
2628
|
+
os.makedirs(save_path, exist_ok=True)
|
|
2629
|
+
|
|
2630
|
+
param_dict = _load_file_and_convert_name(file_path, name_map, format=src_format)
|
|
2631
|
+
|
|
2632
|
+
file_filename = os.path.basename(file_path).replace(src_file_suffix, dst_file_suffix)
|
|
2633
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), file_filename)
|
|
2634
|
+
mindspore.save_checkpoint(param_dict, dst_file, format=dst_format)
|
|
2839
2635
|
|
|
2840
2636
|
|
|
2841
2637
|
def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
@@ -2854,11 +2650,11 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2854
2650
|
|
|
2855
2651
|
Args:
|
|
2856
2652
|
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
2857
|
-
save_path (str, optional): Directory path where safetensors files will be saved.
|
|
2858
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names.
|
|
2653
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Default: ``None``.
|
|
2654
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2859
2655
|
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2860
|
-
|
|
2861
|
-
processes_num (int, optional): Number of processes to use for parallel processing.
|
|
2656
|
+
Default: ``None``.
|
|
2657
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2862
2658
|
Raises:
|
|
2863
2659
|
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
2864
2660
|
or the file_path does not end with '.ckpt'.
|
|
@@ -2874,36 +2670,8 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2874
2670
|
>>> namemap = {"lin.weight":"new_name"}
|
|
2875
2671
|
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
2876
2672
|
"""
|
|
2877
|
-
|
|
2878
|
-
|
|
2879
|
-
if not is_dir and not is_file:
|
|
2880
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
2881
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
2882
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
2883
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
2884
|
-
raise ValueError(
|
|
2885
|
-
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2886
|
-
|
|
2887
|
-
if is_dir:
|
|
2888
|
-
tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
|
|
2889
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
2890
|
-
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
2891
|
-
elif is_file:
|
|
2892
|
-
if not file_path.endswith(".ckpt"):
|
|
2893
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
2894
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2895
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
2896
|
-
if save_path and not os.path.exists(save_path):
|
|
2897
|
-
os.makedirs(save_path, exist_ok=True)
|
|
2898
|
-
|
|
2899
|
-
if name_map is not None:
|
|
2900
|
-
param_dict = _load_ckpt_to_new_name_map(file_path, name_map)
|
|
2901
|
-
else:
|
|
2902
|
-
param_dict = mindspore.load_checkpoint(file_path)
|
|
2903
|
-
|
|
2904
|
-
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
2905
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
2906
|
-
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2673
|
+
_convert_checkpoint_file(file_path, save_path, name_map,
|
|
2674
|
+
file_name_regex, processes_num, "safetensors")
|
|
2907
2675
|
|
|
2908
2676
|
|
|
2909
2677
|
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
@@ -2918,11 +2686,11 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2918
2686
|
|
|
2919
2687
|
Args:
|
|
2920
2688
|
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
2921
|
-
save_path (str, optional): Directory path where checkpoint files will be saved.
|
|
2922
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names.
|
|
2689
|
+
save_path (str, optional): Directory path where checkpoint files will be saved. Default: ``None``.
|
|
2690
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Default: ``None``.
|
|
2923
2691
|
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2924
|
-
|
|
2925
|
-
processes_num (int, optional): Number of processes to use for parallel processing.
|
|
2692
|
+
Default: ``None``.
|
|
2693
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Default: 1.
|
|
2926
2694
|
|
|
2927
2695
|
Raises:
|
|
2928
2696
|
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
@@ -2939,37 +2707,8 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2939
2707
|
>>> namemap = {"lin.weight":"new_name"}
|
|
2940
2708
|
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
2941
2709
|
"""
|
|
2942
|
-
|
|
2943
|
-
|
|
2944
|
-
if not is_dir and not is_file:
|
|
2945
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
2946
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
2947
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
2948
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
2949
|
-
raise ValueError(
|
|
2950
|
-
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2951
|
-
|
|
2952
|
-
if is_dir:
|
|
2953
|
-
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
2954
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
2955
|
-
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
2956
|
-
elif is_file:
|
|
2957
|
-
if not file_path.endswith(".safetensors"):
|
|
2958
|
-
raise ValueError(
|
|
2959
|
-
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
2960
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2961
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
2962
|
-
if save_path and not os.path.exists(save_path):
|
|
2963
|
-
os.makedirs(save_path, exist_ok=True)
|
|
2964
|
-
|
|
2965
|
-
if name_map is not None:
|
|
2966
|
-
param_dict = _load_sf_to_new_name_map(file_path, name_map)
|
|
2967
|
-
else:
|
|
2968
|
-
param_dict = mindspore.load_checkpoint(file_path, format="safetensors")
|
|
2969
|
-
|
|
2970
|
-
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
2971
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
2972
|
-
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2710
|
+
_convert_checkpoint_file(file_path, save_path, name_map,
|
|
2711
|
+
file_name_regex, processes_num, "ckpt")
|
|
2973
2712
|
|
|
2974
2713
|
|
|
2975
2714
|
def restore_group_info_list(group_info_file_name):
|