mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,20 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
import time
|
|
20
19
|
import glob
|
|
21
|
-
import re
|
|
22
20
|
import math
|
|
23
21
|
import json
|
|
22
|
+
import re
|
|
24
23
|
from collections import defaultdict
|
|
25
24
|
|
|
25
|
+
import time
|
|
26
26
|
import multiprocessing as mp
|
|
27
27
|
import numpy as np
|
|
28
|
+
from safetensors.numpy import save_file, load_file
|
|
29
|
+
from safetensors import safe_open
|
|
30
|
+
|
|
28
31
|
import mindspore as ms
|
|
32
|
+
from mindspore import log as logger
|
|
29
33
|
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
|
|
30
34
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
31
35
|
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
|
|
@@ -36,70 +40,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
|
|
|
36
40
|
from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
|
|
37
41
|
_convert_to_list
|
|
38
42
|
|
|
39
|
-
from safetensors.numpy import save_file, load_file
|
|
40
|
-
from safetensors import safe_open
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
44
|
-
if load_func is not None:
|
|
45
|
-
param_dict = load_func(path)
|
|
46
|
-
else:
|
|
47
|
-
param_dict = path
|
|
48
|
-
transform_dict = {}
|
|
49
|
-
for k, v in param_dict.items():
|
|
50
|
-
new_name = name_map.get(k, k) if name_map is not None else k
|
|
51
|
-
transform_dict[new_name] = transform_func(v, new_name)
|
|
52
|
-
return transform_dict
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _transform_tensor_to_numpy(path, name_map=None):
|
|
56
|
-
return _load_and_transform(path, name_map, ms.load_checkpoint, lambda v, new_name: v.asnumpy())
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _transform_numpy_to_tensor(path, name_map=None):
|
|
60
|
-
return _load_and_transform(path, name_map, load_file, lambda v, new_name: ms.Parameter(v, name=new_name))
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _process_file(file_info):
|
|
64
|
-
cur_ckpt_path, name_map, save_path, file = file_info
|
|
65
|
-
param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
|
|
66
|
-
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
67
|
-
dst_file = os.path.join(save_path, safetensors_filename)
|
|
68
|
-
save_file(param_dict_numpy, dst_file)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _process_file_safetensors(file_info):
|
|
72
|
-
cur_safe_path, name_map, save_path, file = file_info
|
|
73
|
-
param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
|
|
74
|
-
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
75
|
-
dst_file = os.path.join(save_path, ckpt_filename)
|
|
76
|
-
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _gather_tasks(file_path, save_path, file_name_regex, name_map):
|
|
80
|
-
"""gather transform rank together"""
|
|
81
|
-
tasks = []
|
|
82
|
-
for root, dirs, _ in os.walk(file_path):
|
|
83
|
-
if root != file_path:
|
|
84
|
-
continue
|
|
85
|
-
|
|
86
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
87
|
-
if not rank_dirs:
|
|
88
|
-
raise ValueError(
|
|
89
|
-
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
90
|
-
|
|
91
|
-
for rank_dir in rank_dirs:
|
|
92
|
-
rank_dir_path = os.path.join(root, rank_dir)
|
|
93
|
-
dst_root = os.path.join(save_path,
|
|
94
|
-
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
95
|
-
os.makedirs(dst_root, exist_ok=True)
|
|
96
|
-
tasks.extend(
|
|
97
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
98
|
-
for file in os.listdir(rank_dir_path)
|
|
99
|
-
if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
100
|
-
)
|
|
101
|
-
return tasks
|
|
102
|
-
|
|
103
43
|
|
|
104
44
|
def _progress_bar(iterable, total=None):
|
|
105
45
|
"""
|
|
@@ -134,155 +74,16 @@ def _progress_bar(iterable, total=None):
|
|
|
134
74
|
print_progress_bar(i)
|
|
135
75
|
|
|
136
76
|
|
|
137
|
-
def
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
verification, an error will be generated when performing the conversion.
|
|
148
|
-
The safetensors format currently does not support crc verification function. If ckpt contains crc verification
|
|
149
|
-
information, the crc verification information will be lost after conversion to safetensors.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
153
|
-
save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
|
|
154
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
155
|
-
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
156
|
-
Defaults: ``None``.
|
|
157
|
-
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
158
|
-
Raises:
|
|
159
|
-
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
160
|
-
or the file_path does not end with '.ckpt'.
|
|
161
|
-
|
|
162
|
-
Supported Platforms:
|
|
163
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
164
|
-
|
|
165
|
-
Examples:
|
|
166
|
-
>>> import mindspore as ms
|
|
167
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path")
|
|
168
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
|
|
169
|
-
>>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
|
|
170
|
-
>>> namemap = {"lin.weight":"new_name"}
|
|
171
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
172
|
-
"""
|
|
173
|
-
is_dir = os.path.isdir(file_path)
|
|
174
|
-
is_file = os.path.isfile(file_path)
|
|
175
|
-
if not is_dir and not is_file:
|
|
176
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
177
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
178
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
179
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
180
|
-
raise ValueError(
|
|
181
|
-
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
182
|
-
|
|
183
|
-
if is_dir:
|
|
184
|
-
tasks = _gather_tasks(file_path, save_path, file_name_regex, name_map)
|
|
185
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
186
|
-
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
187
|
-
elif is_file:
|
|
188
|
-
if not file_path.endswith(".ckpt"):
|
|
189
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
190
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
191
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
192
|
-
if save_path and not os.path.exists(save_path):
|
|
193
|
-
os.makedirs(save_path, exist_ok=True)
|
|
194
|
-
|
|
195
|
-
param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
|
|
196
|
-
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
197
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
198
|
-
save_file(param_dict_numpy, dst_file)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
202
|
-
"""gather transform rank together"""
|
|
203
|
-
tasks = []
|
|
204
|
-
for root, dirs, _ in os.walk(file_path):
|
|
205
|
-
if root != file_path:
|
|
206
|
-
continue
|
|
207
|
-
|
|
208
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
209
|
-
if not rank_dirs:
|
|
210
|
-
raise ValueError(
|
|
211
|
-
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
|
|
212
|
-
|
|
213
|
-
for rank_dir in rank_dirs:
|
|
214
|
-
rank_dir_path = os.path.join(root, rank_dir)
|
|
215
|
-
dst_root = os.path.join(save_path,
|
|
216
|
-
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
217
|
-
os.makedirs(dst_root, exist_ok=True)
|
|
218
|
-
tasks.extend(
|
|
219
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
220
|
-
for file in os.listdir(rank_dir_path)
|
|
221
|
-
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
222
|
-
)
|
|
223
|
-
return tasks
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
227
|
-
"""
|
|
228
|
-
Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
|
|
229
|
-
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
230
|
-
used for securely storing Tensors with fast speed (zero copy).
|
|
231
|
-
|
|
232
|
-
Note:
|
|
233
|
-
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
234
|
-
too large, otherwise it may cause freezing.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
238
|
-
save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
|
|
239
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
240
|
-
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
241
|
-
Defaults: ``None``.
|
|
242
|
-
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
243
|
-
|
|
244
|
-
Raises:
|
|
245
|
-
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
246
|
-
or the file_path does not end with '.safetensors'.
|
|
247
|
-
|
|
248
|
-
Supported Platforms:
|
|
249
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
250
|
-
|
|
251
|
-
Examples:
|
|
252
|
-
>>> import mindspore as ms
|
|
253
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path")
|
|
254
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
|
|
255
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
|
|
256
|
-
>>> namemap = {"lin.weight":"new_name"}
|
|
257
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
258
|
-
"""
|
|
259
|
-
is_dir = os.path.isdir(file_path)
|
|
260
|
-
is_file = os.path.isfile(file_path)
|
|
261
|
-
if not is_dir and not is_file:
|
|
262
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
263
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
264
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
265
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
266
|
-
raise ValueError(
|
|
267
|
-
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
268
|
-
|
|
269
|
-
if is_dir:
|
|
270
|
-
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
271
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
272
|
-
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
273
|
-
elif is_file:
|
|
274
|
-
if not file_path.endswith(".safetensors"):
|
|
275
|
-
raise ValueError(
|
|
276
|
-
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
277
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
278
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
279
|
-
if save_path and not os.path.exists(save_path):
|
|
280
|
-
os.makedirs(save_path, exist_ok=True)
|
|
281
|
-
|
|
282
|
-
param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
|
|
283
|
-
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
284
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
285
|
-
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
77
|
+
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
78
|
+
if load_func is not None:
|
|
79
|
+
param_dict = load_func(path)
|
|
80
|
+
else:
|
|
81
|
+
param_dict = path
|
|
82
|
+
transform_dict = {}
|
|
83
|
+
for k, v in param_dict.items():
|
|
84
|
+
new_name = name_map.get(k, k) if name_map is not None else k
|
|
85
|
+
transform_dict[new_name] = transform_func(v, new_name)
|
|
86
|
+
return transform_dict
|
|
286
87
|
|
|
287
88
|
|
|
288
89
|
def _check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file):
|
|
@@ -460,7 +261,6 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
|
|
|
460
261
|
|
|
461
262
|
for name, layout in layout_map.items():
|
|
462
263
|
pipe_param_list[layout[6][0]].append(name)
|
|
463
|
-
|
|
464
264
|
part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
|
|
465
265
|
processes = []
|
|
466
266
|
for i in range(process_num):
|
|
@@ -512,7 +312,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
512
312
|
origin_dst_strategy_list,
|
|
513
313
|
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
514
314
|
_transform_param_list, pipe_param_list=None, file_index=None, unified_flag=False,
|
|
515
|
-
src_strategy_file=None):
|
|
315
|
+
src_strategy_file=None, choice_func=None):
|
|
516
316
|
"""
|
|
517
317
|
Transforms safetensors files to a specified format without using parallel processing.
|
|
518
318
|
"""
|
|
@@ -567,7 +367,18 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
567
367
|
# param not in ckpt file, check reason
|
|
568
368
|
continue
|
|
569
369
|
output = f.get_tensor(param_name)
|
|
570
|
-
|
|
370
|
+
save_param_name = param_name
|
|
371
|
+
if choice_func is not None:
|
|
372
|
+
choice_out = choice_func(param_name)
|
|
373
|
+
if isinstance(choice_out, bool):
|
|
374
|
+
if not choice_out:
|
|
375
|
+
continue
|
|
376
|
+
elif isinstance(choice_out, str):
|
|
377
|
+
save_param_name = choice_out
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
380
|
+
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
381
|
+
saftensor_dict[save_param_name] = output
|
|
571
382
|
else:
|
|
572
383
|
saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
|
|
573
384
|
for param_name, param in saftensor_dict.items():
|
|
@@ -735,6 +546,12 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
735
546
|
save_file(transform_param_dict, save_safetensor_file_name)
|
|
736
547
|
|
|
737
548
|
|
|
549
|
+
def _extrace_number(file_name):
|
|
550
|
+
"""get file last two number"""
|
|
551
|
+
number_ls = re.findall(r'\d+', file_name)
|
|
552
|
+
number_ls = [int(i) for i in number_ls]
|
|
553
|
+
return number_ls[-2:]
|
|
554
|
+
|
|
738
555
|
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
|
|
739
556
|
"""
|
|
740
557
|
Collects all safetensors files from the specified directory and its subdirectories.
|
|
@@ -758,12 +575,9 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
|
|
|
758
575
|
else:
|
|
759
576
|
safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
|
|
760
577
|
rank_ckpts = glob.glob(safetensor_file_name)
|
|
761
|
-
rank_ckpts.sort()
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
ms.log.warning("{} is not a safetensor file.".format(safetensor_file))
|
|
765
|
-
continue
|
|
766
|
-
all_safetensor_files_map[rank_id] = safetensor_file
|
|
578
|
+
rank_ckpts.sort(key=_extrace_number)
|
|
579
|
+
if rank_ckpts:
|
|
580
|
+
all_safetensor_files_map[rank_id] = rank_ckpts[-1]
|
|
767
581
|
return all_safetensor_files_map
|
|
768
582
|
|
|
769
583
|
|
|
@@ -865,7 +679,8 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
865
679
|
return transform_param_dict
|
|
866
680
|
|
|
867
681
|
|
|
868
|
-
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None
|
|
682
|
+
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
|
|
683
|
+
max_process_num=64, choice_func=None):
|
|
869
684
|
"""
|
|
870
685
|
Merge multiple safetensor files into a unified safetensor file.
|
|
871
686
|
|
|
@@ -877,6 +692,9 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
877
692
|
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
|
|
878
693
|
file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
|
|
879
694
|
meaning all safetensors files in the source weight directory will be merged.
|
|
695
|
+
max_process_num (int): Maximum number of processes. Default: 64.
|
|
696
|
+
choice_func (callable): A callable function used to filter parameters or modify parameter names.
|
|
697
|
+
The return value of the function must be of type str (string) or bool (boolean). Default: None.
|
|
880
698
|
|
|
881
699
|
Raises:
|
|
882
700
|
ValueError: If the safetensors file of rank is missing.
|
|
@@ -943,13 +761,21 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
943
761
|
param_name_dict = dict()
|
|
944
762
|
for index, part_list in enumerate(split_list):
|
|
945
763
|
for name in part_list:
|
|
946
|
-
|
|
764
|
+
save_param_name = name
|
|
765
|
+
if choice_func is not None:
|
|
766
|
+
choice_out = choice_func(name)
|
|
767
|
+
if isinstance(choice_out, bool):
|
|
768
|
+
if not choice_out:
|
|
769
|
+
continue
|
|
770
|
+
elif isinstance(choice_out, str):
|
|
771
|
+
save_param_name = choice_out
|
|
772
|
+
param_name_dict[save_param_name] = f"part{index}.safetensors"
|
|
947
773
|
json_str = json.dumps(param_name_dict, indent=4)
|
|
948
774
|
map_file = os.path.join(dst_dir, "param_name_map.json")
|
|
949
775
|
with open(map_file, 'w') as f:
|
|
950
776
|
f.write(json_str)
|
|
951
777
|
|
|
952
|
-
max_process = min(split_num,
|
|
778
|
+
max_process = min(split_num, max_process_num)
|
|
953
779
|
res = [i for i in range(split_num)]
|
|
954
780
|
res = _split_list(res, max_process)
|
|
955
781
|
processes = []
|
|
@@ -960,7 +786,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
960
786
|
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
|
|
961
787
|
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
962
788
|
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
|
|
963
|
-
"", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name))
|
|
789
|
+
"", dst_dir, "safetensors", None, split_list, res[i], True, src_strategy_name, choice_func))
|
|
964
790
|
p.start()
|
|
965
791
|
processes.append(p)
|
|
966
792
|
for p in processes:
|
|
@@ -974,13 +800,14 @@ def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor
|
|
|
974
800
|
origin_dst_strategy_list,
|
|
975
801
|
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
976
802
|
_transform_param_list, pipe_param_list=None, file_index=None,
|
|
977
|
-
unified_flag=False, src_strategy_file=None):
|
|
803
|
+
unified_flag=False, src_strategy_file=None, choice_func=None):
|
|
978
804
|
for i in file_index:
|
|
979
805
|
_transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
980
806
|
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
|
|
981
807
|
origin_src_strategy_list,
|
|
982
808
|
origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir, output_format,
|
|
983
|
-
_transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file
|
|
809
|
+
_transform_param_list, pipe_param_list[i], i, unified_flag, src_strategy_file,
|
|
810
|
+
choice_func)
|
|
984
811
|
|
|
985
812
|
|
|
986
813
|
def _split_list(split_list, split_num):
|
|
@@ -1027,24 +854,64 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
|
|
|
1027
854
|
return sf_obj
|
|
1028
855
|
|
|
1029
856
|
|
|
1030
|
-
def
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
857
|
+
def _check_name_map_value_is_str(value):
|
|
858
|
+
"""check input is bool"""
|
|
859
|
+
if not isinstance(value, str):
|
|
860
|
+
raise ValueError(
|
|
861
|
+
f"For 'load_distributed_checkpoint', the value of name_map must be str, but got {type(value)}.")
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
def _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param):
|
|
865
|
+
"""process hyper params"""
|
|
866
|
+
if 'hyper_param.safetensors' in file_list:
|
|
867
|
+
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
868
|
+
with safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
869
|
+
for key in f.keys():
|
|
870
|
+
cur_param_name = name_map.get(key) if name_map is not None and key in name_map else key
|
|
871
|
+
_check_name_map_value_is_str(cur_param_name)
|
|
872
|
+
total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
|
|
873
|
+
return total_param
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files, dst_strategy_file, rank_id):
|
|
877
|
+
"""calculate param_name_map and param_list"""
|
|
878
|
+
if len(file_list) == 1:
|
|
879
|
+
logger.info("There is only one weight file in the directory, which will be automatically mapped.")
|
|
880
|
+
file_name = os.path.join(total_safetensors_dir, file_list[0])
|
|
881
|
+
is_file = os.path.isfile(file_name)
|
|
882
|
+
if not is_file:
|
|
883
|
+
raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
|
|
884
|
+
f"in the `unified_safetensors_dir`.")
|
|
885
|
+
with safe_open(file_name, framework="np") as f:
|
|
886
|
+
keys = f.keys()
|
|
887
|
+
values = len(keys) * [file_list[0]]
|
|
888
|
+
param_name_map = dict(zip(keys, values))
|
|
889
|
+
else:
|
|
890
|
+
if len(json_files) != 1:
|
|
891
|
+
raise ValueError(f"For 'load_parallel_checkpoint', the number of json files in 'total_safetensors_dir' "
|
|
892
|
+
f"must be 1, but got {len(json_files)}.")
|
|
893
|
+
param_name_json = os.path.join(total_safetensors_dir, json_files[0])
|
|
894
|
+
with open(param_name_json, 'r') as f:
|
|
895
|
+
param_name_map = json.load(f)
|
|
896
|
+
|
|
1041
897
|
if dst_strategy_file is not None:
|
|
1042
898
|
_, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
|
|
1043
899
|
param_list = dst_strategy_list.keys()
|
|
1044
900
|
else:
|
|
1045
901
|
dst_strategy_list = None
|
|
1046
902
|
param_list = param_name_map.keys()
|
|
903
|
+
return param_name_map, param_list, dst_strategy_list
|
|
1047
904
|
|
|
905
|
+
|
|
906
|
+
def _load_parallel_checkpoint(file_info):
|
|
907
|
+
"""load parallel safetensors by merged file."""
|
|
908
|
+
total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
|
|
909
|
+
rank_id, output_format, name_map, return_param_dict = file_info
|
|
910
|
+
file_list = os.listdir(total_safetensors_dir)
|
|
911
|
+
json_files = [file for file in file_list if file.endswith('.json')]
|
|
912
|
+
param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
|
|
913
|
+
json_files, dst_strategy_file,
|
|
914
|
+
rank_id)
|
|
1048
915
|
total_param = dict()
|
|
1049
916
|
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
1050
917
|
is not None else 1
|
|
@@ -1102,20 +969,19 @@ def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None
|
|
|
1102
969
|
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
|
|
1103
970
|
else:
|
|
1104
971
|
slice_param = sf_obj[:]
|
|
972
|
+
cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
|
|
973
|
+
_check_name_map_value_is_str(cur_param_name)
|
|
974
|
+
total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
|
|
1105
975
|
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
if 'hyper_param.safetensors' in file_list:
|
|
1109
|
-
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
1110
|
-
with safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
1111
|
-
for key in f.keys():
|
|
1112
|
-
total_param[key] = ms.Parameter(f.get_tensor(key))
|
|
976
|
+
total_param = _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param)
|
|
1113
977
|
if net is not None:
|
|
1114
|
-
|
|
1115
|
-
|
|
978
|
+
if not return_param_dict:
|
|
979
|
+
param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
|
|
980
|
+
return param_not_load, ckpt_not_load
|
|
981
|
+
return total_param
|
|
1116
982
|
_make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
|
|
1117
|
-
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.
|
|
1118
|
-
format=
|
|
983
|
+
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.{output_format}"),
|
|
984
|
+
format=output_format)
|
|
1119
985
|
return None
|
|
1120
986
|
|
|
1121
987
|
|
|
@@ -1143,4 +1009,4 @@ def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
|
|
|
1143
1009
|
|
|
1144
1010
|
|
|
1145
1011
|
__all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
|
|
1146
|
-
"transform_safetensors_by_rank", "
|
|
1012
|
+
"transform_safetensors_by_rank", "unified_safetensors"]
|
mindspore/profiler/__init__.py
CHANGED
|
@@ -20,9 +20,22 @@ Users can visualize the results using the MindInsight tool.
|
|
|
20
20
|
Now, Profiler supports AICORE operator, AICPU operator, HostCPU operator, memory,
|
|
21
21
|
correspondence, cluster, etc data analysis.
|
|
22
22
|
"""
|
|
23
|
-
__all__ = [
|
|
23
|
+
__all__ = [
|
|
24
|
+
"tensor_board_trace_handler",
|
|
25
|
+
"schedule",
|
|
26
|
+
"Profiler",
|
|
27
|
+
"EnvProfiler",
|
|
28
|
+
"ProfilerLevel",
|
|
29
|
+
"ProfilerActivity",
|
|
30
|
+
"AicoreMetrics",
|
|
31
|
+
"DynamicProfilerMonitor",
|
|
32
|
+
"mstx"
|
|
33
|
+
]
|
|
24
34
|
|
|
25
|
-
from mindspore.profiler.
|
|
26
|
-
from mindspore.profiler.
|
|
27
|
-
from mindspore.profiler.
|
|
35
|
+
from mindspore.profiler.mstx import Mstx as mstx
|
|
36
|
+
from mindspore.profiler.profiler import Profiler
|
|
37
|
+
from mindspore.profiler.profiler import tensor_board_trace_handler
|
|
38
|
+
from mindspore.profiler.schedule import Schedule as schedule
|
|
39
|
+
from mindspore.profiler.envprofiler import EnvProfiler
|
|
40
|
+
from mindspore.profiler.common.constant import ProfilerLevel, ProfilerActivity, AicoreMetrics
|
|
28
41
|
from mindspore.profiler.dynamic_profiler import DynamicProfilerMonitor
|
|
File without changes
|
|
File without changes
|