mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,19 +16,26 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
-
import
|
|
19
|
+
import sys
|
|
20
20
|
import glob
|
|
21
|
-
import re
|
|
22
21
|
import math
|
|
23
22
|
import json
|
|
23
|
+
import re
|
|
24
24
|
from collections import defaultdict
|
|
25
25
|
|
|
26
|
+
import time
|
|
26
27
|
import multiprocessing as mp
|
|
28
|
+
import psutil
|
|
27
29
|
import numpy as np
|
|
30
|
+
from safetensors.numpy import save_file, load_file
|
|
31
|
+
from safetensors import safe_open
|
|
32
|
+
|
|
28
33
|
import mindspore as ms
|
|
34
|
+
from mindspore import log as logger
|
|
35
|
+
from mindspore.log import vlog_print
|
|
29
36
|
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
|
|
30
37
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
31
|
-
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
|
|
38
|
+
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
|
|
32
39
|
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
33
40
|
_get_needed_rank_transform_operator_map_by_layouts, \
|
|
34
41
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
@@ -36,70 +43,6 @@ from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_
|
|
|
36
43
|
from mindspore.parallel._parallel_serialization import _build_searched_strategy, _load_protobuf_strategy, \
|
|
37
44
|
_convert_to_list
|
|
38
45
|
|
|
39
|
-
from safetensors.numpy import save_file, load_file
|
|
40
|
-
from safetensors import safe_open
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
44
|
-
if load_func is not None:
|
|
45
|
-
param_dict = load_func(path)
|
|
46
|
-
else:
|
|
47
|
-
param_dict = path
|
|
48
|
-
transform_dict = {}
|
|
49
|
-
for k, v in param_dict.items():
|
|
50
|
-
new_name = name_map.get(k, k) if name_map is not None else k
|
|
51
|
-
transform_dict[new_name] = transform_func(v, new_name)
|
|
52
|
-
return transform_dict
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _transform_tensor_to_numpy(path, name_map=None):
|
|
56
|
-
return _load_and_transform(path, name_map, ms.load_checkpoint, lambda v, new_name: v.asnumpy())
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _transform_numpy_to_tensor(path, name_map=None):
|
|
60
|
-
return _load_and_transform(path, name_map, load_file, lambda v, new_name: ms.Parameter(v, name=new_name))
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _process_file(file_info):
|
|
64
|
-
cur_ckpt_path, name_map, save_path, file = file_info
|
|
65
|
-
param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
|
|
66
|
-
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
67
|
-
dst_file = os.path.join(save_path, safetensors_filename)
|
|
68
|
-
save_file(param_dict_numpy, dst_file)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _process_file_safetensors(file_info):
|
|
72
|
-
cur_safe_path, name_map, save_path, file = file_info
|
|
73
|
-
param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
|
|
74
|
-
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
75
|
-
dst_file = os.path.join(save_path, ckpt_filename)
|
|
76
|
-
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _gather_tasks(file_path, save_path, file_name_regex, name_map):
|
|
80
|
-
"""gather transform rank together"""
|
|
81
|
-
tasks = []
|
|
82
|
-
for root, dirs, _ in os.walk(file_path):
|
|
83
|
-
if root != file_path:
|
|
84
|
-
continue
|
|
85
|
-
|
|
86
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
87
|
-
if not rank_dirs:
|
|
88
|
-
raise ValueError(
|
|
89
|
-
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
90
|
-
|
|
91
|
-
for rank_dir in rank_dirs:
|
|
92
|
-
rank_dir_path = os.path.join(root, rank_dir)
|
|
93
|
-
dst_root = os.path.join(save_path,
|
|
94
|
-
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
95
|
-
os.makedirs(dst_root, exist_ok=True)
|
|
96
|
-
tasks.extend(
|
|
97
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
98
|
-
for file in os.listdir(rank_dir_path)
|
|
99
|
-
if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
100
|
-
)
|
|
101
|
-
return tasks
|
|
102
|
-
|
|
103
46
|
|
|
104
47
|
def _progress_bar(iterable, total=None):
|
|
105
48
|
"""
|
|
@@ -125,6 +68,7 @@ def _progress_bar(iterable, total=None):
|
|
|
125
68
|
elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
|
|
126
69
|
remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
|
|
127
70
|
|
|
71
|
+
sys.stdout.reconfigure(encoding="utf-8")
|
|
128
72
|
print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
|
|
129
73
|
if iteration == total:
|
|
130
74
|
print()
|
|
@@ -134,155 +78,16 @@ def _progress_bar(iterable, total=None):
|
|
|
134
78
|
print_progress_bar(i)
|
|
135
79
|
|
|
136
80
|
|
|
137
|
-
def
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
verification, an error will be generated when performing the conversion.
|
|
148
|
-
The safetensors format currently does not support crc verification function. If ckpt contains crc verification
|
|
149
|
-
information, the crc verification information will be lost after conversion to safetensors.
|
|
150
|
-
|
|
151
|
-
Args:
|
|
152
|
-
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
153
|
-
save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
|
|
154
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
155
|
-
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
156
|
-
Defaults: ``None``.
|
|
157
|
-
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
158
|
-
Raises:
|
|
159
|
-
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
160
|
-
or the file_path does not end with '.ckpt'.
|
|
161
|
-
|
|
162
|
-
Supported Platforms:
|
|
163
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
164
|
-
|
|
165
|
-
Examples:
|
|
166
|
-
>>> import mindspore as ms
|
|
167
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path")
|
|
168
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
|
|
169
|
-
>>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
|
|
170
|
-
>>> namemap = {"lin.weight":"new_name"}
|
|
171
|
-
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
172
|
-
"""
|
|
173
|
-
is_dir = os.path.isdir(file_path)
|
|
174
|
-
is_file = os.path.isfile(file_path)
|
|
175
|
-
if not is_dir and not is_file:
|
|
176
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
177
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
178
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
179
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
180
|
-
raise ValueError(
|
|
181
|
-
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
182
|
-
|
|
183
|
-
if is_dir:
|
|
184
|
-
tasks = _gather_tasks(file_path, save_path, file_name_regex, name_map)
|
|
185
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
186
|
-
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
187
|
-
elif is_file:
|
|
188
|
-
if not file_path.endswith(".ckpt"):
|
|
189
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
190
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
191
|
-
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
192
|
-
if save_path and not os.path.exists(save_path):
|
|
193
|
-
os.makedirs(save_path, exist_ok=True)
|
|
194
|
-
|
|
195
|
-
param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
|
|
196
|
-
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
197
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
198
|
-
save_file(param_dict_numpy, dst_file)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
202
|
-
"""gather transform rank together"""
|
|
203
|
-
tasks = []
|
|
204
|
-
for root, dirs, _ in os.walk(file_path):
|
|
205
|
-
if root != file_path:
|
|
206
|
-
continue
|
|
207
|
-
|
|
208
|
-
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
209
|
-
if not rank_dirs:
|
|
210
|
-
raise ValueError(
|
|
211
|
-
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
|
|
212
|
-
|
|
213
|
-
for rank_dir in rank_dirs:
|
|
214
|
-
rank_dir_path = os.path.join(root, rank_dir)
|
|
215
|
-
dst_root = os.path.join(save_path,
|
|
216
|
-
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
217
|
-
os.makedirs(dst_root, exist_ok=True)
|
|
218
|
-
tasks.extend(
|
|
219
|
-
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
220
|
-
for file in os.listdir(rank_dir_path)
|
|
221
|
-
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
222
|
-
)
|
|
223
|
-
return tasks
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
227
|
-
"""
|
|
228
|
-
Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
|
|
229
|
-
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
230
|
-
used for securely storing Tensors with fast speed (zero copy).
|
|
231
|
-
|
|
232
|
-
Note:
|
|
233
|
-
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
234
|
-
too large, otherwise it may cause freezing.
|
|
235
|
-
|
|
236
|
-
Args:
|
|
237
|
-
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
238
|
-
save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
|
|
239
|
-
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
240
|
-
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
241
|
-
Defaults: ``None``.
|
|
242
|
-
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
243
|
-
|
|
244
|
-
Raises:
|
|
245
|
-
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
246
|
-
or the file_path does not end with '.safetensors'.
|
|
247
|
-
|
|
248
|
-
Supported Platforms:
|
|
249
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
250
|
-
|
|
251
|
-
Examples:
|
|
252
|
-
>>> import mindspore as ms
|
|
253
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path")
|
|
254
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
|
|
255
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
|
|
256
|
-
>>> namemap = {"lin.weight":"new_name"}
|
|
257
|
-
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
258
|
-
"""
|
|
259
|
-
is_dir = os.path.isdir(file_path)
|
|
260
|
-
is_file = os.path.isfile(file_path)
|
|
261
|
-
if not is_dir and not is_file:
|
|
262
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
263
|
-
if save_path and os.path.splitext(save_path)[1]:
|
|
264
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
265
|
-
if name_map is not None and not isinstance(name_map, dict):
|
|
266
|
-
raise ValueError(
|
|
267
|
-
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
268
|
-
|
|
269
|
-
if is_dir:
|
|
270
|
-
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
271
|
-
with mp.Pool(processes=processes_num) as pool:
|
|
272
|
-
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
273
|
-
elif is_file:
|
|
274
|
-
if not file_path.endswith(".safetensors"):
|
|
275
|
-
raise ValueError(
|
|
276
|
-
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
277
|
-
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
278
|
-
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
279
|
-
if save_path and not os.path.exists(save_path):
|
|
280
|
-
os.makedirs(save_path, exist_ok=True)
|
|
281
|
-
|
|
282
|
-
param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
|
|
283
|
-
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
284
|
-
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
285
|
-
ms.save_checkpoint(param_dict_tensor, dst_file)
|
|
81
|
+
def _load_and_transform(path, name_map, load_func, transform_func):
|
|
82
|
+
if load_func is not None:
|
|
83
|
+
param_dict = load_func(path)
|
|
84
|
+
else:
|
|
85
|
+
param_dict = path
|
|
86
|
+
transform_dict = {}
|
|
87
|
+
for k, v in param_dict.items():
|
|
88
|
+
new_name = name_map.get(k, k) if name_map is not None else k
|
|
89
|
+
transform_dict[new_name] = transform_func(v, new_name)
|
|
90
|
+
return transform_dict
|
|
286
91
|
|
|
287
92
|
|
|
288
93
|
def _check_transform_safetensors(src_safetensors_dir, ckpt_prefix, src_strategy_file, dst_strategy_file):
|
|
@@ -460,7 +265,6 @@ def _transform_safetensors_with_parallel(needed_rank_list_map, all_safetensor_fi
|
|
|
460
265
|
|
|
461
266
|
for name, layout in layout_map.items():
|
|
462
267
|
pipe_param_list[layout[6][0]].append(name)
|
|
463
|
-
|
|
464
268
|
part_list_dict = _distribute_files_by_size(all_safetensor_files_map, needed_rank_list_map, process_num)
|
|
465
269
|
processes = []
|
|
466
270
|
for i in range(process_num):
|
|
@@ -485,8 +289,9 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
|
|
|
485
289
|
|
|
486
290
|
|
|
487
291
|
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
|
|
488
|
-
needed_rank, device_num):
|
|
292
|
+
needed_rank, device_num, choice_func):
|
|
489
293
|
"""Find the rank_id under redundant groups."""
|
|
294
|
+
io_time = 0
|
|
490
295
|
for param_name in pipe_param_list:
|
|
491
296
|
rank_num = int(needed_rank)
|
|
492
297
|
redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
|
|
@@ -499,11 +304,23 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
|
|
|
499
304
|
open_file_id = real_rank
|
|
500
305
|
break
|
|
501
306
|
if open_file_id is not None:
|
|
502
|
-
|
|
307
|
+
start_time = time.time()
|
|
308
|
+
output = file_dict[open_file_id].get_slice(param_name)
|
|
309
|
+
end_time = time.time()
|
|
310
|
+
cost_time = end_time - start_time
|
|
311
|
+
io_time += cost_time
|
|
312
|
+
if choice_func is not None:
|
|
313
|
+
choice_out = choice_func(param_name)
|
|
314
|
+
if isinstance(choice_out, bool) and not choice_out:
|
|
315
|
+
continue
|
|
316
|
+
if not isinstance(choice_out, (bool, str)):
|
|
317
|
+
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
318
|
+
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
503
319
|
saftensor_dict[param_name] = output
|
|
504
320
|
else:
|
|
505
321
|
raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
|
|
506
322
|
f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
|
|
323
|
+
return io_time
|
|
507
324
|
|
|
508
325
|
|
|
509
326
|
def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
@@ -512,13 +329,14 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
512
329
|
origin_dst_strategy_list,
|
|
513
330
|
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
514
331
|
_transform_param_list, pipe_param_list=None, file_index=None, unified_flag=False,
|
|
515
|
-
src_strategy_file=None):
|
|
332
|
+
src_strategy_file=None, choice_func=None):
|
|
516
333
|
"""
|
|
517
334
|
Transforms safetensors files to a specified format without using parallel processing.
|
|
518
335
|
"""
|
|
336
|
+
io_cost_time = 0
|
|
519
337
|
if src_strategy_file is not None:
|
|
520
338
|
from mindspore.train._utils import get_parameter_redundancy
|
|
521
|
-
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
|
|
339
|
+
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
|
|
522
340
|
redundancy_dict = {}
|
|
523
341
|
device_num = 0
|
|
524
342
|
for param_name, redundancy in redundancy_dict_tmp.items():
|
|
@@ -552,8 +370,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
552
370
|
if pipe_param_list:
|
|
553
371
|
saftensor_dict = dict()
|
|
554
372
|
if src_strategy_file is not None:
|
|
555
|
-
_find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
556
|
-
|
|
373
|
+
io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
374
|
+
saftensor_dict, redundancy_dict, needed_rank,
|
|
375
|
+
device_num, choice_func)
|
|
376
|
+
io_cost_time += io_time
|
|
557
377
|
else:
|
|
558
378
|
with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
|
|
559
379
|
if not unified_flag:
|
|
@@ -562,14 +382,32 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
562
382
|
dst_param_name_set = set(dst_strategy_list_keys)
|
|
563
383
|
hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
|
|
564
384
|
pipe_param_list.extend(list(hyper_param_set))
|
|
385
|
+
io_time = 0
|
|
565
386
|
for param_name in pipe_param_list:
|
|
566
387
|
if param_name not in f.keys():
|
|
567
388
|
# param not in ckpt file, check reason
|
|
568
389
|
continue
|
|
569
|
-
|
|
390
|
+
start_time = time.time()
|
|
391
|
+
output = f.get_slice(param_name)
|
|
392
|
+
end_time = time.time()
|
|
393
|
+
cost_time = end_time - start_time
|
|
394
|
+
io_time += cost_time
|
|
395
|
+
io_cost_time += io_time
|
|
396
|
+
if choice_func is not None:
|
|
397
|
+
choice_out = choice_func(param_name)
|
|
398
|
+
if isinstance(choice_out, bool) and not choice_out:
|
|
399
|
+
continue
|
|
400
|
+
if not isinstance(choice_out, (bool, str)):
|
|
401
|
+
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
402
|
+
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
570
403
|
saftensor_dict[param_name] = output
|
|
571
404
|
else:
|
|
405
|
+
start_time = time.time()
|
|
572
406
|
saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
|
|
407
|
+
end_time = time.time()
|
|
408
|
+
cost_time = end_time - start_time
|
|
409
|
+
io_cost_time += cost_time
|
|
410
|
+
|
|
573
411
|
for param_name, param in saftensor_dict.items():
|
|
574
412
|
src_rank = int(needed_rank) % src_stage_device_num
|
|
575
413
|
param_total_dict[param_name][src_rank] = param
|
|
@@ -588,7 +426,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
588
426
|
local_rank_id = transform_rank % dst_stage_device_num
|
|
589
427
|
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
590
428
|
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
591
|
-
param_total_dict_keys, src_strategy_file)
|
|
429
|
+
param_total_dict_keys, src_strategy_file, choice_func)
|
|
592
430
|
if file_index is not None:
|
|
593
431
|
save_safetensor_file = f"part{file_index}.{output_format}"
|
|
594
432
|
save_safetensor_file_dir = dst_safetensors_dir
|
|
@@ -602,15 +440,17 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
602
440
|
if _transform_param_list is not None:
|
|
603
441
|
_transform_param_list.append({save_file_name: transform_param_dict})
|
|
604
442
|
else:
|
|
605
|
-
if
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
443
|
+
if transform_param_dict:
|
|
444
|
+
if output_format == "safetensors":
|
|
445
|
+
save_file(transform_param_dict, save_file_name)
|
|
446
|
+
else:
|
|
447
|
+
transform_param_dict = _load_and_transform(transform_param_dict,
|
|
448
|
+
None, None, transform_func=
|
|
449
|
+
lambda v, name: ms.Parameter(v, name=name))
|
|
450
|
+
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
612
451
|
del param_total_dict_keys
|
|
613
452
|
del param_total_dict
|
|
453
|
+
return io_cost_time
|
|
614
454
|
|
|
615
455
|
|
|
616
456
|
def _save_final_safetensors(_transform_param_list, output_format):
|
|
@@ -735,6 +575,13 @@ def transform_safetensors_by_rank(rank_id, safetensor_files_map, save_safetensor
|
|
|
735
575
|
save_file(transform_param_dict, save_safetensor_file_name)
|
|
736
576
|
|
|
737
577
|
|
|
578
|
+
def _extrace_number(file_name):
|
|
579
|
+
"""get file last two number"""
|
|
580
|
+
number_ls = re.findall(r'\d+', file_name)
|
|
581
|
+
number_ls = [int(i) for i in number_ls]
|
|
582
|
+
return number_ls[-2:]
|
|
583
|
+
|
|
584
|
+
|
|
738
585
|
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
|
|
739
586
|
"""
|
|
740
587
|
Collects all safetensors files from the specified directory and its subdirectories.
|
|
@@ -758,12 +605,9 @@ def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_su
|
|
|
758
605
|
else:
|
|
759
606
|
safetensor_file_name = os.path.join(safetensor_dir, f"*{file_suffix}.{format}")
|
|
760
607
|
rank_ckpts = glob.glob(safetensor_file_name)
|
|
761
|
-
rank_ckpts.sort()
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
ms.log.warning("{} is not a safetensor file.".format(safetensor_file))
|
|
765
|
-
continue
|
|
766
|
-
all_safetensor_files_map[rank_id] = safetensor_file
|
|
608
|
+
rank_ckpts.sort(key=_extrace_number)
|
|
609
|
+
if rank_ckpts:
|
|
610
|
+
all_safetensor_files_map[rank_id] = rank_ckpts[-1]
|
|
767
611
|
return all_safetensor_files_map
|
|
768
612
|
|
|
769
613
|
|
|
@@ -775,7 +619,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
|
|
|
775
619
|
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
776
620
|
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
777
621
|
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
778
|
-
for rank in
|
|
622
|
+
for rank in range(dst_device_num):
|
|
779
623
|
needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
|
|
780
624
|
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
|
|
781
625
|
needed_rank_list_map[needed_rank_list_key].append(rank)
|
|
@@ -791,7 +635,8 @@ def load_file_by_param_name(filename, parme_name_list):
|
|
|
791
635
|
|
|
792
636
|
|
|
793
637
|
def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
794
|
-
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None
|
|
638
|
+
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None,
|
|
639
|
+
choice_func=None):
|
|
795
640
|
"""
|
|
796
641
|
Transform model parallel dimension for distributed safetensor files.
|
|
797
642
|
"""
|
|
@@ -799,7 +644,10 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
799
644
|
device_num = -1
|
|
800
645
|
param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
|
|
801
646
|
for param_name in param_total_dict_keys:
|
|
802
|
-
|
|
647
|
+
if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
|
|
648
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
|
|
649
|
+
else:
|
|
650
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
803
651
|
from_dev_matrix = [1]
|
|
804
652
|
from_tensor_map = [-1] * len(tensor_shape)
|
|
805
653
|
from_opt_shard_step = 0
|
|
@@ -832,6 +680,9 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
832
680
|
continue
|
|
833
681
|
origin_tensor_shape += (item * param_strategy[i],)
|
|
834
682
|
|
|
683
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
684
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
685
|
+
|
|
835
686
|
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
836
687
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
837
688
|
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
@@ -851,21 +702,132 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
851
702
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
852
703
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
853
704
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
705
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
|
|
854
706
|
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
|
|
855
707
|
param_total_dict_copy = param_total_dict[param_name].copy()
|
|
856
708
|
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
|
|
857
|
-
|
|
709
|
+
if choice_func is not None:
|
|
710
|
+
choice_out = choice_func(param_name)
|
|
711
|
+
if isinstance(choice_out, str):
|
|
712
|
+
param_name = choice_out
|
|
858
713
|
transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
|
|
714
|
+
if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
|
|
715
|
+
transform_param_dict[param_name] = transform_param_dict[param_name][:]
|
|
859
716
|
|
|
860
717
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
861
718
|
for param_name in param_total_dict_keys:
|
|
719
|
+
if choice_func is not None:
|
|
720
|
+
choice_out = choice_func(param_name)
|
|
721
|
+
if isinstance(choice_out, str):
|
|
722
|
+
continue
|
|
862
723
|
if param_name not in transform_param_dict:
|
|
863
724
|
transform_para = param_total_dict[param_name][rank_id % device_num]
|
|
725
|
+
if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
|
|
726
|
+
transform_para = transform_para[:]
|
|
864
727
|
transform_param_dict[param_name] = transform_para
|
|
865
728
|
return transform_param_dict
|
|
866
729
|
|
|
867
730
|
|
|
868
|
-
def
|
|
731
|
+
def _cal_param_size(shape, dtype):
|
|
732
|
+
"""cal param size by dtype and shape"""
|
|
733
|
+
dtype_size = {
|
|
734
|
+
"BOOL": 1,
|
|
735
|
+
"U8": 1,
|
|
736
|
+
"I8": 1,
|
|
737
|
+
"F8_E5M2": 1,
|
|
738
|
+
"F8_E4M3": 1,
|
|
739
|
+
"I16": 2,
|
|
740
|
+
"U16": 2,
|
|
741
|
+
"I32": 4,
|
|
742
|
+
"U32": 4,
|
|
743
|
+
"I64": 8,
|
|
744
|
+
"U64": 8,
|
|
745
|
+
"F16": 2,
|
|
746
|
+
"BF16": 2,
|
|
747
|
+
"F32": 4,
|
|
748
|
+
"F64": 8,
|
|
749
|
+
}
|
|
750
|
+
num_elements = math.prod(shape)
|
|
751
|
+
element_size = dtype_size.get(dtype, 4)
|
|
752
|
+
total_bytes = num_elements * element_size
|
|
753
|
+
return total_bytes
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def _split_weight_dict(weights, num_groups):
|
|
757
|
+
"""split weights by num"""
|
|
758
|
+
sorted_items = sorted(weights.items(), key=lambda x: -x[1])
|
|
759
|
+
groups = [[] for _ in range(num_groups)]
|
|
760
|
+
total_bytes = [0] * num_groups
|
|
761
|
+
for weight_name, byte_size in sorted_items:
|
|
762
|
+
min_index = total_bytes.index(min(total_bytes))
|
|
763
|
+
groups[min_index].append(weight_name)
|
|
764
|
+
total_bytes[min_index] += byte_size
|
|
765
|
+
|
|
766
|
+
return groups
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
|
|
770
|
+
"""save hyper param"""
|
|
771
|
+
if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
|
|
772
|
+
with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
|
|
773
|
+
all_key = f.keys()
|
|
774
|
+
hyper_parameter = set(all_key) - set(name_list)
|
|
775
|
+
if hyper_parameter:
|
|
776
|
+
hyper_dict = {}
|
|
777
|
+
for key in hyper_parameter:
|
|
778
|
+
hyper_dict[key] = f.get_tensor(key)
|
|
779
|
+
save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
|
|
783
|
+
"""save parameter map json file"""
|
|
784
|
+
param_name_dict = dict()
|
|
785
|
+
for index, part_list in enumerate(split_list):
|
|
786
|
+
for name in part_list:
|
|
787
|
+
save_param_name = name
|
|
788
|
+
if choice_func is not None:
|
|
789
|
+
choice_out = choice_func(name)
|
|
790
|
+
if isinstance(choice_out, str):
|
|
791
|
+
save_param_name = choice_out
|
|
792
|
+
if save_param_name == -1:
|
|
793
|
+
break
|
|
794
|
+
param_name_dict[save_param_name] = f"part{index}.safetensors"
|
|
795
|
+
output_dict = {"metadata": {"total_size": param_total_size}, "weight_map": param_name_dict}
|
|
796
|
+
if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
|
|
797
|
+
json_str = json.dumps(output_dict, indent=4)
|
|
798
|
+
map_file = os.path.join(dst_dir, "param_name_map.json")
|
|
799
|
+
with open(map_file, 'w') as f:
|
|
800
|
+
f.write(json_str)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def _get_dst_shape(param_name, param_shape, src_strategy_list):
|
|
804
|
+
"""get dst shape by strategy"""
|
|
805
|
+
from_dev_matrix = [1]
|
|
806
|
+
from_tensor_map = [-1] * len(param_shape)
|
|
807
|
+
from_opt_shard_size = 0
|
|
808
|
+
if src_strategy_list is not None:
|
|
809
|
+
from_dev_matrix, from_tensor_map, _, from_opt_shard_size = _extract_layout_item(
|
|
810
|
+
src_strategy_list.get(param_name))
|
|
811
|
+
to_dev_matrix_origin = [1]
|
|
812
|
+
to_tensor_map_origin = [-1] * len(param_shape)
|
|
813
|
+
to_opt_shard_step = 0
|
|
814
|
+
to_opt_shard_size = 0
|
|
815
|
+
|
|
816
|
+
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
817
|
+
origin_tensor_shape = ()
|
|
818
|
+
for i, item in enumerate(param_shape):
|
|
819
|
+
if i == 0 and from_opt_shard_size > 0:
|
|
820
|
+
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
821
|
+
continue
|
|
822
|
+
origin_tensor_shape += (item * param_strategy[i],)
|
|
823
|
+
|
|
824
|
+
_, _, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
825
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
826
|
+
return to_full_tensor_shape
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
|
|
830
|
+
max_process_num=64, choice_func=None, split_dst_file=()):
|
|
869
831
|
"""
|
|
870
832
|
Merge multiple safetensor files into a unified safetensor file.
|
|
871
833
|
|
|
@@ -877,6 +839,14 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
877
839
|
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
|
|
878
840
|
file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
|
|
879
841
|
meaning all safetensors files in the source weight directory will be merged.
|
|
842
|
+
max_process_num (int, optional): Maximum number of processes. Default: ``64``.
|
|
843
|
+
choice_func (callable, optional): A callable function used to filter parameters or modify parameter names.
|
|
844
|
+
The return value of the function must be of type str (string) or bool (boolean). Default: ``None``.
|
|
845
|
+
split_dst_file (tuple, optional) - A parameter used to manually split a task into multiple subtasks for
|
|
846
|
+
execution, represented as a tuple containing two elements. The first element indicates the number of
|
|
847
|
+
the current subtask, and the second element indicates the total number of tasks. This parameter supports
|
|
848
|
+
splitting and executing tasks multiple times on a single machine, and also supports executing different
|
|
849
|
+
subtasks on multiple machines respectively. Default: ``()``.
|
|
880
850
|
|
|
881
851
|
Raises:
|
|
882
852
|
ValueError: If the safetensors file of rank is missing.
|
|
@@ -889,8 +859,12 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
889
859
|
>>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
|
|
890
860
|
>>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
|
|
891
861
|
>>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
|
|
892
|
-
>>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
862
|
+
>>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
893
863
|
"""
|
|
864
|
+
pid = os.getpid()
|
|
865
|
+
total_cores = os.cpu_count()
|
|
866
|
+
all_cores = set(range(total_cores))
|
|
867
|
+
os.sched_setaffinity(pid, all_cores)
|
|
894
868
|
_check_transform_safetensors(src_dir, "", src_strategy_file, None)
|
|
895
869
|
_make_dir(dst_dir, "path")
|
|
896
870
|
if os.path.isfile(src_dir):
|
|
@@ -914,13 +888,11 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
914
888
|
"but it is missing.".format(needed_rank, rank))
|
|
915
889
|
layout_map = _convert_to_list(src_strategy_dict)
|
|
916
890
|
|
|
917
|
-
total_size = 0
|
|
918
891
|
actual_params = set()
|
|
919
892
|
for _, file_name in all_safetensor_files_map.items():
|
|
920
|
-
total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
|
|
921
893
|
with safe_open(file_name, framework="np") as f:
|
|
922
894
|
actual_params.update(f.keys())
|
|
923
|
-
|
|
895
|
+
|
|
924
896
|
params_to_store = actual_params & set(layout_map.keys())
|
|
925
897
|
|
|
926
898
|
name_list = []
|
|
@@ -928,29 +900,55 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
928
900
|
if name.startswith("accu_grads"):
|
|
929
901
|
continue
|
|
930
902
|
name_list.append(name)
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
903
|
+
|
|
904
|
+
param_size_dict = {}
|
|
905
|
+
param_total_size = 0
|
|
906
|
+
for _, file_name in all_safetensor_files_map.items():
|
|
907
|
+
with safe_open(file_name, framework="np") as f:
|
|
908
|
+
for k in f.keys():
|
|
909
|
+
if k in name_list:
|
|
910
|
+
py_slice = f.get_slice(k)
|
|
911
|
+
param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
|
|
912
|
+
param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
|
|
913
|
+
# Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
|
|
914
|
+
param_dst_shape = [int(item) for item in param_dst_shape]
|
|
915
|
+
if choice_func is not None:
|
|
916
|
+
choice_out = choice_func(k)
|
|
917
|
+
if isinstance(choice_out, bool):
|
|
918
|
+
if not choice_out:
|
|
919
|
+
continue
|
|
920
|
+
if k not in param_size_dict:
|
|
921
|
+
param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
|
|
922
|
+
split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
|
|
923
|
+
split_num = min(split_num, len(name_list))
|
|
924
|
+
split_list = _split_weight_dict(param_size_dict, split_num)
|
|
925
|
+
|
|
926
|
+
if split_dst_file:
|
|
927
|
+
current_machine_num = split_dst_file[0]
|
|
928
|
+
total_machine_num = split_dst_file[1]
|
|
929
|
+
n = len(split_list)
|
|
930
|
+
avg_length = n // total_machine_num
|
|
931
|
+
remainder = n % total_machine_num
|
|
932
|
+
start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
|
|
933
|
+
end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
|
|
934
|
+
sub_list = []
|
|
935
|
+
for i in range(len(split_list)):
|
|
936
|
+
if start_index <= i < end_index:
|
|
937
|
+
sub_list.append(split_list[i])
|
|
938
|
+
else:
|
|
939
|
+
sub_list.append([-1])
|
|
940
|
+
else:
|
|
941
|
+
sub_list = split_list
|
|
942
|
+
|
|
943
|
+
_save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
|
|
944
|
+
_save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
|
|
945
|
+
|
|
946
|
+
if split_dst_file:
|
|
947
|
+
split_num = end_index - start_index
|
|
948
|
+
res = list(range(start_index, end_index))
|
|
949
|
+
else:
|
|
950
|
+
res = [i for i in range(split_num)]
|
|
951
|
+
max_process = min(split_num, max_process_num)
|
|
954
952
|
res = _split_list(res, max_process)
|
|
955
953
|
processes = []
|
|
956
954
|
src_strategy_name = None
|
|
@@ -960,7 +958,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
960
958
|
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
|
|
961
959
|
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
962
960
|
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
|
|
963
|
-
"", dst_dir, "safetensors", None,
|
|
961
|
+
"", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
|
|
964
962
|
p.start()
|
|
965
963
|
processes.append(p)
|
|
966
964
|
for p in processes:
|
|
@@ -974,13 +972,21 @@ def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor
|
|
|
974
972
|
origin_dst_strategy_list,
|
|
975
973
|
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
976
974
|
_transform_param_list, pipe_param_list=None, file_index=None,
|
|
977
|
-
unified_flag=False, src_strategy_file=None):
|
|
975
|
+
unified_flag=False, src_strategy_file=None, choice_func=None):
|
|
976
|
+
"""transform safetensors single semaphore"""
|
|
977
|
+
total_io_cost_time = 0
|
|
978
978
|
for i in file_index:
|
|
979
|
-
_transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
979
|
+
io_cost_time = _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
|
|
980
|
+
src_stage_device_num, dst_stage_device_num, src_strategy_dict,
|
|
981
|
+
dst_strategy_dict, origin_src_strategy_list,
|
|
982
|
+
origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir,
|
|
983
|
+
output_format, _transform_param_list, pipe_param_list[i], i,
|
|
984
|
+
unified_flag, src_strategy_file, choice_func)
|
|
985
|
+
while psutil.virtual_memory().percent > 50:
|
|
986
|
+
time.sleep(1)
|
|
987
|
+
total_io_cost_time += io_cost_time
|
|
988
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
989
|
+
f"Unified safetensors io cost time:{total_io_cost_time}.")
|
|
984
990
|
|
|
985
991
|
|
|
986
992
|
def _split_list(split_list, split_num):
|
|
@@ -1027,36 +1033,76 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
|
|
|
1027
1033
|
return sf_obj
|
|
1028
1034
|
|
|
1029
1035
|
|
|
1030
|
-
def
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1036
|
+
def _process_hyper_params(file_list, total_safetensors_dir, total_param):
|
|
1037
|
+
"""process hyper params"""
|
|
1038
|
+
if 'hyper_param.safetensors' in file_list:
|
|
1039
|
+
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
1040
|
+
with safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
1041
|
+
for key in f.keys():
|
|
1042
|
+
total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
|
|
1043
|
+
return total_param
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_files, dst_strategy_file, rank_id):
|
|
1047
|
+
"""calculate param_name_map and param_list"""
|
|
1048
|
+
if len(file_list) == 1:
|
|
1049
|
+
logger.info("There is only one weight file in the directory, which will be automatically mapped.")
|
|
1050
|
+
file_name = os.path.join(total_safetensors_dir, file_list[0])
|
|
1051
|
+
is_file = os.path.isfile(file_name)
|
|
1052
|
+
if not is_file:
|
|
1053
|
+
raise ValueError(f"For 'load_parallel_checkpoint', weight files must be included "
|
|
1054
|
+
f"in the `unified_safetensors_dir`.")
|
|
1055
|
+
with safe_open(file_name, framework="np") as f:
|
|
1056
|
+
keys = f.keys()
|
|
1057
|
+
values = len(keys) * [file_list[0]]
|
|
1058
|
+
param_name_map = dict(zip(keys, values))
|
|
1059
|
+
else:
|
|
1060
|
+
if not json_files:
|
|
1061
|
+
raise ValueError(
|
|
1062
|
+
f"For 'load_parallel_checkpoint', there must be a JSON file named 'param_name_map.json' in "
|
|
1063
|
+
f"the 'total_safetensors_dir'.")
|
|
1064
|
+
param_name_json = os.path.join(total_safetensors_dir, json_files[0])
|
|
1065
|
+
with open(param_name_json, 'r') as f:
|
|
1066
|
+
param_name_map = json.load(f)
|
|
1067
|
+
if "weight_map" in param_name_map:
|
|
1068
|
+
param_name_map = param_name_map["weight_map"]
|
|
1069
|
+
|
|
1041
1070
|
if dst_strategy_file is not None:
|
|
1042
1071
|
_, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
|
|
1043
1072
|
param_list = dst_strategy_list.keys()
|
|
1044
1073
|
else:
|
|
1045
1074
|
dst_strategy_list = None
|
|
1046
1075
|
param_list = param_name_map.keys()
|
|
1076
|
+
return param_name_map, param_list, dst_strategy_list
|
|
1077
|
+
|
|
1047
1078
|
|
|
1079
|
+
def _load_parallel_checkpoint(file_info):
|
|
1080
|
+
"""load parallel safetensors by merged file."""
|
|
1081
|
+
total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
|
|
1082
|
+
rank_id, output_format, name_map, return_param_dict = file_info
|
|
1083
|
+
pid = os.getpid()
|
|
1084
|
+
total_cores = os.cpu_count()
|
|
1085
|
+
all_cores = set(range(total_cores))
|
|
1086
|
+
os.sched_setaffinity(pid, all_cores)
|
|
1087
|
+
file_list = os.listdir(total_safetensors_dir)
|
|
1088
|
+
json_files = [file for file in file_list if file == "param_name_map.json"]
|
|
1089
|
+
param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
|
|
1090
|
+
json_files, dst_strategy_file,
|
|
1091
|
+
rank_id)
|
|
1048
1092
|
total_param = dict()
|
|
1049
1093
|
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
1050
1094
|
is not None else 1
|
|
1051
1095
|
local_rank_id = rank_id % dst_stage_device_num
|
|
1052
|
-
|
|
1096
|
+
total_io_cost_time = 0
|
|
1097
|
+
for param_name in _progress_bar(param_list):
|
|
1053
1098
|
if param_name not in param_name_map:
|
|
1054
1099
|
continue
|
|
1055
1100
|
file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
|
|
1056
1101
|
with safe_open(file_name, framework="np") as f:
|
|
1057
|
-
|
|
1102
|
+
cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
|
|
1103
|
+
if cur_param_name not in f.keys():
|
|
1058
1104
|
continue
|
|
1059
|
-
sf_obj = f.get_slice(
|
|
1105
|
+
sf_obj = f.get_slice(cur_param_name)
|
|
1060
1106
|
|
|
1061
1107
|
tensor_shape = sf_obj.get_shape()
|
|
1062
1108
|
from_dev_matrix = [1]
|
|
@@ -1078,6 +1124,9 @@ def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None
|
|
|
1078
1124
|
continue
|
|
1079
1125
|
origin_tensor_shape += (item * param_strategy[i],)
|
|
1080
1126
|
|
|
1127
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
1128
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
1129
|
+
|
|
1081
1130
|
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
1082
1131
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
1083
1132
|
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
@@ -1097,25 +1146,34 @@ def _load_parallel_checkpoint(total_safetensors_dir, dst_strategy_file, net=None
|
|
|
1097
1146
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
1098
1147
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
1099
1148
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
1149
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
|
|
1150
|
+
has_layout_from, has_layout_to)
|
|
1100
1151
|
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
|
|
1101
|
-
|
|
1152
|
+
start_time = time.time()
|
|
1102
1153
|
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
|
|
1154
|
+
end_time = time.time()
|
|
1155
|
+
cost_time = end_time - start_time
|
|
1156
|
+
total_io_cost_time += cost_time
|
|
1103
1157
|
else:
|
|
1158
|
+
start_time = time.time()
|
|
1104
1159
|
slice_param = sf_obj[:]
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
total_param[key] = ms.Parameter(f.get_tensor(key))
|
|
1160
|
+
end_time = time.time()
|
|
1161
|
+
cost_time = end_time - start_time
|
|
1162
|
+
total_io_cost_time += cost_time
|
|
1163
|
+
total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
|
|
1164
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1165
|
+
f"load distributed safetensors io cost time:{total_io_cost_time}.")
|
|
1166
|
+
total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
|
|
1113
1167
|
if net is not None:
|
|
1114
|
-
|
|
1115
|
-
|
|
1168
|
+
if not return_param_dict:
|
|
1169
|
+
logger.info("start load param into net...")
|
|
1170
|
+
param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
|
|
1171
|
+
logger.info("load param into net is end...")
|
|
1172
|
+
return param_not_load, ckpt_not_load
|
|
1173
|
+
return total_param
|
|
1116
1174
|
_make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
|
|
1117
|
-
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.
|
|
1118
|
-
format=
|
|
1175
|
+
ms.save_checkpoint(total_param, os.path.join(dst_safetensors_dir, f"rank_{rank_id}", f"net.{output_format}"),
|
|
1176
|
+
format=output_format)
|
|
1119
1177
|
return None
|
|
1120
1178
|
|
|
1121
1179
|
|
|
@@ -1143,4 +1201,4 @@ def _get_slice(rank_id, sf_obj, param_name, dst_strategy_list):
|
|
|
1143
1201
|
|
|
1144
1202
|
|
|
1145
1203
|
__all__ = ["_transform_safetensors", "transform_safetensors_by_stage",
|
|
1146
|
-
"transform_safetensors_by_rank", "
|
|
1204
|
+
"transform_safetensors_by_rank", "unified_safetensors"]
|