mindspore 2.6.0rc1__cp39-cp39-win_amd64.whl → 2.7.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +40 -9
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
- mindspore/_extends/parse/parser.py +37 -62
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +43 -13
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/amp.py +4 -4
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/common/__init__.py +27 -2
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +77 -16
- mindspore/common/api.py +238 -113
- mindspore/common/dtype.py +21 -11
- mindspore/common/dump.py +10 -15
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +11 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/parameter.py +26 -12
- mindspore/common/recompute.py +3 -3
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +81 -81
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +58 -40
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/datasets.py +20 -7
- mindspore/dataset/engine/datasets_user_defined.py +33 -3
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +7 -3
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -7
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +37 -1
- mindspore/include/api/delegate.h +10 -0
- mindspore/include/api/model.h +3 -0
- mindspore/include/api/types.h +2 -2
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +60 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +1 -0
- mindspore/mint/distributed/distributed.py +212 -9
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +164 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +137 -101
- mindspore/mint/nn/layer/normalization.py +8 -22
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/nn/cell.py +328 -502
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +32 -34
- mindspore/nn/layer/basic.py +67 -64
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +37 -39
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +3 -3
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +34 -37
- mindspore/nn/wrap/grad_reducer.py +37 -37
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +5 -5
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +42 -11
- mindspore/ops/auto_generate/gen_extend_func.py +23 -141
- mindspore/ops/auto_generate/gen_ops_def.py +727 -321
- mindspore/ops/auto_generate/gen_ops_prim.py +1721 -984
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +8 -4
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +3 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +9 -96
- mindspore/ops/function/debug_func.py +4 -3
- mindspore/ops/function/grad/grad_func.py +1 -1
- mindspore/ops/function/math_func.py +33 -540
- mindspore/ops/function/nn_func.py +28 -74
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +2 -3
- mindspore/ops/functional_overload.py +571 -6
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +3 -6
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +2 -2
- mindspore/ops/operations/comm_ops.py +185 -26
- mindspore/ops/operations/custom_ops.py +294 -174
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +15 -16
- mindspore/ops/operations/math_ops.py +3 -4
- mindspore/ops/operations/nn_ops.py +7 -39
- mindspore/ops/primitive.py +6 -10
- mindspore/ops/tensor_method.py +47 -8
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +11 -8
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +1 -1
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +13 -8
- mindspore/parallel/auto_parallel.py +14 -7
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +11 -7
- mindspore/parallel/cluster/process_entity/_api.py +84 -48
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +43 -4
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +6 -7
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
- mindspore/parallel/shard.py +3 -4
- mindspore/parallel/transform_safetensors.py +463 -174
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/profiler_context.py +25 -27
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_op_analyse.py +235 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +305 -314
- mindspore/profiler/envprofiler.py +12 -7
- mindspore/profiler/experimental_config.py +96 -6
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/npu_profiler.py +29 -19
- mindspore/profiler/profiler.py +35 -19
- mindspore/profiler/profiler_action_controller.py +64 -76
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +5 -5
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +60 -45
- mindspore/runtime/memory.py +30 -32
- mindspore/runtime/thread_bind_core.py +298 -164
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +14 -4
- mindspore/train/amp.py +43 -20
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_checkpoint.py +3 -6
- mindspore/train/callback/_flops_collector.py +1 -1
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +97 -16
- mindspore/train/data_sink.py +11 -2
- mindspore/train/dataset_helper.py +9 -0
- mindspore/train/model.py +135 -55
- mindspore/train/serialization.py +133 -111
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +6 -9
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +333 -371
- mindspore/_deprecated/jit.py +0 -198
- mindspore/experimental/es/__init__.py +0 -22
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -21,27 +21,263 @@ import glob
|
|
|
21
21
|
import math
|
|
22
22
|
import json
|
|
23
23
|
import re
|
|
24
|
-
|
|
24
|
+
import mmap
|
|
25
|
+
import stat
|
|
26
|
+
from collections import defaultdict, OrderedDict
|
|
25
27
|
|
|
26
28
|
import time
|
|
27
29
|
import multiprocessing as mp
|
|
30
|
+
|
|
31
|
+
from safetensors.numpy import save_file, load_file
|
|
28
32
|
import psutil
|
|
29
33
|
import numpy as np
|
|
30
|
-
from safetensors.numpy import save_file, load_file
|
|
31
|
-
from safetensors import safe_open
|
|
32
34
|
|
|
33
35
|
import mindspore as ms
|
|
34
36
|
from mindspore import log as logger
|
|
35
37
|
from mindspore.log import vlog_print
|
|
38
|
+
from mindspore.common.parameter import Parameter
|
|
39
|
+
from mindspore.common.tensor import Tensor
|
|
40
|
+
from mindspore.common import np_dtype
|
|
36
41
|
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
|
|
37
42
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
38
43
|
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
|
|
39
44
|
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
40
45
|
_get_needed_rank_transform_operator_map_by_layouts, \
|
|
41
46
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
42
|
-
_extract_layout_item,
|
|
47
|
+
_extract_layout_item, _apply_operator
|
|
43
48
|
from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
|
|
44
49
|
_convert_to_list
|
|
50
|
+
from mindspore.common import dtype as mstype
|
|
51
|
+
|
|
52
|
+
safetensors_to_mstype = {'Int4': mstype.qint4x2}
|
|
53
|
+
|
|
54
|
+
np.bfloat16 = np_dtype.bfloat16
|
|
55
|
+
|
|
56
|
+
MAX_HEADER_SIZE = 100 * 1000 * 1000
|
|
57
|
+
|
|
58
|
+
dtype_size = {
|
|
59
|
+
"BOOL": 1,
|
|
60
|
+
"U8": 1,
|
|
61
|
+
"I8": 1,
|
|
62
|
+
"I16": 2,
|
|
63
|
+
"U16": 2,
|
|
64
|
+
"I32": 4,
|
|
65
|
+
"U32": 4,
|
|
66
|
+
"I64": 8,
|
|
67
|
+
"U64": 8,
|
|
68
|
+
"F16": 2,
|
|
69
|
+
"BF16": 2,
|
|
70
|
+
"F32": 4,
|
|
71
|
+
"F64": 8,
|
|
72
|
+
}
|
|
73
|
+
np_dtype_size = {
|
|
74
|
+
"bool_": 1,
|
|
75
|
+
"uint8": 1,
|
|
76
|
+
"int8": 1,
|
|
77
|
+
"int16": 2,
|
|
78
|
+
"uint16": 2,
|
|
79
|
+
"int32": 4,
|
|
80
|
+
"uint32": 4,
|
|
81
|
+
"int64": 8,
|
|
82
|
+
"uint64": 8,
|
|
83
|
+
"float16": 2,
|
|
84
|
+
"bfloat16": 2,
|
|
85
|
+
"float32": 4,
|
|
86
|
+
"float64": 8,
|
|
87
|
+
}
|
|
88
|
+
numpy_dtype = {
|
|
89
|
+
"BOOL": np.bool_,
|
|
90
|
+
"U8": np.uint8,
|
|
91
|
+
"I8": np.int8,
|
|
92
|
+
"I16": np.int16,
|
|
93
|
+
"U16": np.uint16,
|
|
94
|
+
"I32": np.int32,
|
|
95
|
+
"U32": np.uint32,
|
|
96
|
+
"I64": np.int64,
|
|
97
|
+
"U64": np.uint64,
|
|
98
|
+
"F16": np.float16,
|
|
99
|
+
"BF16": np.bfloat16, # no bf16
|
|
100
|
+
"F32": np.float32,
|
|
101
|
+
"F64": np.float64,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def getSize(fileobject):
|
|
106
|
+
fileobject.seek(0, 2) # move the cursor to the end of the file
|
|
107
|
+
size = fileobject.tell()
|
|
108
|
+
fileobject.seek(0) # move the cursor to the start of the file
|
|
109
|
+
return size
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _save_file_atomically(transform_param_dict, save_file_name, metadata=None):
|
|
113
|
+
"""Atomically save file using temporary name and rename."""
|
|
114
|
+
if metadata is None:
|
|
115
|
+
metadata = {"format": "ms"}
|
|
116
|
+
file_name_list = list(os.path.splitext(save_file_name))
|
|
117
|
+
file_name_list[1] = file_name_list[1].replace('.safetensors', '.tmp')
|
|
118
|
+
tmp_name = ''.join(file_name_list)
|
|
119
|
+
try:
|
|
120
|
+
if os.path.exists(save_file_name):
|
|
121
|
+
os.chmod(save_file_name, stat.S_IWUSR)
|
|
122
|
+
os.remove(save_file_name)
|
|
123
|
+
if os.path.exists(tmp_name):
|
|
124
|
+
os.chmod(tmp_name, stat.S_IWUSR)
|
|
125
|
+
os.remove(tmp_name)
|
|
126
|
+
save_file(transform_param_dict, tmp_name, metadata=metadata)
|
|
127
|
+
os.rename(tmp_name, save_file_name)
|
|
128
|
+
os.chmod(save_file_name, stat.S_IRUSR)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
if not os.path.exists(save_file_name):
|
|
131
|
+
logger.warning(f"Save failed, {save_file_name} not found. "
|
|
132
|
+
f"This may indicate multiple processes modifying the same file "
|
|
133
|
+
f"or insufficient disk space.")
|
|
134
|
+
raise e
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def metadata_validate(metadata):
|
|
138
|
+
"""validation metadata"""
|
|
139
|
+
start = 0
|
|
140
|
+
for key, info in metadata.items():
|
|
141
|
+
s, e = info["data_offsets"]
|
|
142
|
+
if s != start or e < s:
|
|
143
|
+
raise ValueError(f"SafeTensorError::InvalidOffset({key})")
|
|
144
|
+
start = e
|
|
145
|
+
nelements = np.prod(info["shape"])
|
|
146
|
+
nbytes = nelements * dtype_size[info["dtype"]]
|
|
147
|
+
if (e - s) != nbytes:
|
|
148
|
+
raise ValueError("SafeTensorError::TensorInvalidInfo")
|
|
149
|
+
return start
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def read_metadata(buffer):
|
|
153
|
+
"""read metadata by buffer"""
|
|
154
|
+
buffer_len = getSize(buffer)
|
|
155
|
+
if buffer_len < 8:
|
|
156
|
+
raise ValueError("SafeTensorError::HeaderTooSmall")
|
|
157
|
+
|
|
158
|
+
n = np.frombuffer(buffer.read(8), dtype=np.uint64).item()
|
|
159
|
+
if n > MAX_HEADER_SIZE:
|
|
160
|
+
raise ValueError("SafeTensorError::HeaderTooLarge")
|
|
161
|
+
|
|
162
|
+
stop = n + 8
|
|
163
|
+
if stop > buffer_len:
|
|
164
|
+
raise ValueError("SafeTensorError::InvalidHeaderLength")
|
|
165
|
+
|
|
166
|
+
tensors = json.loads(buffer.read(n), object_pairs_hook=OrderedDict)
|
|
167
|
+
metadata = tensors.pop("__metadata__", None)
|
|
168
|
+
buffer_end = metadata_validate(tensors)
|
|
169
|
+
|
|
170
|
+
if buffer_end + 8 + n != buffer_len:
|
|
171
|
+
raise ValueError("SafeTensorError::MetadataIncompleteBuffer")
|
|
172
|
+
|
|
173
|
+
return stop, tensors, metadata
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class PySafeSlice:
|
|
177
|
+
"""Create PySafeSlice by file"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, info, bufferfile, base_ptr, buffermmap):
|
|
180
|
+
self.info = info
|
|
181
|
+
self.bufferfile = bufferfile
|
|
182
|
+
self.buffermmap = buffermmap
|
|
183
|
+
self.base_ptr = base_ptr
|
|
184
|
+
|
|
185
|
+
self.start = [0 for dim in self.shape]
|
|
186
|
+
self.stop = [dim for dim in self.shape]
|
|
187
|
+
self.step = [1 for dim in self.shape]
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def ndim(self):
|
|
191
|
+
return len(self.shape)
|
|
192
|
+
|
|
193
|
+
def get(self, *args, **kwargs):
|
|
194
|
+
"""Get tensor from buffer by data_offset"""
|
|
195
|
+
nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize
|
|
196
|
+
offset = self.start_offset
|
|
197
|
+
tensor = np.frombuffer(self.buffermmap, dtype=self.dtype, offset=offset,
|
|
198
|
+
count=nbytes // np.dtype(self.dtype).itemsize)
|
|
199
|
+
tensor = tensor.reshape(self.shape)
|
|
200
|
+
if not tensor.flags["ALIGNED"]:
|
|
201
|
+
logger.info("This safetensors file is not aligned.")
|
|
202
|
+
tensor = tensor.copy()
|
|
203
|
+
return tensor
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def start_offset(self):
|
|
207
|
+
return self.base_ptr + self.info["data_offsets"][0]
|
|
208
|
+
|
|
209
|
+
def get_shape(self):
|
|
210
|
+
return self.shape
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def shape(self):
|
|
214
|
+
return self.info["shape"]
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def dtype(self):
|
|
218
|
+
return numpy_dtype[self.info["dtype"]]
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def nelements(self):
|
|
222
|
+
return np.prod(self.info["shape"])
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def bits(self):
|
|
226
|
+
return dtype_size[self.info["dtype"]]
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def nbytes(self):
|
|
230
|
+
return self.nelements * dtype_size[self.info["dtype"]]
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class _fast_safe_open:
|
|
234
|
+
"""
|
|
235
|
+
Open a safetensors file and access its metadata and tensors efficiently.
|
|
236
|
+
|
|
237
|
+
This function is designed to work similarly to `safetensors.safe_open`,
|
|
238
|
+
providing a fast way to open and interact with safetensors files.
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, filename, framework=None, device="cpu"):
|
|
242
|
+
self.filename = filename
|
|
243
|
+
self.framework = framework
|
|
244
|
+
self.file = open(self.filename, "rb")
|
|
245
|
+
self.file_mmap = mmap.mmap(self.file.fileno(), 0, access=mmap.ACCESS_COPY)
|
|
246
|
+
try:
|
|
247
|
+
self.base, self.tensors_decs, self.__metadata__ = read_metadata(self.file)
|
|
248
|
+
except ValueError:
|
|
249
|
+
raise ValueError(f"Fail to parse the input safetensors file: '{self.filename}'. "
|
|
250
|
+
f"Please check the correctness of the file.")
|
|
251
|
+
self.tensors = OrderedDict()
|
|
252
|
+
for key, info in self.tensors_decs.items():
|
|
253
|
+
self.tensors[key] = PySafeSlice(info, self.file, self.base, self.file_mmap)
|
|
254
|
+
self.tensors[key].key = key
|
|
255
|
+
|
|
256
|
+
def __enter__(self):
|
|
257
|
+
return self
|
|
258
|
+
|
|
259
|
+
def __exit__(self, *args):
|
|
260
|
+
self.file.close()
|
|
261
|
+
|
|
262
|
+
def metadata(self):
|
|
263
|
+
return self.__metadata__
|
|
264
|
+
|
|
265
|
+
def keys(self):
|
|
266
|
+
return list(self.tensors.keys())
|
|
267
|
+
|
|
268
|
+
def get_tensor(self, name):
|
|
269
|
+
return self.tensors[name].get()
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _fast_load_file(filename):
|
|
273
|
+
"""
|
|
274
|
+
Load safetensors info from a specified file.
|
|
275
|
+
"""
|
|
276
|
+
result = {}
|
|
277
|
+
with _fast_safe_open(filename, framework="np") as f:
|
|
278
|
+
for k in f.keys():
|
|
279
|
+
result[k] = f.get_tensor(k)
|
|
280
|
+
return result
|
|
45
281
|
|
|
46
282
|
|
|
47
283
|
def _progress_bar(iterable, total=None):
|
|
@@ -267,15 +503,22 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
|
|
|
267
503
|
pipe_param_list[layout[6][0]].append(name)
|
|
268
504
|
part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
|
|
269
505
|
processes = []
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
p
|
|
506
|
+
if process_num > 1:
|
|
507
|
+
for i in range(process_num):
|
|
508
|
+
p = mp.Process(target=_transform_safetensors_single, args=(
|
|
509
|
+
part_list_dict[i], all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
510
|
+
src_strategy_dict, dst_strategy_dict, origin_src_strategy_list, origin_dst_strategy_list,
|
|
511
|
+
ckpt_prefix, dst_safetensors_dir, output_format, _transform_param_list, pipe_param_list[i]))
|
|
512
|
+
p.start()
|
|
513
|
+
processes.append(p)
|
|
514
|
+
for p in processes:
|
|
515
|
+
p.join()
|
|
516
|
+
else:
|
|
517
|
+
_transform_safetensors_single(part_list_dict[0], all_safetensor_files_map, src_stage_device_num,
|
|
518
|
+
dst_stage_device_num, src_strategy_dict, dst_strategy_dict,
|
|
519
|
+
origin_src_strategy_list, origin_dst_strategy_list, ckpt_prefix,
|
|
520
|
+
dst_safetensors_dir, output_format, _transform_param_list,
|
|
521
|
+
pipe_param_list[0])
|
|
279
522
|
|
|
280
523
|
|
|
281
524
|
def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
|
|
@@ -288,7 +531,7 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
|
|
|
288
531
|
return set()
|
|
289
532
|
|
|
290
533
|
|
|
291
|
-
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
534
|
+
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, safetensor_dict, redundancy_dict,
|
|
292
535
|
needed_rank, device_num, choice_func):
|
|
293
536
|
"""Find the rank_id under redundant groups."""
|
|
294
537
|
io_time = 0
|
|
@@ -305,7 +548,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
|
|
|
305
548
|
break
|
|
306
549
|
if open_file_id is not None:
|
|
307
550
|
start_time = time.time()
|
|
308
|
-
output = file_dict[open_file_id].
|
|
551
|
+
output = file_dict[open_file_id].get_tensor(param_name)
|
|
309
552
|
end_time = time.time()
|
|
310
553
|
cost_time = end_time - start_time
|
|
311
554
|
io_time += cost_time
|
|
@@ -316,7 +559,7 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
|
|
|
316
559
|
if not isinstance(choice_out, (bool, str)):
|
|
317
560
|
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
318
561
|
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
319
|
-
|
|
562
|
+
safetensor_dict[param_name] = output
|
|
320
563
|
else:
|
|
321
564
|
raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
|
|
322
565
|
f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
|
|
@@ -334,6 +577,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
334
577
|
Transforms safetensors files to a specified format without using parallel processing.
|
|
335
578
|
"""
|
|
336
579
|
io_cost_time = 0
|
|
580
|
+
meta_data = {"format": "ms"}
|
|
337
581
|
if src_strategy_file is not None:
|
|
338
582
|
from mindspore.train._utils import get_parameter_redundancy
|
|
339
583
|
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
|
|
@@ -353,13 +597,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
353
597
|
file_dict = {}
|
|
354
598
|
single_param_dict = {}
|
|
355
599
|
for file_id, _ in all_safetensor_files_map.items():
|
|
356
|
-
f =
|
|
600
|
+
f = _fast_safe_open(all_safetensor_files_map.get(file_id), framework="np")
|
|
357
601
|
file_dict[file_id] = f
|
|
358
602
|
for param_name in f.keys():
|
|
359
603
|
if param_name not in single_param_dict.keys():
|
|
360
604
|
single_param_dict[param_name] = {file_id}
|
|
361
605
|
else:
|
|
362
606
|
single_param_dict[param_name].add(file_id)
|
|
607
|
+
if f.metadata() is not None:
|
|
608
|
+
meta_data.update(f.metadata())
|
|
363
609
|
src_strategy_list_keys = _convert_to_list(src_strategy_dict).keys() if src_strategy_dict else []
|
|
364
610
|
dst_strategy_list_keys = _convert_to_list(dst_strategy_dict).keys() if dst_strategy_dict else []
|
|
365
611
|
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
|
|
@@ -368,27 +614,29 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
368
614
|
needed_rank_list = needed_rank_list_key.split("-")
|
|
369
615
|
for needed_rank in needed_rank_list:
|
|
370
616
|
if pipe_param_list:
|
|
371
|
-
|
|
617
|
+
safetensor_dict = dict()
|
|
372
618
|
if src_strategy_file is not None:
|
|
373
619
|
io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
374
|
-
|
|
620
|
+
safetensor_dict, redundancy_dict, needed_rank,
|
|
375
621
|
device_num, choice_func)
|
|
376
622
|
io_cost_time += io_time
|
|
377
623
|
else:
|
|
378
|
-
with
|
|
624
|
+
with _fast_safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
|
|
379
625
|
if not unified_flag:
|
|
380
626
|
all_param_name_set = set(f.keys())
|
|
381
627
|
src_param_name_set = set(src_strategy_list_keys)
|
|
382
628
|
dst_param_name_set = set(dst_strategy_list_keys)
|
|
383
629
|
hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
|
|
384
630
|
pipe_param_list.extend(list(hyper_param_set))
|
|
631
|
+
if f.metadata() is not None:
|
|
632
|
+
meta_data.update(f.metadata())
|
|
385
633
|
io_time = 0
|
|
386
634
|
for param_name in pipe_param_list:
|
|
387
635
|
if param_name not in f.keys():
|
|
388
636
|
# param not in ckpt file, check reason
|
|
389
637
|
continue
|
|
390
638
|
start_time = time.time()
|
|
391
|
-
output = f.
|
|
639
|
+
output = f.get_tensor(param_name)
|
|
392
640
|
end_time = time.time()
|
|
393
641
|
cost_time = end_time - start_time
|
|
394
642
|
io_time += cost_time
|
|
@@ -400,15 +648,15 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
400
648
|
if not isinstance(choice_out, (bool, str)):
|
|
401
649
|
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
402
650
|
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
403
|
-
|
|
651
|
+
safetensor_dict[param_name] = output
|
|
404
652
|
else:
|
|
405
653
|
start_time = time.time()
|
|
406
|
-
|
|
654
|
+
safetensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
|
|
407
655
|
end_time = time.time()
|
|
408
656
|
cost_time = end_time - start_time
|
|
409
657
|
io_cost_time += cost_time
|
|
410
658
|
|
|
411
|
-
for param_name, param in
|
|
659
|
+
for param_name, param in safetensor_dict.items():
|
|
412
660
|
src_rank = int(needed_rank) % src_stage_device_num
|
|
413
661
|
param_total_dict[param_name][src_rank] = param
|
|
414
662
|
param_attr_dict[param_name][src_rank] = (True, False)
|
|
@@ -442,11 +690,11 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
442
690
|
else:
|
|
443
691
|
if transform_param_dict:
|
|
444
692
|
if output_format == "safetensors":
|
|
445
|
-
|
|
693
|
+
_save_file_atomically(transform_param_dict, save_file_name, metadata=meta_data)
|
|
446
694
|
else:
|
|
447
|
-
transform_param_dict = _load_and_transform(transform_param_dict,
|
|
448
|
-
|
|
449
|
-
|
|
695
|
+
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
|
|
696
|
+
transform_func=lambda v, name: Parameter(v,
|
|
697
|
+
name=name))
|
|
450
698
|
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
451
699
|
del param_total_dict_keys
|
|
452
700
|
del param_total_dict
|
|
@@ -464,10 +712,10 @@ def _save_final_safetensors(_transform_param_list, output_format):
|
|
|
464
712
|
new_transform_dict[save_file_name].update(transform_param_dict)
|
|
465
713
|
for save_file_name, transform_param_dict in new_transform_dict.items():
|
|
466
714
|
if output_format == "safetensors":
|
|
467
|
-
|
|
715
|
+
_save_file_atomically(transform_param_dict, save_file_name, metadata={"format": "ms"})
|
|
468
716
|
else:
|
|
469
717
|
transform_param_dict = _load_and_transform(transform_param_dict, None, None,
|
|
470
|
-
transform_func=lambda v, name:
|
|
718
|
+
transform_func=lambda v, name: Parameter(v, name=name))
|
|
471
719
|
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
472
720
|
|
|
473
721
|
|
|
@@ -501,8 +749,8 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
|
|
|
501
749
|
if not os.path.exists(local_file):
|
|
502
750
|
raise ValueError("safetensor file {} in rank {} not exits: ".format(local_file, rank))
|
|
503
751
|
for rank, file_name in safetensor_files_map.items():
|
|
504
|
-
|
|
505
|
-
for param_name, param in
|
|
752
|
+
safetensor_dict = load_file(file_name)
|
|
753
|
+
for param_name, param in safetensor_dict.items():
|
|
506
754
|
# cut the parameter not in the pipeline stage.
|
|
507
755
|
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
|
|
508
756
|
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
|
|
@@ -520,7 +768,7 @@ def transform_safetensors_by_stage(src_safetensors_dir, dst_safetensors_dir, ckp
|
|
|
520
768
|
if not os.path.exists(save_safetensor_file_dir):
|
|
521
769
|
_make_dir(save_safetensor_file_dir, "path")
|
|
522
770
|
save_safetensor_file_name = os.path.join(save_safetensor_file_dir, save_safetensor_file)
|
|
523
|
-
|
|
771
|
+
_save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
|
|
524
772
|
|
|
525
773
|
|
|
526
774
|
def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor_file_name,
|
|
@@ -556,8 +804,8 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
556
804
|
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
|
|
557
805
|
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
|
|
558
806
|
for rank, file_name in safetensor_files_map.items():
|
|
559
|
-
|
|
560
|
-
for param_name, param in
|
|
807
|
+
safetensor_dict = load_file(file_name)
|
|
808
|
+
for param_name, param in safetensor_dict.items():
|
|
561
809
|
# cut the parameter not in the pipeline stage.
|
|
562
810
|
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
|
|
563
811
|
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
|
|
@@ -572,7 +820,7 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
572
820
|
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
573
821
|
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
574
822
|
param_type_dict)
|
|
575
|
-
|
|
823
|
+
_save_file_atomically(transform_param_dict, save_safetensor_file_name, metadata={"format": "ms"})
|
|
576
824
|
|
|
577
825
|
|
|
578
826
|
def _extrace_number(file_name):
|
|
@@ -628,7 +876,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
|
|
|
628
876
|
|
|
629
877
|
def load_file_by_param_name(filename, parme_name_list):
|
|
630
878
|
result = {}
|
|
631
|
-
with
|
|
879
|
+
with _fast_safe_open(filename, framework="np") as f:
|
|
632
880
|
for k in parme_name_list:
|
|
633
881
|
result[k] = f.get_tensor(k)
|
|
634
882
|
return result
|
|
@@ -644,10 +892,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
644
892
|
device_num = -1
|
|
645
893
|
param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
|
|
646
894
|
for param_name in param_total_dict_keys:
|
|
647
|
-
|
|
648
|
-
tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
|
|
649
|
-
else:
|
|
650
|
-
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
895
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
651
896
|
from_dev_matrix = [1]
|
|
652
897
|
from_tensor_map = [-1] * len(tensor_shape)
|
|
653
898
|
from_opt_shard_step = 0
|
|
@@ -695,7 +940,7 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
695
940
|
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
|
|
696
941
|
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
697
942
|
if rank_id % device_num not in param_attr_dict[param_name] and src_strategy_file is None:
|
|
698
|
-
raise ValueError("The
|
|
943
|
+
raise ValueError("The param: {} in rank {} is missing.".format(param_name, rank_id % device_num))
|
|
699
944
|
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
700
945
|
device_list, rank_id)
|
|
701
946
|
|
|
@@ -711,8 +956,6 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
711
956
|
if isinstance(choice_out, str):
|
|
712
957
|
param_name = choice_out
|
|
713
958
|
transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
|
|
714
|
-
if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
|
|
715
|
-
transform_param_dict[param_name] = transform_param_dict[param_name][:]
|
|
716
959
|
|
|
717
960
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
718
961
|
for param_name in param_total_dict_keys:
|
|
@@ -722,33 +965,14 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
722
965
|
continue
|
|
723
966
|
if param_name not in transform_param_dict:
|
|
724
967
|
transform_para = param_total_dict[param_name][rank_id % device_num]
|
|
725
|
-
if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
|
|
726
|
-
transform_para = transform_para[:]
|
|
727
968
|
transform_param_dict[param_name] = transform_para
|
|
728
969
|
return transform_param_dict
|
|
729
970
|
|
|
730
971
|
|
|
731
972
|
def _cal_param_size(shape, dtype):
|
|
732
973
|
"""cal param size by dtype and shape"""
|
|
733
|
-
dtype_size = {
|
|
734
|
-
"BOOL": 1,
|
|
735
|
-
"U8": 1,
|
|
736
|
-
"I8": 1,
|
|
737
|
-
"F8_E5M2": 1,
|
|
738
|
-
"F8_E4M3": 1,
|
|
739
|
-
"I16": 2,
|
|
740
|
-
"U16": 2,
|
|
741
|
-
"I32": 4,
|
|
742
|
-
"U32": 4,
|
|
743
|
-
"I64": 8,
|
|
744
|
-
"U64": 8,
|
|
745
|
-
"F16": 2,
|
|
746
|
-
"BF16": 2,
|
|
747
|
-
"F32": 4,
|
|
748
|
-
"F64": 8,
|
|
749
|
-
}
|
|
750
974
|
num_elements = math.prod(shape)
|
|
751
|
-
element_size =
|
|
975
|
+
element_size = np_dtype_size.get(dtype, 4)
|
|
752
976
|
total_bytes = num_elements * element_size
|
|
753
977
|
return total_bytes
|
|
754
978
|
|
|
@@ -769,14 +993,15 @@ def _split_weight_dict(weights, num_groups):
|
|
|
769
993
|
def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
|
|
770
994
|
"""save hyper param"""
|
|
771
995
|
if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
|
|
772
|
-
with
|
|
996
|
+
with _fast_safe_open(all_safetensor_files_map.get(0), framework="np") as f:
|
|
773
997
|
all_key = f.keys()
|
|
774
998
|
hyper_parameter = set(all_key) - set(name_list)
|
|
775
999
|
if hyper_parameter:
|
|
776
1000
|
hyper_dict = {}
|
|
777
1001
|
for key in hyper_parameter:
|
|
778
1002
|
hyper_dict[key] = f.get_tensor(key)
|
|
779
|
-
|
|
1003
|
+
_save_file_atomically(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"),
|
|
1004
|
+
metadata={"format": "ms"})
|
|
780
1005
|
|
|
781
1006
|
|
|
782
1007
|
def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
|
|
@@ -826,14 +1051,57 @@ def _get_dst_shape(param_name, param_shape, src_strategy_list):
|
|
|
826
1051
|
return to_full_tensor_shape
|
|
827
1052
|
|
|
828
1053
|
|
|
1054
|
+
def _check_remove_redundancy(merge_with_redundancy, f):
|
|
1055
|
+
"""Check whether remove_redundancy is consistent with the safetensors file."""
|
|
1056
|
+
if f.metadata() is not None and "remove_redundancy" in f.metadata().keys():
|
|
1057
|
+
if f.metadata()["remove_redundancy"] == "True" and merge_with_redundancy:
|
|
1058
|
+
logger.warning("For 'unified_safetensors', the safetensors file is deduplicated, "
|
|
1059
|
+
"but merge_with_redundancy is set to True.")
|
|
1060
|
+
return False
|
|
1061
|
+
if f.metadata()["remove_redundancy"] == "False" and not merge_with_redundancy:
|
|
1062
|
+
logger.warning("For 'unified_safetensors', the safetensors file is non-deduplicated, "
|
|
1063
|
+
"but merge_with_redundancy is set to False.")
|
|
1064
|
+
return True
|
|
1065
|
+
return merge_with_redundancy
|
|
1066
|
+
|
|
1067
|
+
|
|
1068
|
+
def set_affinity_pid():
|
|
1069
|
+
"""Set CPU affinity pid"""
|
|
1070
|
+
pid = os.getpid()
|
|
1071
|
+
total_cores = os.cpu_count()
|
|
1072
|
+
all_cores = set(range(total_cores))
|
|
1073
|
+
os.sched_setaffinity(pid, all_cores)
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def _validate_safetensors_files(target_directory, expected_file_ids):
|
|
1077
|
+
"""Validate whether safetensors files are completely generated in the target directory."""
|
|
1078
|
+
missing_file_ids = []
|
|
1079
|
+
for file_id in expected_file_ids:
|
|
1080
|
+
safetensors_file = os.path.join(target_directory, f"part{file_id}.safetensors")
|
|
1081
|
+
if os.path.exists(safetensors_file):
|
|
1082
|
+
continue
|
|
1083
|
+
missing_file_ids.append(file_id)
|
|
1084
|
+
|
|
1085
|
+
if missing_file_ids:
|
|
1086
|
+
logger.warning(
|
|
1087
|
+
f"For unified_safetensors, target file part {missing_file_ids} does not exist. "
|
|
1088
|
+
f"Possible causes: file rename failed, insufficient permissions, or disk space shortage."
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
|
|
829
1092
|
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
|
|
830
1093
|
max_process_num=64, choice_func=None, split_dst_file=()):
|
|
831
1094
|
"""
|
|
832
1095
|
Merge multiple safetensor files into a unified safetensor file.
|
|
833
1096
|
|
|
1097
|
+
Note:
|
|
1098
|
+
When merging weights, it will verify whether the `merge_with_redundancy` parameter differs from
|
|
1099
|
+
the deduplication flag in the merged safetensors files. If they are the same, the merging will be performed
|
|
1100
|
+
according to the deduplication flag in the files.
|
|
1101
|
+
|
|
834
1102
|
Args:
|
|
835
1103
|
src_dir (str): Source weight saving directory.
|
|
836
|
-
src_strategy_file (str): Source weight segmentation strategy file.
|
|
1104
|
+
src_strategy_file (str): Source weight segmentation strategy file with the file extension `.ckpt` .
|
|
837
1105
|
dst_dir (str): Target save directory.
|
|
838
1106
|
merge_with_redundancy (bool, optional): Whether the merged source weight files are de-duplicated and
|
|
839
1107
|
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
|
|
@@ -861,10 +1129,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
861
1129
|
>>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
|
|
862
1130
|
>>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
863
1131
|
"""
|
|
864
|
-
|
|
865
|
-
total_cores = os.cpu_count()
|
|
866
|
-
all_cores = set(range(total_cores))
|
|
867
|
-
os.sched_setaffinity(pid, all_cores)
|
|
1132
|
+
set_affinity_pid()
|
|
868
1133
|
_check_transform_safetensors(src_dir, "", src_strategy_file, None)
|
|
869
1134
|
_make_dir(dst_dir, "path")
|
|
870
1135
|
if os.path.isfile(src_dir):
|
|
@@ -890,8 +1155,9 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
890
1155
|
|
|
891
1156
|
actual_params = set()
|
|
892
1157
|
for _, file_name in all_safetensor_files_map.items():
|
|
893
|
-
with
|
|
1158
|
+
with _fast_safe_open(file_name, framework="np") as f:
|
|
894
1159
|
actual_params.update(f.keys())
|
|
1160
|
+
merge_with_redundancy = _check_remove_redundancy(merge_with_redundancy, f)
|
|
895
1161
|
|
|
896
1162
|
params_to_store = actual_params & set(layout_map.keys())
|
|
897
1163
|
|
|
@@ -904,21 +1170,22 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
904
1170
|
param_size_dict = {}
|
|
905
1171
|
param_total_size = 0
|
|
906
1172
|
for _, file_name in all_safetensor_files_map.items():
|
|
907
|
-
with
|
|
1173
|
+
with _fast_safe_open(file_name, framework="np") as f:
|
|
908
1174
|
for k in f.keys():
|
|
909
1175
|
if k in name_list:
|
|
910
|
-
py_slice = f.
|
|
911
|
-
param_total_size += _cal_param_size(py_slice.
|
|
912
|
-
param_dst_shape = _get_dst_shape(k, py_slice.
|
|
1176
|
+
py_slice = f.get_tensor(k)
|
|
1177
|
+
param_total_size += _cal_param_size(py_slice.shape, py_slice.dtype)
|
|
1178
|
+
param_dst_shape = _get_dst_shape(k, py_slice.shape, origin_src_strategy_list)
|
|
913
1179
|
# Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
|
|
914
1180
|
param_dst_shape = [int(item) for item in param_dst_shape]
|
|
915
1181
|
if choice_func is not None:
|
|
916
1182
|
choice_out = choice_func(k)
|
|
917
1183
|
if isinstance(choice_out, bool):
|
|
918
1184
|
if not choice_out:
|
|
1185
|
+
name_list.remove(k)
|
|
919
1186
|
continue
|
|
920
1187
|
if k not in param_size_dict:
|
|
921
|
-
param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.
|
|
1188
|
+
param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.dtype)
|
|
922
1189
|
split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
|
|
923
1190
|
split_num = min(split_num, len(name_list))
|
|
924
1191
|
split_list = _split_weight_dict(param_size_dict, split_num)
|
|
@@ -932,37 +1199,44 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
932
1199
|
start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
|
|
933
1200
|
end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
|
|
934
1201
|
sub_list = []
|
|
935
|
-
for i in
|
|
1202
|
+
for i, item in enumerate(split_list):
|
|
936
1203
|
if start_index <= i < end_index:
|
|
937
|
-
sub_list.append(
|
|
1204
|
+
sub_list.append(item)
|
|
938
1205
|
else:
|
|
939
1206
|
sub_list.append([-1])
|
|
1207
|
+
split_num = end_index - start_index
|
|
1208
|
+
res = list(range(start_index, end_index))
|
|
940
1209
|
else:
|
|
941
1210
|
sub_list = split_list
|
|
1211
|
+
res = [i for i in range(split_num)]
|
|
942
1212
|
|
|
943
1213
|
_save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
|
|
944
1214
|
_save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
|
|
945
1215
|
|
|
946
|
-
if split_dst_file:
|
|
947
|
-
split_num = end_index - start_index
|
|
948
|
-
res = list(range(start_index, end_index))
|
|
949
|
-
else:
|
|
950
|
-
res = [i for i in range(split_num)]
|
|
951
1216
|
max_process = min(split_num, max_process_num)
|
|
1217
|
+
file_ids = res[:]
|
|
952
1218
|
res = _split_list(res, max_process)
|
|
953
1219
|
processes = []
|
|
954
1220
|
src_strategy_name = None
|
|
955
1221
|
if not merge_with_redundancy:
|
|
956
1222
|
src_strategy_name = src_strategy_file
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
p
|
|
1223
|
+
if max_process > 1:
|
|
1224
|
+
for i in range(max_process):
|
|
1225
|
+
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
|
|
1226
|
+
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
1227
|
+
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
|
|
1228
|
+
"", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
|
|
1229
|
+
p.start()
|
|
1230
|
+
processes.append(p)
|
|
1231
|
+
for p in processes:
|
|
1232
|
+
p.join()
|
|
1233
|
+
else:
|
|
1234
|
+
_transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
1235
|
+
dst_stage_device_num, src_strategy_dict, None,
|
|
1236
|
+
origin_src_strategy_list, origin_dst_strategy_list, "",
|
|
1237
|
+
dst_dir, "safetensors", None, sub_list,
|
|
1238
|
+
res[0], True, src_strategy_name, choice_func)
|
|
1239
|
+
_validate_safetensors_files(dst_dir, file_ids)
|
|
966
1240
|
|
|
967
1241
|
|
|
968
1242
|
def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor_files_map,
|
|
@@ -997,7 +1271,7 @@ def _split_list(split_list, split_num):
|
|
|
997
1271
|
def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num):
|
|
998
1272
|
"""apply safetensors object operators"""
|
|
999
1273
|
if not transform_operator_stack:
|
|
1000
|
-
return sf_obj
|
|
1274
|
+
return sf_obj
|
|
1001
1275
|
level = transform_operator_stack[-1][1]
|
|
1002
1276
|
level_operators = []
|
|
1003
1277
|
while True:
|
|
@@ -1022,7 +1296,7 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
|
|
|
1022
1296
|
allgather_list = [sf_obj for _ in operator[1][:-1]]
|
|
1023
1297
|
tmp_tensor_dict[rank_id % device_num] = _apply_operator(operator[0])(allgather_list, operator)
|
|
1024
1298
|
if op_name == "AllConcat":
|
|
1025
|
-
for
|
|
1299
|
+
for _, value in tmp_tensor_dict.items():
|
|
1026
1300
|
sf_obj = value
|
|
1027
1301
|
level_operators.clear()
|
|
1028
1302
|
if not transform_operator_stack:
|
|
@@ -1037,13 +1311,26 @@ def _process_hyper_params(file_list, total_safetensors_dir, total_param):
|
|
|
1037
1311
|
"""process hyper params"""
|
|
1038
1312
|
if 'hyper_param.safetensors' in file_list:
|
|
1039
1313
|
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
1040
|
-
with
|
|
1314
|
+
with _fast_safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
1041
1315
|
for key in f.keys():
|
|
1042
|
-
total_param[key] =
|
|
1316
|
+
total_param[key] = Parameter(Tensor.from_numpy(f.get_tensor(key)))
|
|
1043
1317
|
return total_param
|
|
1044
1318
|
|
|
1045
1319
|
|
|
1046
|
-
def
|
|
1320
|
+
def _get_param_name_map_by_file(file_name, file_list, name_map):
|
|
1321
|
+
"""get param_name_map by file"""
|
|
1322
|
+
with _fast_safe_open(file_name, framework="np") as f:
|
|
1323
|
+
keys = f.keys()
|
|
1324
|
+
values = len(keys) * [file_list[0]]
|
|
1325
|
+
if name_map:
|
|
1326
|
+
flipped_name_map = {value: key for key, value in name_map.items()}
|
|
1327
|
+
keys = [flipped_name_map.get(key, key) for key in keys]
|
|
1328
|
+
param_name_map = dict(zip(keys, values))
|
|
1329
|
+
return param_name_map
|
|
1330
|
+
|
|
1331
|
+
|
|
1332
|
+
def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files,
|
|
1333
|
+
dst_strategy_file, rank_id, name_map=None):
|
|
1047
1334
|
"""calculate param_name_map and param_list"""
|
|
1048
1335
|
if len(file_list) == 1:
|
|
1049
1336
|
logger.info("There is only one weight file in the directory, which will be automatically mapped.")
|
|
@@ -1052,10 +1339,7 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
|
|
|
1052
1339
|
if not is_file:
|
|
1053
1340
|
raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
|
|
1054
1341
|
f"in the `unified_safetensors_dir`.")
|
|
1055
|
-
|
|
1056
|
-
keys = f.keys()
|
|
1057
|
-
values = len(keys) * [file_list[0]]
|
|
1058
|
-
param_name_map = dict(zip(keys, values))
|
|
1342
|
+
param_name_map = _get_param_name_map_by_file(file_name, file_list, name_map)
|
|
1059
1343
|
else:
|
|
1060
1344
|
if not json_files:
|
|
1061
1345
|
raise ValueError(
|
|
@@ -1076,19 +1360,71 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
|
|
|
1076
1360
|
return param_name_map, param_list, dst_strategy_list
|
|
1077
1361
|
|
|
1078
1362
|
|
|
1363
|
+
def _cal_transform_operator_stack_and_device_num(from_dev_matrix, from_tensor_map, from_opt_shard_step,
|
|
1364
|
+
from_opt_shard_size, param_name, dst_strategy_list, tensor_shape,
|
|
1365
|
+
local_rank_id):
|
|
1366
|
+
"""cal transform_operator_stack and device_num"""
|
|
1367
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
1368
|
+
dst_strategy_list.get(param_name))
|
|
1369
|
+
|
|
1370
|
+
device_num = np.prod(from_dev_matrix)
|
|
1371
|
+
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
1372
|
+
origin_tensor_shape = ()
|
|
1373
|
+
for i, item in enumerate(tensor_shape):
|
|
1374
|
+
if i == 0 and from_opt_shard_size > 0:
|
|
1375
|
+
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
1376
|
+
continue
|
|
1377
|
+
origin_tensor_shape += (item * param_strategy[i],)
|
|
1378
|
+
|
|
1379
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
1380
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
1381
|
+
|
|
1382
|
+
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1383
|
+
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
1384
|
+
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1385
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
1386
|
+
# Convert tensor layout to same device num
|
|
1387
|
+
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
|
|
1388
|
+
from_dev_matrix,
|
|
1389
|
+
from_tensor_map,
|
|
1390
|
+
to_full_tensor_shape,
|
|
1391
|
+
to_dev_matrix, to_tensor_map)
|
|
1392
|
+
|
|
1393
|
+
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
|
|
1394
|
+
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
1395
|
+
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
1396
|
+
device_list, local_rank_id)
|
|
1397
|
+
|
|
1398
|
+
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
1399
|
+
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
1400
|
+
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
1401
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
|
|
1402
|
+
has_layout_from, has_layout_to)
|
|
1403
|
+
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
|
|
1404
|
+
return transform_operator_stack, device_num
|
|
1405
|
+
|
|
1406
|
+
|
|
1407
|
+
def check_param_dtype(file, param_name):
|
|
1408
|
+
dtype_need_changed = False
|
|
1409
|
+
changed_dtype = None
|
|
1410
|
+
if file.metadata() is not None and param_name in file.metadata().keys():
|
|
1411
|
+
dtype_need_changed = True
|
|
1412
|
+
sf_dtype = file.metadata()[param_name]
|
|
1413
|
+
changed_dtype = safetensors_to_mstype[sf_dtype]
|
|
1414
|
+
return dtype_need_changed, changed_dtype
|
|
1415
|
+
|
|
1416
|
+
|
|
1079
1417
|
def _load_parallel_checkpoint(file_info):
|
|
1080
1418
|
"""load parallel safetensors by merged file."""
|
|
1081
1419
|
total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
total_cores = os.cpu_count()
|
|
1085
|
-
all_cores = set(range(total_cores))
|
|
1086
|
-
os.sched_setaffinity(pid, all_cores)
|
|
1420
|
+
rank_id, output_format, name_map, return_param_dict = file_info
|
|
1421
|
+
set_affinity_pid()
|
|
1087
1422
|
file_list = os.listdir(total_safetensors_dir)
|
|
1088
1423
|
json_files = [file for file in file_list if file == "param_name_map.json"]
|
|
1089
|
-
|
|
1424
|
+
sf_files = [file for file in file_list if file.endswith('.safetensors')]
|
|
1425
|
+
param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(sf_files, total_safetensors_dir,
|
|
1090
1426
|
json_files, dst_strategy_file,
|
|
1091
|
-
rank_id)
|
|
1427
|
+
rank_id, name_map)
|
|
1092
1428
|
total_param = dict()
|
|
1093
1429
|
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
1094
1430
|
is not None else 1
|
|
@@ -1098,13 +1434,14 @@ def _load_parallel_checkpoint(file_info):
|
|
|
1098
1434
|
if param_name not in param_name_map:
|
|
1099
1435
|
continue
|
|
1100
1436
|
file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
|
|
1101
|
-
with
|
|
1437
|
+
with _fast_safe_open(file_name, framework="np") as f:
|
|
1102
1438
|
cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
|
|
1103
1439
|
if cur_param_name not in f.keys():
|
|
1104
1440
|
continue
|
|
1105
|
-
sf_obj = f.
|
|
1441
|
+
sf_obj = f.get_tensor(cur_param_name)
|
|
1442
|
+
dtype_need_changed, changed_dtype = check_param_dtype(f, param_name)
|
|
1106
1443
|
|
|
1107
|
-
tensor_shape = sf_obj.
|
|
1444
|
+
tensor_shape = sf_obj.shape
|
|
1108
1445
|
from_dev_matrix = [1]
|
|
1109
1446
|
from_tensor_map = [-1] * len(tensor_shape)
|
|
1110
1447
|
from_opt_shard_step = 0
|
|
@@ -1112,43 +1449,14 @@ def _load_parallel_checkpoint(file_info):
|
|
|
1112
1449
|
if dst_strategy_list is not None:
|
|
1113
1450
|
if param_name not in dst_strategy_list:
|
|
1114
1451
|
continue
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
1124
|
-
continue
|
|
1125
|
-
origin_tensor_shape += (item * param_strategy[i],)
|
|
1126
|
-
|
|
1127
|
-
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
1128
|
-
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
1129
|
-
|
|
1130
|
-
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1131
|
-
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
1132
|
-
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1133
|
-
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
1134
|
-
# Convert tensor layout to same device num
|
|
1135
|
-
from_tensor_layout, to_tensor_layout = _construct_from_to_tensor_layout(from_full_tensor_shape,
|
|
1136
|
-
from_dev_matrix,
|
|
1137
|
-
from_tensor_map,
|
|
1138
|
-
to_full_tensor_shape,
|
|
1139
|
-
to_dev_matrix, to_tensor_map)
|
|
1140
|
-
|
|
1141
|
-
# when the from_layout is less devices, the safetensor_map for map[device_num] should using map[0]
|
|
1142
|
-
device_list = list(range(0, np.prod(from_tensor_layout[0])))
|
|
1143
|
-
param_rank_map = _get_needed_rank_transform_operator_map_by_layouts(from_tensor_layout, to_tensor_layout,
|
|
1144
|
-
device_list, local_rank_id)
|
|
1145
|
-
|
|
1146
|
-
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
1147
|
-
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
1148
|
-
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
1149
|
-
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
|
|
1150
|
-
has_layout_from, has_layout_to)
|
|
1151
|
-
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
|
|
1452
|
+
transform_operator_stack, device_num = _cal_transform_operator_stack_and_device_num(from_dev_matrix,
|
|
1453
|
+
from_tensor_map,
|
|
1454
|
+
from_opt_shard_step,
|
|
1455
|
+
from_opt_shard_size,
|
|
1456
|
+
param_name,
|
|
1457
|
+
dst_strategy_list,
|
|
1458
|
+
tensor_shape,
|
|
1459
|
+
local_rank_id)
|
|
1152
1460
|
start_time = time.time()
|
|
1153
1461
|
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
|
|
1154
1462
|
end_time = time.time()
|
|
@@ -1156,11 +1464,15 @@ def _load_parallel_checkpoint(file_info):
|
|
|
1156
1464
|
total_io_cost_time += cost_time
|
|
1157
1465
|
else:
|
|
1158
1466
|
start_time = time.time()
|
|
1159
|
-
slice_param = sf_obj
|
|
1467
|
+
slice_param = sf_obj
|
|
1160
1468
|
end_time = time.time()
|
|
1161
1469
|
cost_time = end_time - start_time
|
|
1162
1470
|
total_io_cost_time += cost_time
|
|
1163
|
-
|
|
1471
|
+
slice_param_copy = np.copy(slice_param)
|
|
1472
|
+
if dtype_need_changed:
|
|
1473
|
+
total_param[param_name] = Parameter(Tensor(slice_param_copy, dtype=changed_dtype))
|
|
1474
|
+
else:
|
|
1475
|
+
total_param[param_name] = Parameter(Tensor.from_numpy(slice_param_copy))
|
|
1164
1476
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1165
1477
|
f"load distributed safetensors io cost time:{total_io_cost_time}.")
|
|
1166
1478
|
total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
|
|
@@ -1177,28 +1489,5 @@ def _load_parallel_checkpoint(file_info):
|
|
|
1177
1489
|
return None
|
|
1178
1490
|
|
|
1179
1491
|
|
|
1180
|
-
def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
|
|
1181
|
-
"""get slice op"""
|
|
1182
|
-
tensor_shape = sf_obj.get_shape()
|
|
1183
|
-
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size = _extract_layout_item(
|
|
1184
|
-
dst_strategy_list.get(param_name))
|
|
1185
|
-
# Add optimizer sharding dim for tensor layout
|
|
1186
|
-
to_dev_matrix, to_tensor_map, _ = _construct_tensor_layout_for_opt_shard(
|
|
1187
|
-
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, tensor_shape)
|
|
1188
|
-
slice_op = _load_tensor_shape(to_dev_matrix, to_tensor_map, full_shape=tensor_shape, rank_id=rank_id)
|
|
1189
|
-
shape = None
|
|
1190
|
-
if to_opt_shard_size > 0:
|
|
1191
|
-
to_tensor_strategy = _get_tensor_strategy(to_dev_matrix_origin, to_tensor_map_origin)
|
|
1192
|
-
to_slice_tensor_shape = ()
|
|
1193
|
-
for i, item in enumerate(tensor_shape):
|
|
1194
|
-
if i == 0 and to_opt_shard_size > 0:
|
|
1195
|
-
to_slice_tensor_shape += (item // (to_tensor_strategy[i] * to_opt_shard_size),)
|
|
1196
|
-
continue
|
|
1197
|
-
to_slice_tensor_shape += (item // to_tensor_strategy[i],)
|
|
1198
|
-
shape = list(to_slice_tensor_shape)
|
|
1199
|
-
|
|
1200
|
-
return slice_op, shape
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
1492
|
__all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
|
|
1204
1493
|
"transform_safetensors_by_rank", "unified_safetensors"]
|