mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
import os
|
|
19
|
+
import sys
|
|
19
20
|
import glob
|
|
20
21
|
import math
|
|
21
22
|
import json
|
|
@@ -24,15 +25,17 @@ from collections import defaultdict
|
|
|
24
25
|
|
|
25
26
|
import time
|
|
26
27
|
import multiprocessing as mp
|
|
28
|
+
import psutil
|
|
27
29
|
import numpy as np
|
|
28
30
|
from safetensors.numpy import save_file, load_file
|
|
29
31
|
from safetensors import safe_open
|
|
30
32
|
|
|
31
33
|
import mindspore as ms
|
|
32
34
|
from mindspore import log as logger
|
|
35
|
+
from mindspore.log import vlog_print
|
|
33
36
|
from mindspore.parallel._parallel_serialization import _get_device_num_from_strategy, _make_dir, \
|
|
34
37
|
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
|
|
35
|
-
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src
|
|
38
|
+
_insert_opt_shard_reshape, _extract_src_dst_layout_map_by_src, _insert_expand_layout_reshape
|
|
36
39
|
from mindspore.parallel._tensor import _get_tensor_strategy, _construct_from_to_tensor_layout, \
|
|
37
40
|
_get_needed_rank_transform_operator_map_by_layouts, \
|
|
38
41
|
_generate_transform_operator_stack, _apply_tensor_transform_operators, _construct_tensor_layout_for_opt_shard, \
|
|
@@ -65,6 +68,7 @@ def _progress_bar(iterable, total=None):
|
|
|
65
68
|
elapsed_time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
|
|
66
69
|
remaining_time_str = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
|
|
67
70
|
|
|
71
|
+
sys.stdout.reconfigure(encoding="utf-8")
|
|
68
72
|
print(f'\r{percent}%|{bar}|[{elapsed_time_str}<{remaining_time_str}]', end='')
|
|
69
73
|
if iteration == total:
|
|
70
74
|
print()
|
|
@@ -285,8 +289,9 @@ def _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num):
|
|
|
285
289
|
|
|
286
290
|
|
|
287
291
|
def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict, saftensor_dict, redundancy_dict,
|
|
288
|
-
needed_rank, device_num):
|
|
292
|
+
needed_rank, device_num, choice_func):
|
|
289
293
|
"""Find the rank_id under redundant groups."""
|
|
294
|
+
io_time = 0
|
|
290
295
|
for param_name in pipe_param_list:
|
|
291
296
|
rank_num = int(needed_rank)
|
|
292
297
|
redundancy_ranks = _count_redundancy_list(rank_num, param_name, redundancy_dict, device_num)
|
|
@@ -299,11 +304,23 @@ def _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dic
|
|
|
299
304
|
open_file_id = real_rank
|
|
300
305
|
break
|
|
301
306
|
if open_file_id is not None:
|
|
302
|
-
|
|
307
|
+
start_time = time.time()
|
|
308
|
+
output = file_dict[open_file_id].get_slice(param_name)
|
|
309
|
+
end_time = time.time()
|
|
310
|
+
cost_time = end_time - start_time
|
|
311
|
+
io_time += cost_time
|
|
312
|
+
if choice_func is not None:
|
|
313
|
+
choice_out = choice_func(param_name)
|
|
314
|
+
if isinstance(choice_out, bool) and not choice_out:
|
|
315
|
+
continue
|
|
316
|
+
if not isinstance(choice_out, (bool, str)):
|
|
317
|
+
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
318
|
+
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
303
319
|
saftensor_dict[param_name] = output
|
|
304
320
|
else:
|
|
305
321
|
raise ValueError(f"For _transform_safetensors_single, {param_name} should be in "
|
|
306
322
|
f"{redundancy_ranks}, but in {single_param_dict[param_name]}.")
|
|
323
|
+
return io_time
|
|
307
324
|
|
|
308
325
|
|
|
309
326
|
def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map, src_stage_device_num,
|
|
@@ -316,9 +333,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
316
333
|
"""
|
|
317
334
|
Transforms safetensors files to a specified format without using parallel processing.
|
|
318
335
|
"""
|
|
336
|
+
io_cost_time = 0
|
|
319
337
|
if src_strategy_file is not None:
|
|
320
338
|
from mindspore.train._utils import get_parameter_redundancy
|
|
321
|
-
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file)
|
|
339
|
+
redundancy_dict_tmp = get_parameter_redundancy(src_strategy_file, initial_rank=0)
|
|
322
340
|
redundancy_dict = {}
|
|
323
341
|
device_num = 0
|
|
324
342
|
for param_name, redundancy in redundancy_dict_tmp.items():
|
|
@@ -352,8 +370,10 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
352
370
|
if pipe_param_list:
|
|
353
371
|
saftensor_dict = dict()
|
|
354
372
|
if src_strategy_file is not None:
|
|
355
|
-
_find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
356
|
-
|
|
373
|
+
io_time = _find_remove_redundancy_rank_id(pipe_param_list, single_param_dict, file_dict,
|
|
374
|
+
saftensor_dict, redundancy_dict, needed_rank,
|
|
375
|
+
device_num, choice_func)
|
|
376
|
+
io_cost_time += io_time
|
|
357
377
|
else:
|
|
358
378
|
with safe_open(all_safetensor_files_map.get(int(needed_rank)), framework="np") as f:
|
|
359
379
|
if not unified_flag:
|
|
@@ -362,25 +382,32 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
362
382
|
dst_param_name_set = set(dst_strategy_list_keys)
|
|
363
383
|
hyper_param_set = all_param_name_set - (src_param_name_set & dst_param_name_set)
|
|
364
384
|
pipe_param_list.extend(list(hyper_param_set))
|
|
385
|
+
io_time = 0
|
|
365
386
|
for param_name in pipe_param_list:
|
|
366
387
|
if param_name not in f.keys():
|
|
367
388
|
# param not in ckpt file, check reason
|
|
368
389
|
continue
|
|
369
|
-
|
|
370
|
-
|
|
390
|
+
start_time = time.time()
|
|
391
|
+
output = f.get_slice(param_name)
|
|
392
|
+
end_time = time.time()
|
|
393
|
+
cost_time = end_time - start_time
|
|
394
|
+
io_time += cost_time
|
|
395
|
+
io_cost_time += io_time
|
|
371
396
|
if choice_func is not None:
|
|
372
397
|
choice_out = choice_func(param_name)
|
|
373
|
-
if isinstance(choice_out, bool):
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
elif isinstance(choice_out, str):
|
|
377
|
-
save_param_name = choice_out
|
|
378
|
-
else:
|
|
398
|
+
if isinstance(choice_out, bool) and not choice_out:
|
|
399
|
+
continue
|
|
400
|
+
if not isinstance(choice_out, (bool, str)):
|
|
379
401
|
raise ValueError("For 'unified_safetensors', the return value type of the function "
|
|
380
402
|
f"'choice_func' must be bool or str, but got {type(choice_out)}.")
|
|
381
|
-
saftensor_dict[
|
|
403
|
+
saftensor_dict[param_name] = output
|
|
382
404
|
else:
|
|
405
|
+
start_time = time.time()
|
|
383
406
|
saftensor_dict = load_file(all_safetensor_files_map.get(int(needed_rank)))
|
|
407
|
+
end_time = time.time()
|
|
408
|
+
cost_time = end_time - start_time
|
|
409
|
+
io_cost_time += cost_time
|
|
410
|
+
|
|
384
411
|
for param_name, param in saftensor_dict.items():
|
|
385
412
|
src_rank = int(needed_rank) % src_stage_device_num
|
|
386
413
|
param_total_dict[param_name][src_rank] = param
|
|
@@ -399,7 +426,7 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
399
426
|
local_rank_id = transform_rank % dst_stage_device_num
|
|
400
427
|
transform_param_dict = _transform_parallel_safetensor(local_rank_id, param_total_dict,
|
|
401
428
|
param_attr_dict, src_strategy_list, dst_strategy_list,
|
|
402
|
-
param_total_dict_keys, src_strategy_file)
|
|
429
|
+
param_total_dict_keys, src_strategy_file, choice_func)
|
|
403
430
|
if file_index is not None:
|
|
404
431
|
save_safetensor_file = f"part{file_index}.{output_format}"
|
|
405
432
|
save_safetensor_file_dir = dst_safetensors_dir
|
|
@@ -413,15 +440,17 @@ def _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map
|
|
|
413
440
|
if _transform_param_list is not None:
|
|
414
441
|
_transform_param_list.append({save_file_name: transform_param_dict})
|
|
415
442
|
else:
|
|
416
|
-
if
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
443
|
+
if transform_param_dict:
|
|
444
|
+
if output_format == "safetensors":
|
|
445
|
+
save_file(transform_param_dict, save_file_name)
|
|
446
|
+
else:
|
|
447
|
+
transform_param_dict = _load_and_transform(transform_param_dict,
|
|
448
|
+
None, None, transform_func=
|
|
449
|
+
lambda v, name: ms.Parameter(v, name=name))
|
|
450
|
+
ms.save_checkpoint(transform_param_dict, save_file_name)
|
|
423
451
|
del param_total_dict_keys
|
|
424
452
|
del param_total_dict
|
|
453
|
+
return io_cost_time
|
|
425
454
|
|
|
426
455
|
|
|
427
456
|
def _save_final_safetensors(_transform_param_list, output_format):
|
|
@@ -552,6 +581,7 @@ def _extrace_number(file_name):
|
|
|
552
581
|
number_ls = [int(i) for i in number_ls]
|
|
553
582
|
return number_ls[-2:]
|
|
554
583
|
|
|
584
|
+
|
|
555
585
|
def _collect_safetensor_files(src_safetensors_dir, format='safetensors', file_suffix=None):
|
|
556
586
|
"""
|
|
557
587
|
Collects all safetensors files from the specified directory and its subdirectories.
|
|
@@ -589,7 +619,7 @@ def _find_needed_ranks(src_strategy_dict, dst_strategy_dict):
|
|
|
589
619
|
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
590
620
|
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
591
621
|
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
592
|
-
for rank in
|
|
622
|
+
for rank in range(dst_device_num):
|
|
593
623
|
needed_rank_list = ms.rank_list_for_transform(rank, src_strategy_dict, dst_strategy_dict)
|
|
594
624
|
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
|
|
595
625
|
needed_rank_list_map[needed_rank_list_key].append(rank)
|
|
@@ -605,7 +635,8 @@ def load_file_by_param_name(filename, parme_name_list):
|
|
|
605
635
|
|
|
606
636
|
|
|
607
637
|
def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, src_strategy_list,
|
|
608
|
-
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None
|
|
638
|
+
dst_strategy_list, param_total_dict_keys=None, src_strategy_file=None,
|
|
639
|
+
choice_func=None):
|
|
609
640
|
"""
|
|
610
641
|
Transform model parallel dimension for distributed safetensor files.
|
|
611
642
|
"""
|
|
@@ -613,7 +644,10 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
613
644
|
device_num = -1
|
|
614
645
|
param_total_dict_keys = list(param_total_dict.keys()) if param_total_dict_keys is None else param_total_dict_keys
|
|
615
646
|
for param_name in param_total_dict_keys:
|
|
616
|
-
|
|
647
|
+
if str(type(list(param_total_dict[param_name].values())[0])) == "<class 'builtins.PySafeSlice'>":
|
|
648
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].get_shape()
|
|
649
|
+
else:
|
|
650
|
+
tensor_shape = list(param_total_dict[param_name].values())[0].shape
|
|
617
651
|
from_dev_matrix = [1]
|
|
618
652
|
from_tensor_map = [-1] * len(tensor_shape)
|
|
619
653
|
from_opt_shard_step = 0
|
|
@@ -646,6 +680,9 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
646
680
|
continue
|
|
647
681
|
origin_tensor_shape += (item * param_strategy[i],)
|
|
648
682
|
|
|
683
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
684
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
685
|
+
|
|
649
686
|
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
650
687
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
651
688
|
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
@@ -665,22 +702,132 @@ def _transform_parallel_safetensor(rank_id, param_total_dict, param_attr_dict, s
|
|
|
665
702
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
666
703
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
667
704
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
705
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple, has_layout_from, has_layout_to)
|
|
668
706
|
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, rank_id)
|
|
669
707
|
param_total_dict_copy = param_total_dict[param_name].copy()
|
|
670
708
|
_apply_tensor_transform_operators(transform_operator_stack, param_total_dict_copy, device_num)
|
|
671
|
-
|
|
709
|
+
if choice_func is not None:
|
|
710
|
+
choice_out = choice_func(param_name)
|
|
711
|
+
if isinstance(choice_out, str):
|
|
712
|
+
param_name = choice_out
|
|
672
713
|
transform_param_dict[param_name] = param_total_dict_copy[rank_id % device_num]
|
|
714
|
+
if str(type(transform_param_dict[param_name])) == "<class 'builtins.PySafeSlice'>":
|
|
715
|
+
transform_param_dict[param_name] = transform_param_dict[param_name][:]
|
|
673
716
|
|
|
674
717
|
# Handle those parameter like learning_rate, global_step which not in strategy_file.
|
|
675
718
|
for param_name in param_total_dict_keys:
|
|
719
|
+
if choice_func is not None:
|
|
720
|
+
choice_out = choice_func(param_name)
|
|
721
|
+
if isinstance(choice_out, str):
|
|
722
|
+
continue
|
|
676
723
|
if param_name not in transform_param_dict:
|
|
677
724
|
transform_para = param_total_dict[param_name][rank_id % device_num]
|
|
725
|
+
if str(type(transform_para)) == "<class 'builtins.PySafeSlice'>":
|
|
726
|
+
transform_para = transform_para[:]
|
|
678
727
|
transform_param_dict[param_name] = transform_para
|
|
679
728
|
return transform_param_dict
|
|
680
729
|
|
|
681
730
|
|
|
731
|
+
def _cal_param_size(shape, dtype):
|
|
732
|
+
"""cal param size by dtype and shape"""
|
|
733
|
+
dtype_size = {
|
|
734
|
+
"BOOL": 1,
|
|
735
|
+
"U8": 1,
|
|
736
|
+
"I8": 1,
|
|
737
|
+
"F8_E5M2": 1,
|
|
738
|
+
"F8_E4M3": 1,
|
|
739
|
+
"I16": 2,
|
|
740
|
+
"U16": 2,
|
|
741
|
+
"I32": 4,
|
|
742
|
+
"U32": 4,
|
|
743
|
+
"I64": 8,
|
|
744
|
+
"U64": 8,
|
|
745
|
+
"F16": 2,
|
|
746
|
+
"BF16": 2,
|
|
747
|
+
"F32": 4,
|
|
748
|
+
"F64": 8,
|
|
749
|
+
}
|
|
750
|
+
num_elements = math.prod(shape)
|
|
751
|
+
element_size = dtype_size.get(dtype, 4)
|
|
752
|
+
total_bytes = num_elements * element_size
|
|
753
|
+
return total_bytes
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def _split_weight_dict(weights, num_groups):
|
|
757
|
+
"""split weights by num"""
|
|
758
|
+
sorted_items = sorted(weights.items(), key=lambda x: -x[1])
|
|
759
|
+
groups = [[] for _ in range(num_groups)]
|
|
760
|
+
total_bytes = [0] * num_groups
|
|
761
|
+
for weight_name, byte_size in sorted_items:
|
|
762
|
+
min_index = total_bytes.index(min(total_bytes))
|
|
763
|
+
groups[min_index].append(weight_name)
|
|
764
|
+
total_bytes[min_index] += byte_size
|
|
765
|
+
|
|
766
|
+
return groups
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir):
|
|
770
|
+
"""save hyper param"""
|
|
771
|
+
if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
|
|
772
|
+
with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
|
|
773
|
+
all_key = f.keys()
|
|
774
|
+
hyper_parameter = set(all_key) - set(name_list)
|
|
775
|
+
if hyper_parameter:
|
|
776
|
+
hyper_dict = {}
|
|
777
|
+
for key in hyper_parameter:
|
|
778
|
+
hyper_dict[key] = f.get_tensor(key)
|
|
779
|
+
save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size):
|
|
783
|
+
"""save parameter map json file"""
|
|
784
|
+
param_name_dict = dict()
|
|
785
|
+
for index, part_list in enumerate(split_list):
|
|
786
|
+
for name in part_list:
|
|
787
|
+
save_param_name = name
|
|
788
|
+
if choice_func is not None:
|
|
789
|
+
choice_out = choice_func(name)
|
|
790
|
+
if isinstance(choice_out, str):
|
|
791
|
+
save_param_name = choice_out
|
|
792
|
+
if save_param_name == -1:
|
|
793
|
+
break
|
|
794
|
+
param_name_dict[save_param_name] = f"part{index}.safetensors"
|
|
795
|
+
output_dict = {"metadata": {"total_size": param_total_size}, "weight_map": param_name_dict}
|
|
796
|
+
if not split_dst_file or (split_dst_file and split_dst_file[0] == 1):
|
|
797
|
+
json_str = json.dumps(output_dict, indent=4)
|
|
798
|
+
map_file = os.path.join(dst_dir, "param_name_map.json")
|
|
799
|
+
with open(map_file, 'w') as f:
|
|
800
|
+
f.write(json_str)
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
def _get_dst_shape(param_name, param_shape, src_strategy_list):
|
|
804
|
+
"""get dst shape by strategy"""
|
|
805
|
+
from_dev_matrix = [1]
|
|
806
|
+
from_tensor_map = [-1] * len(param_shape)
|
|
807
|
+
from_opt_shard_size = 0
|
|
808
|
+
if src_strategy_list is not None:
|
|
809
|
+
from_dev_matrix, from_tensor_map, _, from_opt_shard_size = _extract_layout_item(
|
|
810
|
+
src_strategy_list.get(param_name))
|
|
811
|
+
to_dev_matrix_origin = [1]
|
|
812
|
+
to_tensor_map_origin = [-1] * len(param_shape)
|
|
813
|
+
to_opt_shard_step = 0
|
|
814
|
+
to_opt_shard_size = 0
|
|
815
|
+
|
|
816
|
+
param_strategy = _get_tensor_strategy(from_dev_matrix, from_tensor_map)
|
|
817
|
+
origin_tensor_shape = ()
|
|
818
|
+
for i, item in enumerate(param_shape):
|
|
819
|
+
if i == 0 and from_opt_shard_size > 0:
|
|
820
|
+
origin_tensor_shape += (item * param_strategy[i] * from_opt_shard_size,)
|
|
821
|
+
continue
|
|
822
|
+
origin_tensor_shape += (item * param_strategy[i],)
|
|
823
|
+
|
|
824
|
+
_, _, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
825
|
+
to_dev_matrix_origin, to_tensor_map_origin, to_opt_shard_step, to_opt_shard_size, origin_tensor_shape)
|
|
826
|
+
return to_full_tensor_shape
|
|
827
|
+
|
|
828
|
+
|
|
682
829
|
def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundancy=True, file_suffix=None,
|
|
683
|
-
max_process_num=64, choice_func=None):
|
|
830
|
+
max_process_num=64, choice_func=None, split_dst_file=()):
|
|
684
831
|
"""
|
|
685
832
|
Merge multiple safetensor files into a unified safetensor file.
|
|
686
833
|
|
|
@@ -692,9 +839,14 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
692
839
|
saved safetensors files. Default: ``True``, indicating that the merged source weight files are complete.
|
|
693
840
|
file_suffix (str, optional): Specify the filename suffix for merging safetensors files. Default: ``None``,
|
|
694
841
|
meaning all safetensors files in the source weight directory will be merged.
|
|
695
|
-
max_process_num (int): Maximum number of processes. Default: 64
|
|
696
|
-
choice_func (callable): A callable function used to filter parameters or modify parameter names.
|
|
697
|
-
The return value of the function must be of type str (string) or bool (boolean). Default: None
|
|
842
|
+
max_process_num (int, optional): Maximum number of processes. Default: ``64``.
|
|
843
|
+
choice_func (callable, optional): A callable function used to filter parameters or modify parameter names.
|
|
844
|
+
The return value of the function must be of type str (string) or bool (boolean). Default: ``None``.
|
|
845
|
+
split_dst_file (tuple, optional) - A parameter used to manually split a task into multiple subtasks for
|
|
846
|
+
execution, represented as a tuple containing two elements. The first element indicates the number of
|
|
847
|
+
the current subtask, and the second element indicates the total number of tasks. This parameter supports
|
|
848
|
+
splitting and executing tasks multiple times on a single machine, and also supports executing different
|
|
849
|
+
subtasks on multiple machines respectively. Default: ``()``.
|
|
698
850
|
|
|
699
851
|
Raises:
|
|
700
852
|
ValueError: If the safetensors file of rank is missing.
|
|
@@ -707,8 +859,12 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
707
859
|
>>> src_dir = "/usr/safetensors/llama31B/4p_safetensors/"
|
|
708
860
|
>>> src_strategy_file = "/usr/safetensors/llama31B/strategy_4p.ckpt"
|
|
709
861
|
>>> dst_dir = "/usr/safetensors/llama31B/merge_llama31B_4p/"
|
|
710
|
-
>>> ms.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
862
|
+
>>> ms.parallel.unified_safetensors(src_dir, src_strategy_file, dst_dir)
|
|
711
863
|
"""
|
|
864
|
+
pid = os.getpid()
|
|
865
|
+
total_cores = os.cpu_count()
|
|
866
|
+
all_cores = set(range(total_cores))
|
|
867
|
+
os.sched_setaffinity(pid, all_cores)
|
|
712
868
|
_check_transform_safetensors(src_dir, "", src_strategy_file, None)
|
|
713
869
|
_make_dir(dst_dir, "path")
|
|
714
870
|
if os.path.isfile(src_dir):
|
|
@@ -732,13 +888,11 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
732
888
|
"but it is missing.".format(needed_rank, rank))
|
|
733
889
|
layout_map = _convert_to_list(src_strategy_dict)
|
|
734
890
|
|
|
735
|
-
total_size = 0
|
|
736
891
|
actual_params = set()
|
|
737
892
|
for _, file_name in all_safetensor_files_map.items():
|
|
738
|
-
total_size += os.path.getsize(file_name) / 1024 / 1024 / 1024
|
|
739
893
|
with safe_open(file_name, framework="np") as f:
|
|
740
894
|
actual_params.update(f.keys())
|
|
741
|
-
|
|
895
|
+
|
|
742
896
|
params_to_store = actual_params & set(layout_map.keys())
|
|
743
897
|
|
|
744
898
|
name_list = []
|
|
@@ -746,37 +900,55 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
746
900
|
if name.startswith("accu_grads"):
|
|
747
901
|
continue
|
|
748
902
|
name_list.append(name)
|
|
749
|
-
split_list = _split_list(name_list, split_num)
|
|
750
|
-
|
|
751
|
-
with safe_open(all_safetensor_files_map.get(0), framework="np") as f:
|
|
752
|
-
all_key = f.keys()
|
|
753
|
-
hyper_parameter = set(all_key) - set(name_list)
|
|
754
|
-
if hyper_parameter:
|
|
755
|
-
hyper_dict = {}
|
|
756
|
-
for key in hyper_parameter:
|
|
757
|
-
hyper_dict[key] = f.get_tensor(key)
|
|
758
|
-
save_file(hyper_dict, os.path.join(dst_dir, "hyper_param.safetensors"))
|
|
759
|
-
|
|
760
|
-
# save parameter map json
|
|
761
|
-
param_name_dict = dict()
|
|
762
|
-
for index, part_list in enumerate(split_list):
|
|
763
|
-
for name in part_list:
|
|
764
|
-
save_param_name = name
|
|
765
|
-
if choice_func is not None:
|
|
766
|
-
choice_out = choice_func(name)
|
|
767
|
-
if isinstance(choice_out, bool):
|
|
768
|
-
if not choice_out:
|
|
769
|
-
continue
|
|
770
|
-
elif isinstance(choice_out, str):
|
|
771
|
-
save_param_name = choice_out
|
|
772
|
-
param_name_dict[save_param_name] = f"part{index}.safetensors"
|
|
773
|
-
json_str = json.dumps(param_name_dict, indent=4)
|
|
774
|
-
map_file = os.path.join(dst_dir, "param_name_map.json")
|
|
775
|
-
with open(map_file, 'w') as f:
|
|
776
|
-
f.write(json_str)
|
|
777
903
|
|
|
904
|
+
param_size_dict = {}
|
|
905
|
+
param_total_size = 0
|
|
906
|
+
for _, file_name in all_safetensor_files_map.items():
|
|
907
|
+
with safe_open(file_name, framework="np") as f:
|
|
908
|
+
for k in f.keys():
|
|
909
|
+
if k in name_list:
|
|
910
|
+
py_slice = f.get_slice(k)
|
|
911
|
+
param_total_size += _cal_param_size(py_slice.get_shape(), py_slice.get_dtype())
|
|
912
|
+
param_dst_shape = _get_dst_shape(k, py_slice.get_shape(), origin_src_strategy_list)
|
|
913
|
+
# Convert the shape of np.int32 type to int type to prevent overflow in subsequent calculations.
|
|
914
|
+
param_dst_shape = [int(item) for item in param_dst_shape]
|
|
915
|
+
if choice_func is not None:
|
|
916
|
+
choice_out = choice_func(k)
|
|
917
|
+
if isinstance(choice_out, bool):
|
|
918
|
+
if not choice_out:
|
|
919
|
+
continue
|
|
920
|
+
if k not in param_size_dict:
|
|
921
|
+
param_size_dict[k] = _cal_param_size(param_dst_shape, py_slice.get_dtype())
|
|
922
|
+
split_num = math.ceil(sum(param_size_dict.values()) / 1024 / 1024 / 1024 / 3)
|
|
923
|
+
split_num = min(split_num, len(name_list))
|
|
924
|
+
split_list = _split_weight_dict(param_size_dict, split_num)
|
|
925
|
+
|
|
926
|
+
if split_dst_file:
|
|
927
|
+
current_machine_num = split_dst_file[0]
|
|
928
|
+
total_machine_num = split_dst_file[1]
|
|
929
|
+
n = len(split_list)
|
|
930
|
+
avg_length = n // total_machine_num
|
|
931
|
+
remainder = n % total_machine_num
|
|
932
|
+
start_index = (avg_length * (current_machine_num - 1)) + min(current_machine_num - 1, remainder)
|
|
933
|
+
end_index = start_index + avg_length + (1 if current_machine_num <= remainder else 0)
|
|
934
|
+
sub_list = []
|
|
935
|
+
for i in range(len(split_list)):
|
|
936
|
+
if start_index <= i < end_index:
|
|
937
|
+
sub_list.append(split_list[i])
|
|
938
|
+
else:
|
|
939
|
+
sub_list.append([-1])
|
|
940
|
+
else:
|
|
941
|
+
sub_list = split_list
|
|
942
|
+
|
|
943
|
+
_save_hyper_param(split_dst_file, all_safetensor_files_map, name_list, dst_dir)
|
|
944
|
+
_save_parameter_map_json(split_list, choice_func, split_dst_file, dst_dir, param_total_size)
|
|
945
|
+
|
|
946
|
+
if split_dst_file:
|
|
947
|
+
split_num = end_index - start_index
|
|
948
|
+
res = list(range(start_index, end_index))
|
|
949
|
+
else:
|
|
950
|
+
res = [i for i in range(split_num)]
|
|
778
951
|
max_process = min(split_num, max_process_num)
|
|
779
|
-
res = [i for i in range(split_num)]
|
|
780
952
|
res = _split_list(res, max_process)
|
|
781
953
|
processes = []
|
|
782
954
|
src_strategy_name = None
|
|
@@ -786,7 +958,7 @@ def unified_safetensors(src_dir, src_strategy_file, dst_dir, merge_with_redundan
|
|
|
786
958
|
p = mp.Process(target=_transform_safetensors_single_semaphore, args=(
|
|
787
959
|
needed_rank_list_map, all_safetensor_files_map, src_stage_device_num, dst_stage_device_num,
|
|
788
960
|
src_strategy_dict, None, origin_src_strategy_list, origin_dst_strategy_list,
|
|
789
|
-
"", dst_dir, "safetensors", None,
|
|
961
|
+
"", dst_dir, "safetensors", None, sub_list, res[i], True, src_strategy_name, choice_func))
|
|
790
962
|
p.start()
|
|
791
963
|
processes.append(p)
|
|
792
964
|
for p in processes:
|
|
@@ -801,13 +973,20 @@ def _transform_safetensors_single_semaphore(needed_rank_list_map, all_safetensor
|
|
|
801
973
|
ckpt_prefix, dst_safetensors_dir, output_format,
|
|
802
974
|
_transform_param_list, pipe_param_list=None, file_index=None,
|
|
803
975
|
unified_flag=False, src_strategy_file=None, choice_func=None):
|
|
976
|
+
"""transform safetensors single semaphore"""
|
|
977
|
+
total_io_cost_time = 0
|
|
804
978
|
for i in file_index:
|
|
805
|
-
_transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
979
|
+
io_cost_time = _transform_safetensors_single(needed_rank_list_map, all_safetensor_files_map,
|
|
980
|
+
src_stage_device_num, dst_stage_device_num, src_strategy_dict,
|
|
981
|
+
dst_strategy_dict, origin_src_strategy_list,
|
|
982
|
+
origin_dst_strategy_list, ckpt_prefix, dst_safetensors_dir,
|
|
983
|
+
output_format, _transform_param_list, pipe_param_list[i], i,
|
|
984
|
+
unified_flag, src_strategy_file, choice_func)
|
|
985
|
+
while psutil.virtual_memory().percent > 50:
|
|
986
|
+
time.sleep(1)
|
|
987
|
+
total_io_cost_time += io_cost_time
|
|
988
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
989
|
+
f"Unified safetensors io cost time:{total_io_cost_time}.")
|
|
811
990
|
|
|
812
991
|
|
|
813
992
|
def _split_list(split_list, split_num):
|
|
@@ -854,22 +1033,13 @@ def _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_n
|
|
|
854
1033
|
return sf_obj
|
|
855
1034
|
|
|
856
1035
|
|
|
857
|
-
def
|
|
858
|
-
"""check input is bool"""
|
|
859
|
-
if not isinstance(value, str):
|
|
860
|
-
raise ValueError(
|
|
861
|
-
f"For 'load_distributed_checkpoint', the value of name_map must be str, but got {type(value)}.")
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
def _process_hyper_params(file_list, total_safetensors_dir, name_map, total_param):
|
|
1036
|
+
def _process_hyper_params(file_list, total_safetensors_dir, total_param):
|
|
865
1037
|
"""process hyper params"""
|
|
866
1038
|
if 'hyper_param.safetensors' in file_list:
|
|
867
1039
|
hyper_parameter_file_name = os.path.join(total_safetensors_dir, "hyper_param.safetensors")
|
|
868
1040
|
with safe_open(hyper_parameter_file_name, framework="np") as f:
|
|
869
1041
|
for key in f.keys():
|
|
870
|
-
|
|
871
|
-
_check_name_map_value_is_str(cur_param_name)
|
|
872
|
-
total_param[cur_param_name] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
|
|
1042
|
+
total_param[key] = ms.Parameter(ms.Tensor.from_numpy(f.get_tensor(key)))
|
|
873
1043
|
return total_param
|
|
874
1044
|
|
|
875
1045
|
|
|
@@ -887,12 +1057,15 @@ def _cal_param_name_map_and_param_list(file_list, total_safetensors_dir, json_fi
|
|
|
887
1057
|
values = len(keys) * [file_list[0]]
|
|
888
1058
|
param_name_map = dict(zip(keys, values))
|
|
889
1059
|
else:
|
|
890
|
-
if
|
|
891
|
-
raise ValueError(
|
|
892
|
-
|
|
1060
|
+
if not json_files:
|
|
1061
|
+
raise ValueError(
|
|
1062
|
+
f"For 'load_parallel_checkpoint', there must be a JSON file named 'param_name_map.json' in "
|
|
1063
|
+
f"the 'total_safetensors_dir'.")
|
|
893
1064
|
param_name_json = os.path.join(total_safetensors_dir, json_files[0])
|
|
894
1065
|
with open(param_name_json, 'r') as f:
|
|
895
1066
|
param_name_map = json.load(f)
|
|
1067
|
+
if "weight_map" in param_name_map:
|
|
1068
|
+
param_name_map = param_name_map["weight_map"]
|
|
896
1069
|
|
|
897
1070
|
if dst_strategy_file is not None:
|
|
898
1071
|
_, dst_strategy_list = _extract_src_dst_layout_map(rank_id, None, dst_strategy_file)
|
|
@@ -907,8 +1080,12 @@ def _load_parallel_checkpoint(file_info):
|
|
|
907
1080
|
"""load parallel safetensors by merged file."""
|
|
908
1081
|
total_safetensors_dir, dst_strategy_file, net, dst_safetensors_dir, \
|
|
909
1082
|
rank_id, output_format, name_map, return_param_dict = file_info
|
|
1083
|
+
pid = os.getpid()
|
|
1084
|
+
total_cores = os.cpu_count()
|
|
1085
|
+
all_cores = set(range(total_cores))
|
|
1086
|
+
os.sched_setaffinity(pid, all_cores)
|
|
910
1087
|
file_list = os.listdir(total_safetensors_dir)
|
|
911
|
-
json_files = [file for file in file_list if file.
|
|
1088
|
+
json_files = [file for file in file_list if file == "param_name_map.json"]
|
|
912
1089
|
param_name_map, param_list, dst_strategy_list = _cal_param_name_map_and_param_list(file_list, total_safetensors_dir,
|
|
913
1090
|
json_files, dst_strategy_file,
|
|
914
1091
|
rank_id)
|
|
@@ -916,14 +1093,16 @@ def _load_parallel_checkpoint(file_info):
|
|
|
916
1093
|
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
|
|
917
1094
|
is not None else 1
|
|
918
1095
|
local_rank_id = rank_id % dst_stage_device_num
|
|
919
|
-
|
|
1096
|
+
total_io_cost_time = 0
|
|
1097
|
+
for param_name in _progress_bar(param_list):
|
|
920
1098
|
if param_name not in param_name_map:
|
|
921
1099
|
continue
|
|
922
1100
|
file_name = os.path.join(total_safetensors_dir, param_name_map[param_name])
|
|
923
1101
|
with safe_open(file_name, framework="np") as f:
|
|
924
|
-
|
|
1102
|
+
cur_param_name = name_map.get(param_name) if name_map is not None and param_name in name_map else param_name
|
|
1103
|
+
if cur_param_name not in f.keys():
|
|
925
1104
|
continue
|
|
926
|
-
sf_obj = f.get_slice(
|
|
1105
|
+
sf_obj = f.get_slice(cur_param_name)
|
|
927
1106
|
|
|
928
1107
|
tensor_shape = sf_obj.get_shape()
|
|
929
1108
|
from_dev_matrix = [1]
|
|
@@ -945,6 +1124,9 @@ def _load_parallel_checkpoint(file_info):
|
|
|
945
1124
|
continue
|
|
946
1125
|
origin_tensor_shape += (item * param_strategy[i],)
|
|
947
1126
|
|
|
1127
|
+
has_layout_from = any(isinstance(i, (list, tuple)) for i in from_tensor_map)
|
|
1128
|
+
has_layout_to = any(isinstance(i, (list, tuple)) for i in to_tensor_map_origin)
|
|
1129
|
+
|
|
948
1130
|
from_dev_matrix, from_tensor_map, from_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
949
1131
|
from_dev_matrix, from_tensor_map, from_opt_shard_step, from_opt_shard_size, origin_tensor_shape)
|
|
950
1132
|
to_dev_matrix, to_tensor_map, to_full_tensor_shape = _construct_tensor_layout_for_opt_shard(
|
|
@@ -964,19 +1146,29 @@ def _load_parallel_checkpoint(file_info):
|
|
|
964
1146
|
from_info_tuple = (from_opt_shard_size, from_dev_matrix, from_tensor_map, from_full_tensor_shape)
|
|
965
1147
|
to_info_tuple = (to_opt_shard_size, to_dev_matrix_origin, to_tensor_map_origin, origin_tensor_shape)
|
|
966
1148
|
_insert_opt_shard_reshape(param_rank_map, from_info_tuple, to_info_tuple)
|
|
1149
|
+
_insert_expand_layout_reshape(param_rank_map, from_info_tuple, to_info_tuple,
|
|
1150
|
+
has_layout_from, has_layout_to)
|
|
967
1151
|
transform_operator_stack = _generate_transform_operator_stack(param_rank_map, local_rank_id)
|
|
968
|
-
|
|
1152
|
+
start_time = time.time()
|
|
969
1153
|
slice_param = _apply_sf_obj_transform_operators(transform_operator_stack, sf_obj, device_num)
|
|
1154
|
+
end_time = time.time()
|
|
1155
|
+
cost_time = end_time - start_time
|
|
1156
|
+
total_io_cost_time += cost_time
|
|
970
1157
|
else:
|
|
1158
|
+
start_time = time.time()
|
|
971
1159
|
slice_param = sf_obj[:]
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
1160
|
+
end_time = time.time()
|
|
1161
|
+
cost_time = end_time - start_time
|
|
1162
|
+
total_io_cost_time += cost_time
|
|
1163
|
+
total_param[param_name] = ms.Parameter(ms.Tensor.from_numpy(slice_param))
|
|
1164
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1165
|
+
f"load distributed safetensors io cost time:{total_io_cost_time}.")
|
|
1166
|
+
total_param = _process_hyper_params(file_list, total_safetensors_dir, total_param)
|
|
977
1167
|
if net is not None:
|
|
978
1168
|
if not return_param_dict:
|
|
1169
|
+
logger.info("start load param into net...")
|
|
979
1170
|
param_not_load, ckpt_not_load = ms.load_param_into_net(net, total_param)
|
|
1171
|
+
logger.info("load param into net is end...")
|
|
980
1172
|
return param_not_load, ckpt_not_load
|
|
981
1173
|
return total_param
|
|
982
1174
|
_make_dir(os.path.join(dst_safetensors_dir, f"rank_{rank_id}"), "path")
|
mindspore/pgodb140.dll
CHANGED
|
Binary file
|
mindspore/pgort140.dll
CHANGED
|
Binary file
|