mindspore 2.5.0__cp39-cp39-win_amd64.whl → 2.6.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -27,9 +27,9 @@ import stat
|
|
|
27
27
|
import atexit
|
|
28
28
|
import threading
|
|
29
29
|
from threading import Thread, RLock
|
|
30
|
-
from multiprocessing import
|
|
30
|
+
from multiprocessing import active_children
|
|
31
31
|
import multiprocessing as mp
|
|
32
|
-
from collections import
|
|
32
|
+
from collections import OrderedDict
|
|
33
33
|
from io import BytesIO
|
|
34
34
|
|
|
35
35
|
import math
|
|
@@ -53,37 +53,33 @@ from mindspore.log import vlog_print
|
|
|
53
53
|
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
54
54
|
from mindspore import _checkparam as Validator
|
|
55
55
|
from mindspore.common import dtype as mstype
|
|
56
|
+
from mindspore.common import np_dtype
|
|
56
57
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
57
|
-
from mindspore.common.api import
|
|
58
|
+
from mindspore.common.api import _JitExecutor
|
|
58
59
|
from mindspore.common.api import _get_parameter_layout
|
|
59
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
60
60
|
from mindspore.common.initializer import initializer, One
|
|
61
61
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
62
62
|
from mindspore.common.tensor import Tensor
|
|
63
|
-
from mindspore._c_expression import
|
|
63
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
64
64
|
from mindspore.common._utils import is_shape_unknown
|
|
65
65
|
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
66
66
|
from mindspore.communication.management import get_rank, get_group_size
|
|
67
67
|
from mindspore.experimental import MapParameter
|
|
68
68
|
from mindspore.ops import Cast
|
|
69
69
|
from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
|
|
70
|
-
from mindspore.parallel._tensor import
|
|
71
|
-
from mindspore.parallel.
|
|
72
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
|
|
73
|
-
_get_device_num
|
|
74
|
-
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
75
|
-
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
76
|
-
_restore_group_info_list, _get_param_list_when_first_dim_sharded
|
|
70
|
+
from mindspore.parallel._tensor import _reshape_param_data
|
|
71
|
+
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|
77
72
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
78
73
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
79
74
|
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
80
|
-
from mindspore.parallel.
|
|
81
|
-
|
|
75
|
+
from mindspore.parallel.checkpoint_transform import restore_group_info_list as new_restore_group_info_list
|
|
76
|
+
from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
|
|
77
|
+
from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
|
|
78
|
+
from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
|
|
82
79
|
from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
|
|
83
|
-
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file,
|
|
80
|
+
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
|
|
84
81
|
split_mindir, split_dynamic_mindir
|
|
85
82
|
from mindspore.common.generator import Generator
|
|
86
|
-
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
87
83
|
|
|
88
84
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
89
85
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
@@ -94,6 +90,9 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
|
|
|
94
90
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
95
91
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
96
92
|
|
|
93
|
+
if hasattr(np_dtype, "bfloat16"):
|
|
94
|
+
tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
|
|
95
|
+
|
|
97
96
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
98
97
|
|
|
99
98
|
mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
|
|
@@ -153,22 +152,28 @@ atexit.register(_async_save_close)
|
|
|
153
152
|
|
|
154
153
|
def _get_cur_rank_dp(parameter_layout_dict):
|
|
155
154
|
""" Get dp and tp from layout dict. """
|
|
156
|
-
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
157
|
-
dev_num = _get_device_num()
|
|
158
155
|
global_rank = get_rank()
|
|
159
|
-
|
|
160
|
-
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
161
|
-
parameter_redundancy_dict = get_parameter_redundancy(
|
|
162
|
-
parameter_layout_dict, initial_rank)
|
|
156
|
+
parameter_redundancy_dict = get_parameter_redundancy(parameter_layout_dict)
|
|
163
157
|
value_len = sys.maxsize
|
|
164
158
|
min_value = ()
|
|
159
|
+
min_value_set = set()
|
|
165
160
|
for key, value in parameter_redundancy_dict.items():
|
|
166
|
-
if "accu_grads"
|
|
161
|
+
if key.startswith("accu_grads") or key.startswith("inputs"):
|
|
167
162
|
continue
|
|
168
163
|
for item in value:
|
|
169
|
-
if
|
|
164
|
+
if global_rank not in item:
|
|
165
|
+
continue
|
|
166
|
+
# if item is subset of min_value_set, update min_value_set and min_value
|
|
167
|
+
if len(item) < value_len:
|
|
168
|
+
if min_value_set and not set(item).issubset(min_value_set):
|
|
169
|
+
return (global_rank,)
|
|
170
170
|
value_len = len(item)
|
|
171
|
+
min_value_set = set(item)
|
|
171
172
|
min_value = item
|
|
173
|
+
# if value is not smaller than len of min_value len,
|
|
174
|
+
# check if min_value_set is subset of current item
|
|
175
|
+
elif not min_value_set.issubset(set(item)):
|
|
176
|
+
return (global_rank,)
|
|
172
177
|
return min_value
|
|
173
178
|
|
|
174
179
|
|
|
@@ -188,7 +193,7 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
|
188
193
|
cur_strategy_path (str): strategy file path for current rank.
|
|
189
194
|
|
|
190
195
|
Returns:
|
|
191
|
-
- new_ckpt_file (
|
|
196
|
+
- new_ckpt_file (str), if found available checkpoint file , return it.
|
|
192
197
|
- None, if not found available checkpoint, return None.
|
|
193
198
|
|
|
194
199
|
Examples:
|
|
@@ -203,6 +208,9 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
|
203
208
|
>>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
|
|
204
209
|
>>> print(ckpt_file_new)
|
|
205
210
|
"""
|
|
211
|
+
cur_rank = get_rank()
|
|
212
|
+
if f"rank_{str(cur_rank)}" in cur_ckpt_path and os.path.isfile(cur_ckpt_path):
|
|
213
|
+
return cur_ckpt_path
|
|
206
214
|
dp = _get_cur_rank_dp(cur_strategy_path)
|
|
207
215
|
pattern = r'rank_\d+'
|
|
208
216
|
for i in dp:
|
|
@@ -358,6 +366,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
358
366
|
file_name_list = list(os.path.splitext(ckpt_file_name))
|
|
359
367
|
file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
|
|
360
368
|
tmp_name = ''.join(file_name_list)
|
|
369
|
+
if _ckpt_fs.backend == "mindio":
|
|
370
|
+
tmp_name = ckpt_file_name
|
|
361
371
|
if os.path.exists(ckpt_file_name):
|
|
362
372
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
363
373
|
os.remove(ckpt_file_name)
|
|
@@ -365,7 +375,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
365
375
|
os.chmod(tmp_name, stat.S_IWUSR)
|
|
366
376
|
os.remove(tmp_name)
|
|
367
377
|
if format == "ckpt":
|
|
368
|
-
|
|
378
|
+
ckpt_total_io_time = 0
|
|
369
379
|
with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
|
|
370
380
|
plain_data = None
|
|
371
381
|
if enc_key is not None:
|
|
@@ -382,20 +392,26 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
382
392
|
if value[0] == "offload_parameter":
|
|
383
393
|
new_value = value[1:]
|
|
384
394
|
new_value[2] = value[3]
|
|
385
|
-
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
395
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data, ckpt_total_io_time)
|
|
386
396
|
_offload_if_config(value[3])
|
|
387
397
|
continue
|
|
388
398
|
if value[1] == "str":
|
|
389
|
-
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
399
|
+
crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
400
|
+
crc_num, crc_check,
|
|
401
|
+
ckpt_total_io_time)
|
|
390
402
|
continue
|
|
391
403
|
if isinstance(value[2], np.ndarray):
|
|
392
|
-
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
404
|
+
crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
405
|
+
crc_num, crc_check,
|
|
406
|
+
ckpt_total_io_time)
|
|
393
407
|
continue
|
|
394
408
|
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
395
409
|
_write_hugeparameter(name, value, f)
|
|
396
410
|
continue
|
|
397
411
|
|
|
398
|
-
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
412
|
+
crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
413
|
+
crc_num, crc_check,
|
|
414
|
+
ckpt_total_io_time)
|
|
399
415
|
|
|
400
416
|
if enc_key is not None:
|
|
401
417
|
plain_data.seek(0)
|
|
@@ -406,15 +422,22 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
406
422
|
block_data = plain_data.read(max_block_size)
|
|
407
423
|
if crc_check:
|
|
408
424
|
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
425
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
426
|
+
f"Save ckpt io cost time:{ckpt_total_io_time}.")
|
|
427
|
+
|
|
412
428
|
elif format == "safetensors":
|
|
413
429
|
save_dict = {}
|
|
414
430
|
crc_num = 0
|
|
415
431
|
for name in sorted(data_list.keys()):
|
|
416
432
|
value = data_list[name]
|
|
417
|
-
|
|
433
|
+
if isinstance(value[2], np.ndarray):
|
|
434
|
+
save_dict[name] = value[2]
|
|
435
|
+
else:
|
|
436
|
+
bytes_data = value[2].get_bytes()
|
|
437
|
+
np_type = tensor_to_np_type.get(value[1])
|
|
438
|
+
np_array = np.frombuffer(bytes_data, np_type)
|
|
439
|
+
new_np_array = np_array.reshape(value[0])
|
|
440
|
+
save_dict[name] = new_np_array
|
|
418
441
|
|
|
419
442
|
if crc_check:
|
|
420
443
|
crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
|
|
@@ -428,11 +451,11 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
428
451
|
save_file(save_dict, tmp_name)
|
|
429
452
|
safetensors_save_time_end = time.time()
|
|
430
453
|
cost_time = safetensors_save_time_end - safetensors_save_time_start
|
|
431
|
-
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors cost time:{cost_time}.")
|
|
454
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
|
|
432
455
|
if not os.path.exists(tmp_name):
|
|
433
456
|
logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
|
|
434
457
|
f"simultaneously modified a file.")
|
|
435
|
-
|
|
458
|
+
elif _ckpt_fs.backend != "mindio":
|
|
436
459
|
os.rename(tmp_name, ckpt_file_name)
|
|
437
460
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
438
461
|
except BaseException as e:
|
|
@@ -453,7 +476,7 @@ def _write_random_seed(name, value, f):
|
|
|
453
476
|
f.write(checkpoint_list.SerializeToString())
|
|
454
477
|
|
|
455
478
|
|
|
456
|
-
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
479
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
|
|
457
480
|
"""Write parameter data into protobuf file."""
|
|
458
481
|
data_size = value[2].nbytes / 1024
|
|
459
482
|
if data_size > SLICE_SIZE:
|
|
@@ -475,14 +498,18 @@ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_ch
|
|
|
475
498
|
output_data = checkpoint_list.SerializeToString()
|
|
476
499
|
if crc_check:
|
|
477
500
|
crc_num = binascii.crc32(output_data, crc_num)
|
|
501
|
+
io_start_time = time.time()
|
|
478
502
|
f.write(output_data)
|
|
503
|
+
io_end_time = time.time()
|
|
504
|
+
io_cost_time = io_end_time - io_start_time
|
|
505
|
+
ckpt_total_io_time += io_cost_time
|
|
479
506
|
else:
|
|
480
507
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
481
508
|
|
|
482
|
-
return crc_num
|
|
509
|
+
return crc_num, ckpt_total_io_time
|
|
483
510
|
|
|
484
511
|
|
|
485
|
-
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
512
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
|
|
486
513
|
"""Write parameter bytes data into protobuf file."""
|
|
487
514
|
bytes_value = value[2].get_bytes()
|
|
488
515
|
chunk_size = 1024 * SLICE_SIZE
|
|
@@ -500,11 +527,15 @@ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0,
|
|
|
500
527
|
output_data = checkpoint_list.SerializeToString()
|
|
501
528
|
if crc_check:
|
|
502
529
|
crc_num = binascii.crc32(output_data, crc_num)
|
|
530
|
+
io_start_time = time.time()
|
|
503
531
|
f.write(output_data)
|
|
532
|
+
io_end_time = time.time()
|
|
533
|
+
io_cost_time = io_end_time - io_start_time
|
|
534
|
+
ckpt_total_io_time += io_cost_time
|
|
504
535
|
else:
|
|
505
536
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
506
537
|
|
|
507
|
-
return crc_num
|
|
538
|
+
return crc_num, ckpt_total_io_time
|
|
508
539
|
|
|
509
540
|
|
|
510
541
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
@@ -583,15 +614,13 @@ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
|
|
|
583
614
|
f"be set to default value '{default_value}', but got '{current_value}'.")
|
|
584
615
|
|
|
585
616
|
|
|
586
|
-
def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode,
|
|
587
|
-
global_step_num=None):
|
|
617
|
+
def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
|
|
588
618
|
"""check save checkpoint unsupported param"""
|
|
589
619
|
if format != "safetensors":
|
|
590
620
|
return
|
|
591
621
|
default_params = {
|
|
592
622
|
"enc_key": None,
|
|
593
623
|
"enc_mode": "AES-GCM",
|
|
594
|
-
"async_save": False,
|
|
595
624
|
"map_param_inc": False,
|
|
596
625
|
"global_step_num": None
|
|
597
626
|
}
|
|
@@ -633,15 +662,18 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
633
662
|
|
|
634
663
|
Args:
|
|
635
664
|
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
636
|
-
list, or dict.
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
665
|
+
list, or dict.
|
|
666
|
+
|
|
667
|
+
- If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
668
|
+
elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
|
|
669
|
+
`param_name` must be string, and the type of `param_data` must be parameter or Tensor).
|
|
670
|
+
- If dict, it can be the returned value of :func:`mindspore.load_checkpoint`.
|
|
671
|
+
|
|
640
672
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
|
641
673
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
|
|
642
|
-
async_save (Union[bool, str]): Whether to use asynchronous saving of the checkpoint file
|
|
643
|
-
the asynchronous thread is used by default. If the type
|
|
644
|
-
the method of asynchronous saving, it can be "process" or "thread".
|
|
674
|
+
async_save (Union[bool, str], optional): Whether to use asynchronous saving of the checkpoint file or
|
|
675
|
+
safetensors file, if True, the asynchronous thread is used by default. If the type
|
|
676
|
+
is string, the method of asynchronous saving, it can be "process" or "thread".
|
|
645
677
|
Default: ``False`` .
|
|
646
678
|
append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
|
|
647
679
|
of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
|
|
@@ -652,9 +684,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
652
684
|
Default: ``"AES-GCM"`` .
|
|
653
685
|
choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
|
|
654
686
|
a parameter name in string type, and the returned value is a bool.
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
687
|
+
Default: ``None`` .
|
|
688
|
+
|
|
689
|
+
- If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
690
|
+
- If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
691
|
+
be saved.
|
|
692
|
+
|
|
658
693
|
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
659
694
|
result to the file. Default: ``False`` .
|
|
660
695
|
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
@@ -693,6 +728,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
693
728
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
694
729
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
695
730
|
"""
|
|
731
|
+
start_save_time = time.time()
|
|
696
732
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
697
733
|
integrated_save = Validator.check_bool(integrated_save)
|
|
698
734
|
async_save = _check_async_save(async_save)
|
|
@@ -703,12 +739,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
703
739
|
map_param_inc = kwargs.get('incremental', False)
|
|
704
740
|
logger.info("Execute the process of saving checkpoint files.")
|
|
705
741
|
global_step_num = kwargs.get('global_step_num', None)
|
|
706
|
-
_check_save_checkpoint_upsupported_param(format, enc_key, enc_mode,
|
|
742
|
+
_check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
|
|
707
743
|
|
|
708
744
|
if append_dict and "__exception_save__" in append_dict:
|
|
709
745
|
s1 = mindspore.hal.Stream()
|
|
710
746
|
with mindspore.hal.StreamCtx(s1):
|
|
711
747
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
748
|
+
for k_name, value in append_dict.items():
|
|
749
|
+
if isinstance(value, (Tensor, Parameter)):
|
|
750
|
+
append_dict[k_name] = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
712
751
|
s1.synchronize()
|
|
713
752
|
else:
|
|
714
753
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
@@ -779,9 +818,11 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
779
818
|
data_list[key].append(dims)
|
|
780
819
|
tensor_type = str(param["data"].dtype)
|
|
781
820
|
data_list[key].append(tensor_type)
|
|
782
|
-
data = param["data"] if async_save
|
|
821
|
+
data = param["data"] if async_save is False else param["data"].asnumpy()
|
|
783
822
|
data_list[key].append(data)
|
|
784
823
|
|
|
824
|
+
from mindspore.profiler import mstx
|
|
825
|
+
range_id = mstx.range_start('save_checkpoint', None)
|
|
785
826
|
if os.getenv("AITURBO") == "1":
|
|
786
827
|
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
787
828
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
@@ -819,7 +860,32 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
819
860
|
else:
|
|
820
861
|
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
821
862
|
|
|
863
|
+
mstx.range_end(range_id)
|
|
822
864
|
logger.info("Saving checkpoint process is finished.")
|
|
865
|
+
end_save_time = time.time()
|
|
866
|
+
save_checkpoint_cost_time = end_save_time - start_save_time
|
|
867
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save checkpoint cost time {save_checkpoint_cost_time}.")
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def _handle_shared_param_for_pipeline_parallel(save_obj):
|
|
871
|
+
""" Remove shared param for save_obj """
|
|
872
|
+
filtered_save_obj = []
|
|
873
|
+
for param_dict in save_obj:
|
|
874
|
+
cur_param = param_dict['data']
|
|
875
|
+
if isinstance(cur_param, Parameter):
|
|
876
|
+
if not cur_param.param_info.is_pipeline_shared_param:
|
|
877
|
+
filtered_save_obj.append(param_dict)
|
|
878
|
+
else:
|
|
879
|
+
filtered_save_obj.append(param_dict)
|
|
880
|
+
return filtered_save_obj
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def _is_auto_parallel_mode(save_obj):
|
|
884
|
+
"""Check if in auto parallel mode by verifying parameter initialization."""
|
|
885
|
+
for _, param in save_obj.parameters_and_names():
|
|
886
|
+
if param.param_info.is_param_init:
|
|
887
|
+
return True
|
|
888
|
+
return False
|
|
823
889
|
|
|
824
890
|
|
|
825
891
|
def _convert_list_to_param_list(save_obj, choice_func):
|
|
@@ -860,7 +926,7 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
860
926
|
"""Convert a dict of Parameter to param_list."""
|
|
861
927
|
param_list = []
|
|
862
928
|
for (key, value) in save_obj.items():
|
|
863
|
-
if isinstance(key, str) and isinstance(value, (Parameter, str)):
|
|
929
|
+
if isinstance(key, str) and (isinstance(value, (Parameter, str)) or _is_buffer_type(value)):
|
|
864
930
|
if choice_func is not None and not choice_func(key):
|
|
865
931
|
continue
|
|
866
932
|
each_param = {"name": key, "data": value}
|
|
@@ -872,15 +938,19 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
872
938
|
return param_list
|
|
873
939
|
|
|
874
940
|
|
|
875
|
-
def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
941
|
+
def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
|
|
876
942
|
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
877
943
|
param_dict = OrderedDict()
|
|
878
944
|
for _, param in save_obj.parameters_and_names():
|
|
945
|
+
if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
|
|
946
|
+
continue
|
|
879
947
|
not_sliced = not param.sliced
|
|
880
948
|
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
881
949
|
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
882
950
|
judgment = not_sliced or param.has_init
|
|
883
|
-
if
|
|
951
|
+
if param.param_info.is_pipeline_shared_param:
|
|
952
|
+
continue
|
|
953
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
884
954
|
continue
|
|
885
955
|
if choice_func is not None and not choice_func(param.name):
|
|
886
956
|
continue
|
|
@@ -898,11 +968,12 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
898
968
|
sync_pipeline_shared_parameters(save_obj)
|
|
899
969
|
param_list = []
|
|
900
970
|
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
901
|
-
|
|
971
|
+
is_parallel_mode = _is_auto_parallel_mode(save_obj)
|
|
972
|
+
if is_parallel_mode and not parameter_layout_dict:
|
|
902
973
|
parameter_layout_dict = _get_parameter_layout()
|
|
903
|
-
if not
|
|
974
|
+
if not is_parallel_mode:
|
|
904
975
|
save_obj.init_parameters_data()
|
|
905
|
-
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
|
|
976
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
|
|
906
977
|
if append_dict and "random_op" in append_dict:
|
|
907
978
|
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
908
979
|
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
@@ -950,11 +1021,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
950
1021
|
|
|
951
1022
|
def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
952
1023
|
"""Convert a save_obj to param_list."""
|
|
953
|
-
if isinstance(save_obj, list):
|
|
954
|
-
|
|
1024
|
+
if isinstance(save_obj, (list, dict)):
|
|
1025
|
+
if isinstance(save_obj, list):
|
|
1026
|
+
save_obj = _convert_list_to_param_list(save_obj, choice_func)
|
|
955
1027
|
|
|
956
|
-
|
|
957
|
-
|
|
1028
|
+
if isinstance(save_obj, dict):
|
|
1029
|
+
save_obj = _convert_dict_to_param_dict(save_obj, choice_func)
|
|
1030
|
+
|
|
1031
|
+
return _handle_shared_param_for_pipeline_parallel(save_obj)
|
|
958
1032
|
|
|
959
1033
|
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
960
1034
|
|
|
@@ -985,11 +1059,8 @@ def _check_append_dict(append_dict):
|
|
|
985
1059
|
return append_dict
|
|
986
1060
|
|
|
987
1061
|
|
|
988
|
-
def
|
|
989
|
-
if
|
|
990
|
-
customized_func = _check_customized_func(kwargs.get('obf_func'))
|
|
991
|
-
clean_funcs()
|
|
992
|
-
add_opaque_predicate(customized_func.__name__, customized_func)
|
|
1062
|
+
def _is_buffer_type(value):
|
|
1063
|
+
if isinstance(value, Tensor) and getattr(value, "_is_buffer", False):
|
|
993
1064
|
return True
|
|
994
1065
|
return False
|
|
995
1066
|
|
|
@@ -1006,20 +1077,18 @@ def load(file_name, **kwargs):
|
|
|
1006
1077
|
kwargs (dict): Configuration options dictionary.
|
|
1007
1078
|
|
|
1008
1079
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
1009
|
-
- dec_mode (Union[str, function]):
|
|
1080
|
+
- dec_mode (Union[str, function], optional):
|
|
1081
|
+
Specifies the decryption mode, to take effect when dec_key is set.
|
|
1010
1082
|
|
|
1011
1083
|
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
|
|
1012
1084
|
- For details of using the customized decryption, please check the `tutorial
|
|
1013
1085
|
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
1014
1086
|
|
|
1015
|
-
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
1016
|
-
`obfuscate_model()
|
|
1017
|
-
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
1018
|
-
|
|
1019
1087
|
Returns:
|
|
1020
1088
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
1021
1089
|
|
|
1022
1090
|
Raises:
|
|
1091
|
+
NotImplementedError: Dynamic model structure obfuscation is no longer supported.
|
|
1023
1092
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
1024
1093
|
RuntimeError: Failed to parse MindIR file.
|
|
1025
1094
|
|
|
@@ -1046,6 +1115,8 @@ def load(file_name, **kwargs):
|
|
|
1046
1115
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1047
1116
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1048
1117
|
"""
|
|
1118
|
+
if 'obf_func' in kwargs.keys():
|
|
1119
|
+
raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
|
|
1049
1120
|
if not isinstance(file_name, str):
|
|
1050
1121
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
1051
1122
|
"got {}.".format(type(file_name)))
|
|
@@ -1057,9 +1128,6 @@ def load(file_name, **kwargs):
|
|
|
1057
1128
|
"please check whether the 'file_name' is correct.")
|
|
1058
1129
|
file_name = os.path.realpath(file_name)
|
|
1059
1130
|
|
|
1060
|
-
# set customized functions for dynamic obfuscation
|
|
1061
|
-
obfuscated = _check_load_obfuscate(**kwargs)
|
|
1062
|
-
|
|
1063
1131
|
logger.info("Execute the process of loading mindir.")
|
|
1064
1132
|
if 'dec_key' in kwargs.keys():
|
|
1065
1133
|
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
|
|
@@ -1072,9 +1140,9 @@ def load(file_name, **kwargs):
|
|
|
1072
1140
|
else:
|
|
1073
1141
|
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
|
1074
1142
|
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
|
|
1075
|
-
decrypt=dec_func
|
|
1143
|
+
decrypt=dec_func)
|
|
1076
1144
|
else:
|
|
1077
|
-
graph = load_mindir(file_name
|
|
1145
|
+
graph = load_mindir(file_name)
|
|
1078
1146
|
|
|
1079
1147
|
if graph is None:
|
|
1080
1148
|
if _is_cipher_file(file_name):
|
|
@@ -1141,181 +1209,12 @@ def _check_param_type(param_config, key, target_type, requested):
|
|
|
1141
1209
|
if key in param_config:
|
|
1142
1210
|
if not isinstance(param_config[key], target_type):
|
|
1143
1211
|
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
|
|
1144
|
-
if key == 'obf_random_seed':
|
|
1145
|
-
if param_config[key] > INT_64_MAX or param_config[key] <= 0:
|
|
1146
|
-
raise ValueError(
|
|
1147
|
-
"'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
|
|
1148
|
-
param_config[key]))
|
|
1149
1212
|
return param_config[key]
|
|
1150
1213
|
if requested:
|
|
1151
1214
|
raise ValueError("The parameter {} is requested, but not got.".format(key))
|
|
1152
|
-
if key == "obf_random_seed":
|
|
1153
|
-
return 0
|
|
1154
1215
|
return None
|
|
1155
1216
|
|
|
1156
1217
|
|
|
1157
|
-
def _check_customized_func(customized_func):
|
|
1158
|
-
""" check customized function of dynamic obfuscation """
|
|
1159
|
-
if not callable(customized_func):
|
|
1160
|
-
raise TypeError(
|
|
1161
|
-
"'customized_func' must be a function, but not got {}.".format(type(customized_func)))
|
|
1162
|
-
# test customized_func
|
|
1163
|
-
try:
|
|
1164
|
-
func_result = customized_func(1.0, 1.0)
|
|
1165
|
-
except Exception as ex:
|
|
1166
|
-
raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
|
|
1167
|
-
else:
|
|
1168
|
-
if not isinstance(func_result, bool):
|
|
1169
|
-
raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
|
|
1170
|
-
return customized_func
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
def _check_obfuscate_params(obf_config):
|
|
1174
|
-
"""Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
|
|
1175
|
-
if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
|
|
1176
|
-
raise ValueError(
|
|
1177
|
-
"At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
|
|
1178
|
-
obfuscate_type = _check_param_type(obf_config, "type", str, False)
|
|
1179
|
-
if obfuscate_type not in (None, "dynamic"):
|
|
1180
|
-
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
|
|
1181
|
-
if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
|
|
1182
|
-
if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
|
|
1183
|
-
raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
|
|
1184
|
-
obf_config['obf_ratio']))
|
|
1185
|
-
ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
|
|
1186
|
-
obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
|
|
1187
|
-
obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
|
|
1188
|
-
if (obf_ratio <= 0) or (obf_ratio > 1):
|
|
1189
|
-
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
|
|
1190
|
-
customized_funcs = []
|
|
1191
|
-
if 'customized_func' in obf_config.keys():
|
|
1192
|
-
device_target = context.get_context('device_target')
|
|
1193
|
-
if device_target in ["GPU", "Ascend"]:
|
|
1194
|
-
raise ValueError(
|
|
1195
|
-
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
1196
|
-
customized_funcs.append(_check_customized_func(obf_config['customized_func']))
|
|
1197
|
-
obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
|
|
1198
|
-
return obf_ratio, customized_funcs, obf_random_seed
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
def obfuscate_model(obf_config, **kwargs):
|
|
1202
|
-
"""
|
|
1203
|
-
Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
|
|
1204
|
-
predict correctness. The obfuscated model can prevent attackers from stealing the model.
|
|
1205
|
-
|
|
1206
|
-
Args:
|
|
1207
|
-
obf_config (dict): obfuscation config.
|
|
1208
|
-
|
|
1209
|
-
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1210
|
-
- original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
|
|
1211
|
-
model is encrypted, then enc_key and enc_mode should be provided.
|
|
1212
|
-
- save_model_path (str): The path to save the obfuscated model.
|
|
1213
|
-
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
1214
|
-
is the same as using :func:`mindspore.export`.
|
|
1215
|
-
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1216
|
-
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1217
|
-
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1218
|
-
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1219
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1220
|
-
Reference to 'my_func()' in
|
|
1221
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1222
|
-
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1223
|
-
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
1224
|
-
when loading obfuscated model.
|
|
1225
|
-
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1226
|
-
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1227
|
-
`obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
|
|
1228
|
-
interface when loading
|
|
1229
|
-
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1230
|
-
be set, and the latter mode would be applied if both of them are set.
|
|
1231
|
-
|
|
1232
|
-
kwargs (dict): Configuration options dictionary.
|
|
1233
|
-
|
|
1234
|
-
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
1235
|
-
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
1236
|
-
Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
|
|
1237
|
-
|
|
1238
|
-
Raises:
|
|
1239
|
-
TypeError: If `obf_config` is not a dict.
|
|
1240
|
-
ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
|
|
1241
|
-
ValueError: If `original_model_path` is not provided in `obf_config`.
|
|
1242
|
-
ValueError: If the model saved in `original_model_path` has been obfuscated.
|
|
1243
|
-
ValueError: If `save_model_path` is not provided in `obf_config`.
|
|
1244
|
-
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
1245
|
-
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
1246
|
-
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
1247
|
-
ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
|
|
1248
|
-
|
|
1249
|
-
Examples:
|
|
1250
|
-
>>> import mindspore as ms
|
|
1251
|
-
>>> import mindspore.nn as nn
|
|
1252
|
-
>>> import numpy as np
|
|
1253
|
-
>>> # Download ori_net.mindir
|
|
1254
|
-
>>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
|
|
1255
|
-
>>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
|
1256
|
-
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
1257
|
-
... 'save_model_path': "./obf_net",
|
|
1258
|
-
... 'model_inputs': [input1, ],
|
|
1259
|
-
... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
|
|
1260
|
-
>>> ms.obfuscate_model(obf_config)
|
|
1261
|
-
>>> obf_func = ms.load("obf_net.mindir")
|
|
1262
|
-
>>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
|
|
1263
|
-
>>> print(obf_net(input1).asnumpy())
|
|
1264
|
-
"""
|
|
1265
|
-
if not isinstance(obf_config, dict):
|
|
1266
|
-
raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
|
|
1267
|
-
file_path = _check_param_type(obf_config, "original_model_path", str, True)
|
|
1268
|
-
if not file_path.endswith(".mindir"):
|
|
1269
|
-
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
|
|
1270
|
-
"please input the correct 'file_path'.")
|
|
1271
|
-
if not os.path.exists(file_path):
|
|
1272
|
-
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
|
|
1273
|
-
"please check whether the 'file_path' is correct.")
|
|
1274
|
-
saved_path = _check_param_type(obf_config, "save_model_path", str, True)
|
|
1275
|
-
model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
|
|
1276
|
-
for item in model_inputs:
|
|
1277
|
-
if not isinstance(item, Tensor):
|
|
1278
|
-
raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
|
|
1279
|
-
if -1 in item.shape:
|
|
1280
|
-
raise ValueError(
|
|
1281
|
-
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
1282
|
-
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
|
|
1283
|
-
if customized_funcs and obf_random_seed > 0:
|
|
1284
|
-
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
1285
|
-
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
1286
|
-
|
|
1287
|
-
if obf_random_seed == 0: # apply customized_func mode
|
|
1288
|
-
clean_funcs()
|
|
1289
|
-
for func in customized_funcs:
|
|
1290
|
-
add_opaque_predicate(func.__name__, func)
|
|
1291
|
-
branch_control_input = 0
|
|
1292
|
-
else: # apply password mode
|
|
1293
|
-
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
1294
|
-
|
|
1295
|
-
if 'enc_key' in kwargs.keys():
|
|
1296
|
-
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
|
1297
|
-
enc_mode = "AES-GCM"
|
|
1298
|
-
if 'enc_mode' in kwargs.keys():
|
|
1299
|
-
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
1300
|
-
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
1301
|
-
raise ValueError(
|
|
1302
|
-
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
1303
|
-
"obfuscate_model(), but got {}.".format(enc_mode))
|
|
1304
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
1305
|
-
branch_control_input=branch_control_input, dec_key=enc_key,
|
|
1306
|
-
key_len=len(enc_key),
|
|
1307
|
-
dec_mode=enc_mode)
|
|
1308
|
-
else:
|
|
1309
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
1310
|
-
branch_control_input=branch_control_input)
|
|
1311
|
-
|
|
1312
|
-
obf_net = nn.GraphCell(obf_graph)
|
|
1313
|
-
if obf_random_seed != 0:
|
|
1314
|
-
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
1315
|
-
model_inputs += [append_y_tensor]
|
|
1316
|
-
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
1218
|
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1320
1219
|
dec_mode, crc_check, format):
|
|
1321
1220
|
"""load parameter into parameter_dict"""
|
|
@@ -1323,17 +1222,22 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1323
1222
|
if format == "safetensors":
|
|
1324
1223
|
with safe_open(ckpt_file_name, framework='np') as f:
|
|
1325
1224
|
cal_crc_num = 0
|
|
1326
|
-
|
|
1225
|
+
total_io_cost_time = 0
|
|
1327
1226
|
for k in sorted(f.keys()):
|
|
1328
1227
|
if crc_check:
|
|
1329
1228
|
cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
|
|
1330
1229
|
cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
|
|
1331
1230
|
if choice_func is not None and not choice_func(k):
|
|
1332
1231
|
continue
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1232
|
+
io_start_time = time.time()
|
|
1233
|
+
value = f.get_tensor(k)
|
|
1234
|
+
io_end_time = time.time()
|
|
1235
|
+
io_cost_time = io_end_time - io_start_time
|
|
1236
|
+
total_io_cost_time += io_cost_time
|
|
1237
|
+
parameter_dict[k] = Parameter(Tensor.from_numpy(value))
|
|
1238
|
+
|
|
1239
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1240
|
+
f"Load safetensors io cost time:{total_io_cost_time}.")
|
|
1337
1241
|
if crc_check:
|
|
1338
1242
|
if f.metadata() is None or f.metadata().get("crc_num") is None:
|
|
1339
1243
|
logger.warning(
|
|
@@ -1411,38 +1315,37 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1411
1315
|
Load checkpoint info from a specified file.
|
|
1412
1316
|
|
|
1413
1317
|
Note:
|
|
1414
|
-
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1415
|
-
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1416
1318
|
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1417
|
-
`choice_func` is recommended instead.
|
|
1319
|
+
`choice_func` is recommended instead. `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1418
1320
|
And using either of those two args will override `choice_func` at the same time.
|
|
1321
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1419
1322
|
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1420
1323
|
|
|
1421
1324
|
Args:
|
|
1422
1325
|
ckpt_file_name (str): Checkpoint file name.
|
|
1423
|
-
net (Cell): The network where the parameters will be loaded. Default: ``None`` .
|
|
1424
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1425
|
-
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1326
|
+
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1327
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1328
|
+
parameter into net when parameter name's suffix in checkpoint file is the same as the
|
|
1426
1329
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1427
1330
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1428
|
-
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`).
|
|
1429
|
-
filter_prefix will not be loaded. Default: ``None`` .
|
|
1430
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` ,
|
|
1431
|
-
is not required. Default: ``None`` .
|
|
1432
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
|
|
1433
|
-
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
1331
|
+
filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
|
|
1332
|
+
Parameters starting with the filter_prefix will not be loaded. Default: ``None`` .
|
|
1333
|
+
dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
|
|
1334
|
+
the decryption is not required. Default: ``None`` .
|
|
1335
|
+
dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
|
|
1336
|
+
decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
1434
1337
|
Default: ``"AES-GCM"`` .
|
|
1435
|
-
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`).
|
|
1436
|
-
specify_prefix will be loaded. Default: ``None`` .
|
|
1437
|
-
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
|
|
1338
|
+
specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
|
|
1339
|
+
Parameters starting with the specify_prefix will be loaded. Default: ``None`` .
|
|
1340
|
+
choice_func (Union[None, function], optional) : Input value of the function is a Parameter name of type string,
|
|
1438
1341
|
and the return value is a bool. If returns ``True`` , the Parameter
|
|
1439
1342
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1440
1343
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1441
|
-
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1442
|
-
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1344
|
+
crc_check (bool, optional) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1345
|
+
remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1443
1346
|
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1444
1347
|
redundant-free loading is not enabled.
|
|
1445
|
-
format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1348
|
+
format (str, optional): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1446
1349
|
|
|
1447
1350
|
Returns:
|
|
1448
1351
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1487,6 +1390,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1487
1390
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1488
1391
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1489
1392
|
"""
|
|
1393
|
+
start_load_time = time.time()
|
|
1490
1394
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
|
|
1491
1395
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1492
1396
|
filter_prefix = _check_prefix(filter_prefix)
|
|
@@ -1535,6 +1439,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1535
1439
|
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1536
1440
|
|
|
1537
1441
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
|
|
1442
|
+
end_load_time = time.time()
|
|
1443
|
+
load_checkpoint_cost_time = end_load_time - start_load_time
|
|
1444
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load checkpoint cost time {load_checkpoint_cost_time}.")
|
|
1538
1445
|
return parameter_dict
|
|
1539
1446
|
|
|
1540
1447
|
|
|
@@ -1554,7 +1461,7 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
|
|
|
1554
1461
|
And using either of those two args will override `choice_func` at the same time.
|
|
1555
1462
|
|
|
1556
1463
|
Args:
|
|
1557
|
-
ckpt_file_name (str): Checkpoint file name.
|
|
1464
|
+
ckpt_file_name (str): Checkpoint file name. The file extension must be `ckpt` or `safetensors` .
|
|
1558
1465
|
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1559
1466
|
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1560
1467
|
parameter into net when parameter name's suffix in checkpoint file is the
|
|
@@ -1612,10 +1519,11 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
|
|
|
1612
1519
|
>>> model.train(2, dataset)
|
|
1613
1520
|
>>> print("param dict len: ", len(param_dict), flush=True)
|
|
1614
1521
|
"""
|
|
1522
|
+
format = "safetensors" if ckpt_file_name.endswith(".safetensors") else "ckpt"
|
|
1615
1523
|
from concurrent.futures import ThreadPoolExecutor
|
|
1616
1524
|
executor = ThreadPoolExecutor(max_workers=2)
|
|
1617
1525
|
param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
|
|
1618
|
-
dec_key, dec_mode, specify_prefix, choice_func)
|
|
1526
|
+
dec_key, dec_mode, specify_prefix, choice_func, format=format)
|
|
1619
1527
|
return ParamDictFuture(executor, param_dict_future)
|
|
1620
1528
|
|
|
1621
1529
|
|
|
@@ -1703,7 +1611,7 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
|
1703
1611
|
pb_content = f.read()
|
|
1704
1612
|
ckpt_load_time_end = time.time()
|
|
1705
1613
|
cost_time = ckpt_load_time_end - ckpt_load_time_start
|
|
1706
|
-
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt cost time:{cost_time}.")
|
|
1614
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt io cost time:{cost_time}.")
|
|
1707
1615
|
|
|
1708
1616
|
else:
|
|
1709
1617
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
@@ -1774,17 +1682,18 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1774
1682
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
1775
1683
|
|
|
1776
1684
|
Note:
|
|
1777
|
-
|
|
1685
|
+
When loading a parameter dict that has removed redundancy, the network should be compiled.
|
|
1778
1686
|
|
|
1779
1687
|
Args:
|
|
1780
1688
|
net (Cell): The network where the parameters will be loaded.
|
|
1781
1689
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
1782
1690
|
it is a dictionary consisting of key: parameters's name, value: parameter.
|
|
1783
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` ,
|
|
1691
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` ,
|
|
1692
|
+
it will load parameter
|
|
1784
1693
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1785
1694
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1786
1695
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1787
|
-
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1696
|
+
remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1788
1697
|
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1789
1698
|
redundant-free loading is not enabled.
|
|
1790
1699
|
|
|
@@ -1825,6 +1734,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1825
1734
|
param_not_load = []
|
|
1826
1735
|
ckpt_not_load = list(parameter_dict.keys())
|
|
1827
1736
|
for _, param in net.parameters_and_names():
|
|
1737
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1738
|
+
continue
|
|
1828
1739
|
if param.name in parameter_dict:
|
|
1829
1740
|
if isinstance(param, MapParameter):
|
|
1830
1741
|
param.import_data(parameter_dict[param.name])
|
|
@@ -1843,31 +1754,24 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1843
1754
|
if param_not_load and not strict_load:
|
|
1844
1755
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1845
1756
|
|
|
1846
|
-
logger.info("Loading parameters into net is finished.")
|
|
1847
|
-
if param_not_load:
|
|
1848
|
-
logger.warning("For 'load_param_into_net', "
|
|
1849
|
-
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1850
|
-
"'parameter_dict', please check whether the network structure is consistent "
|
|
1851
|
-
"when training and loading checkpoint. Another possibility is that "
|
|
1852
|
-
"the redundant loading is not enabled, but the loaded checkpoint is saved with "
|
|
1853
|
-
"redundancy removed. ".format(len(param_not_load)))
|
|
1854
|
-
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1855
1757
|
if remove_redundancy:
|
|
1856
|
-
|
|
1857
|
-
if parallel_mode == "stand_alone":
|
|
1758
|
+
if get_group_size() == 1:
|
|
1858
1759
|
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1859
|
-
f"in parallel scenarios, but got
|
|
1760
|
+
f"in parallel scenarios, but got stand_alone.")
|
|
1860
1761
|
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1861
1762
|
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1862
1763
|
"the network should be compiled.")
|
|
1863
1764
|
param_layout = net.parameter_layout_dict
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
1867
|
-
chunk_size = device_num // stage_num
|
|
1868
|
-
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
1869
|
-
_single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
|
|
1765
|
+
_single_parameter_broadcast(net, param_layout, param_not_load)
|
|
1766
|
+
mindspore.hal.synchronize()
|
|
1870
1767
|
|
|
1768
|
+
logger.info("Loading parameters into net is finished.")
|
|
1769
|
+
if param_not_load:
|
|
1770
|
+
logger.warning("For 'load_param_into_net', "
|
|
1771
|
+
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1772
|
+
"'parameter_dict', please check whether the network structure is consistent "
|
|
1773
|
+
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
1774
|
+
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1871
1775
|
return param_not_load, ckpt_not_load
|
|
1872
1776
|
|
|
1873
1777
|
|
|
@@ -2050,9 +1954,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
2050
1954
|
elif opt_shard_group:
|
|
2051
1955
|
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
2052
1956
|
tuple(after_reshape_slice_shape))
|
|
2053
|
-
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
|
2054
|
-
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
2055
|
-
tuple(after_reshape_slice_shape))
|
|
2056
1957
|
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
|
2057
1958
|
if allgather_net:
|
|
2058
1959
|
param_data = allgather_net(param_data)
|
|
@@ -2106,27 +2007,6 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
2106
2007
|
|
|
2107
2008
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
2108
2009
|
preprocessing of the dataset into MindIR.
|
|
2109
|
-
|
|
2110
|
-
- obf_config (dict): obfuscation config.
|
|
2111
|
-
|
|
2112
|
-
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
2113
|
-
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
2114
|
-
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
2115
|
-
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
2116
|
-
- customized_func (function): A python function used for customized function mode, which used for control
|
|
2117
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
2118
|
-
Reference to 'my_func()' in
|
|
2119
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
2120
|
-
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
2121
|
-
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
2122
|
-
obfuscated model.
|
|
2123
|
-
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
2124
|
-
structure of obfuscated models corresponding to different random seeds is different. If
|
|
2125
|
-
`obf_random_seed` is set, then it should be passed
|
|
2126
|
-
to :class:`mindspore.nn.GraphCell` interface when loading
|
|
2127
|
-
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
2128
|
-
be set, and the latter mode would be applied if both of them are set.
|
|
2129
|
-
|
|
2130
2010
|
- incremental (bool): export MindIR incrementally.
|
|
2131
2011
|
|
|
2132
2012
|
- custom_func (function): Functions for custom defined export policies. This function will be used to
|
|
@@ -2160,6 +2040,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
2160
2040
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
2161
2041
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
2162
2042
|
"""
|
|
2043
|
+
if 'obf_func' in kwargs.keys():
|
|
2044
|
+
raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
|
|
2163
2045
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
2164
2046
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
2165
2047
|
|
|
@@ -2241,8 +2123,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
2241
2123
|
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
|
2242
2124
|
"""
|
|
2243
2125
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
2244
|
-
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
2245
|
-
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
2246
2126
|
if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
|
|
2247
2127
|
raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
|
|
2248
2128
|
if file_format == 'AIR':
|
|
@@ -2456,14 +2336,13 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2456
2336
|
os.chmod(data_file_name, stat.S_IRUSR)
|
|
2457
2337
|
|
|
2458
2338
|
|
|
2459
|
-
def _msfunc_info(net, *inputs):
|
|
2339
|
+
def _msfunc_info(net, jit_executor, *inputs):
|
|
2460
2340
|
"""Get mindir stream and parameter dict of ms_function"""
|
|
2461
2341
|
# pylint: disable=protected-access
|
|
2462
2342
|
net_dict = OrderedDict()
|
|
2463
|
-
|
|
2464
|
-
|
|
2465
|
-
|
|
2466
|
-
params = _ms_func_executor._graph_executor.get_params(graph_id)
|
|
2343
|
+
graph_id = jit_executor.compile(net.__name__, *inputs)
|
|
2344
|
+
mindir_stream = jit_executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
2345
|
+
params = jit_executor._graph_executor.get_params(graph_id)
|
|
2467
2346
|
for name, value in params.items():
|
|
2468
2347
|
net_dict[name] = Parameter(value, name=name)
|
|
2469
2348
|
return mindir_stream, net_dict
|
|
@@ -2475,53 +2354,21 @@ def _cell_info(net, incremental, *inputs):
|
|
|
2475
2354
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
2476
2355
|
# pylint: disable=protected-access
|
|
2477
2356
|
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
|
|
2478
|
-
# clean obfuscation config to prevent the next call
|
|
2479
|
-
_executor.obfuscate_config = None
|
|
2480
|
-
|
|
2481
2357
|
net_dict = net.parameters_dict()
|
|
2482
2358
|
return mindir_stream, net_dict
|
|
2483
2359
|
|
|
2484
2360
|
|
|
2485
|
-
def _set_obfuscate_config(**kwargs):
|
|
2486
|
-
"""Set obfuscation config for executor."""
|
|
2487
|
-
logger.warning("Obfuscate model.")
|
|
2488
|
-
if 'enc_mode' in kwargs.keys():
|
|
2489
|
-
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
2490
|
-
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
2491
|
-
raise ValueError(
|
|
2492
|
-
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
2493
|
-
"obfuscation, but got {}.".format(enc_mode))
|
|
2494
|
-
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
|
|
2495
|
-
if customized_funcs and obf_random_seed > 0:
|
|
2496
|
-
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
2497
|
-
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
2498
|
-
|
|
2499
|
-
if obf_random_seed == 0: # apply customized_func mode
|
|
2500
|
-
device_target = context.get_context('device_target')
|
|
2501
|
-
if device_target in ["GPU", "Ascend"]:
|
|
2502
|
-
raise ValueError(
|
|
2503
|
-
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
2504
|
-
clean_funcs()
|
|
2505
|
-
for func in customized_funcs:
|
|
2506
|
-
add_opaque_predicate(func.__name__, func)
|
|
2507
|
-
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
2361
|
def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
2511
2362
|
"""Save MindIR format file."""
|
|
2512
|
-
|
|
2513
|
-
if
|
|
2514
|
-
|
|
2515
|
-
for item in inputs:
|
|
2516
|
-
if -1 in item.shape:
|
|
2517
|
-
raise ValueError(
|
|
2518
|
-
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
2363
|
+
executor = _executor
|
|
2364
|
+
if not isinstance(net, nn.Cell):
|
|
2365
|
+
executor = _JitExecutor(net, time.time() * 1e9)
|
|
2519
2366
|
|
|
2520
2367
|
incremental = kwargs.get('incremental', False)
|
|
2521
2368
|
|
|
2522
2369
|
model = mindir_model()
|
|
2523
2370
|
if not isinstance(net, nn.Cell):
|
|
2524
|
-
mindir_stream, net_dict = _msfunc_info(net, *inputs)
|
|
2371
|
+
mindir_stream, net_dict = _msfunc_info(net, executor, *inputs)
|
|
2525
2372
|
else:
|
|
2526
2373
|
mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
|
|
2527
2374
|
model.ParseFromString(mindir_stream)
|
|
@@ -2594,8 +2441,10 @@ def _save_together(net_dict, model):
|
|
|
2594
2441
|
if name in net_dict.keys():
|
|
2595
2442
|
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
2596
2443
|
else:
|
|
2597
|
-
raise ValueError("
|
|
2598
|
-
"
|
|
2444
|
+
raise ValueError("There's a mindspore.Parameter that wasn't created in nn.Cell, and mindspore.export() "
|
|
2445
|
+
f"does not support exporting such Parameters. The parameter name is: {name}.\n"
|
|
2446
|
+
"You can find the supported syntax range for mindspore.export() at the following link:\n"
|
|
2447
|
+
"https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html")
|
|
2599
2448
|
if data_total > TOTAL_SAVE:
|
|
2600
2449
|
return False
|
|
2601
2450
|
return True
|
|
@@ -2762,566 +2611,6 @@ def parse_print(print_file_name):
|
|
|
2762
2611
|
return tensor_list
|
|
2763
2612
|
|
|
2764
2613
|
|
|
2765
|
-
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
2766
|
-
"""
|
|
2767
|
-
Merge data slices to one tensor with whole data when strategy is not None.
|
|
2768
|
-
|
|
2769
|
-
Args:
|
|
2770
|
-
sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
|
|
2771
|
-
parameter_name (str): Name of parameter.
|
|
2772
|
-
strategy (dict): Parameter slice strategy.
|
|
2773
|
-
is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
|
|
2774
|
-
|
|
2775
|
-
Returns:
|
|
2776
|
-
Tensor, the merged Tensor which has the whole data.
|
|
2777
|
-
|
|
2778
|
-
Raises:
|
|
2779
|
-
ValueError: Failed to merge.
|
|
2780
|
-
"""
|
|
2781
|
-
layout = strategy.get(parameter_name)
|
|
2782
|
-
try:
|
|
2783
|
-
dev_mat = list(layout.dev_matrix[0].dim)
|
|
2784
|
-
tensor_map = list(layout.tensor_map[0].dim)
|
|
2785
|
-
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
2786
|
-
field_size = int(layout.field)
|
|
2787
|
-
except BaseException as e:
|
|
2788
|
-
raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
|
|
2789
|
-
f", please make sure that 'strategy' is correct.") from e
|
|
2790
|
-
|
|
2791
|
-
device_count = 1
|
|
2792
|
-
for dim in dev_mat:
|
|
2793
|
-
device_count *= dim
|
|
2794
|
-
|
|
2795
|
-
if len(sliced_data) != device_count:
|
|
2796
|
-
raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
|
|
2797
|
-
f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
|
|
2798
|
-
f"device_count is {device_count}.")
|
|
2799
|
-
|
|
2800
|
-
if not param_split_shape:
|
|
2801
|
-
if not is_even:
|
|
2802
|
-
raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
|
|
2803
|
-
"should be the same when slice manner is even.")
|
|
2804
|
-
|
|
2805
|
-
all_gather_tensor = Tensor(np.concatenate(sliced_data))
|
|
2806
|
-
|
|
2807
|
-
if field_size > 0:
|
|
2808
|
-
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
|
|
2809
|
-
else:
|
|
2810
|
-
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
|
|
2811
|
-
|
|
2812
|
-
else:
|
|
2813
|
-
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
2814
|
-
|
|
2815
|
-
slice_count = 1
|
|
2816
|
-
for dim in tensor_strategy:
|
|
2817
|
-
slice_count *= dim
|
|
2818
|
-
|
|
2819
|
-
if len(param_split_shape) != slice_count:
|
|
2820
|
-
raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
|
|
2821
|
-
f"{slice_count}, but got {len(param_split_shape)}.")
|
|
2822
|
-
|
|
2823
|
-
tensor_slices_new = list(range(slice_count))
|
|
2824
|
-
tensor_slices = sliced_data
|
|
2825
|
-
for i in range(device_count):
|
|
2826
|
-
slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
|
|
2827
|
-
if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
|
|
2828
|
-
raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
|
|
2829
|
-
f"{param_split_shape[slice_index]} in 0 axis, but got "
|
|
2830
|
-
f"{tensor_slices[i].shape[0]}.")
|
|
2831
|
-
tensor_slices_new[slice_index] = np.array(tensor_slices[i])
|
|
2832
|
-
|
|
2833
|
-
dim_len = len(tensor_strategy)
|
|
2834
|
-
for i in range(dim_len):
|
|
2835
|
-
ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
|
|
2836
|
-
tensor_slices_new_inner = []
|
|
2837
|
-
for j in range(ele_count):
|
|
2838
|
-
new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
|
|
2839
|
-
for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
|
|
2840
|
-
(j + 1) * tensor_strategy[dim_len - 1 - i]):
|
|
2841
|
-
new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
|
|
2842
|
-
tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
|
|
2843
|
-
tensor_slices_new = tensor_slices_new_inner
|
|
2844
|
-
merged_tensor = Tensor(tensor_slices_new[0])
|
|
2845
|
-
|
|
2846
|
-
return merged_tensor
|
|
2847
|
-
|
|
2848
|
-
|
|
2849
|
-
def restore_group_info_list(group_info_file_name):
|
|
2850
|
-
"""
|
|
2851
|
-
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2852
|
-
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2853
|
-
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2854
|
-
|
|
2855
|
-
Args:
|
|
2856
|
-
group_info_file_name (str): Name of group information file.
|
|
2857
|
-
|
|
2858
|
-
Returns:
|
|
2859
|
-
List, the rank list.
|
|
2860
|
-
|
|
2861
|
-
Raises:
|
|
2862
|
-
ValueError: group information file is incorrect.
|
|
2863
|
-
TypeError: `group_info_file_name` is not str.
|
|
2864
|
-
|
|
2865
|
-
Examples:
|
|
2866
|
-
>>> import mindspore as ms
|
|
2867
|
-
>>> ms.restore_list = restore_group_info_list("./group_info.pb")
|
|
2868
|
-
"""
|
|
2869
|
-
if not isinstance(group_info_file_name, str):
|
|
2870
|
-
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
|
|
2871
|
-
f"but got {type(group_info_file_name)}.")
|
|
2872
|
-
|
|
2873
|
-
if not os.path.isfile(group_info_file_name):
|
|
2874
|
-
raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
|
|
2875
|
-
|
|
2876
|
-
if os.path.getsize(group_info_file_name) == 0:
|
|
2877
|
-
raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
|
|
2878
|
-
|
|
2879
|
-
return _restore_group_info_list(group_info_file_name)
|
|
2880
|
-
|
|
2881
|
-
|
|
2882
|
-
def build_searched_strategy(strategy_filename):
|
|
2883
|
-
"""
|
|
2884
|
-
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
2885
|
-
|
|
2886
|
-
Args:
|
|
2887
|
-
strategy_filename (str): Name of strategy file.
|
|
2888
|
-
|
|
2889
|
-
Returns:
|
|
2890
|
-
Dict, whose key is parameter name and value is slice strategy of this parameter.
|
|
2891
|
-
|
|
2892
|
-
Raises:
|
|
2893
|
-
ValueError: Strategy file is incorrect.
|
|
2894
|
-
TypeError: `strategy_filename` is not a string.
|
|
2895
|
-
|
|
2896
|
-
Examples:
|
|
2897
|
-
>>> import mindspore as ms
|
|
2898
|
-
>>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
|
|
2899
|
-
"""
|
|
2900
|
-
return _build_searched_strategy(strategy_filename)
|
|
2901
|
-
|
|
2902
|
-
|
|
2903
|
-
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2904
|
-
"""
|
|
2905
|
-
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
2906
|
-
|
|
2907
|
-
Args:
|
|
2908
|
-
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
2909
|
-
strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
|
|
2910
|
-
value is slice strategy of this parameter. If strategy is None, just merge
|
|
2911
|
-
parameter slices in 0 axis order. Default: ``None``.
|
|
2912
|
-
|
|
2913
|
-
Returns:
|
|
2914
|
-
Parameter, the merged parameter which has the whole data.
|
|
2915
|
-
|
|
2916
|
-
Raises:
|
|
2917
|
-
ValueError: Failed to merge.
|
|
2918
|
-
TypeError: The sliced_parameters is incorrect or strategy is not dict.
|
|
2919
|
-
KeyError: The parameter name is not in keys of strategy.
|
|
2920
|
-
|
|
2921
|
-
Examples:
|
|
2922
|
-
>>> import numpy as np
|
|
2923
|
-
>>> import mindspore as ms
|
|
2924
|
-
>>> from mindspore import Tensor, Parameter
|
|
2925
|
-
>>>
|
|
2926
|
-
>>> sliced_parameters = [
|
|
2927
|
-
... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
|
|
2928
|
-
... "network.embedding_table"),
|
|
2929
|
-
... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
|
|
2930
|
-
... "network.embedding_table"),
|
|
2931
|
-
... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
|
|
2932
|
-
... "network.embedding_table"),
|
|
2933
|
-
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
|
2934
|
-
... "network.embedding_table")]
|
|
2935
|
-
>>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
|
|
2936
|
-
>>> print(merged_parameter)
|
|
2937
|
-
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
|
|
2938
|
-
"""
|
|
2939
|
-
if not isinstance(sliced_parameters, list):
|
|
2940
|
-
raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
|
|
2941
|
-
f"but got {type(sliced_parameters)}.")
|
|
2942
|
-
|
|
2943
|
-
if not sliced_parameters:
|
|
2944
|
-
raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
|
|
2945
|
-
|
|
2946
|
-
if strategy and not isinstance(strategy, dict):
|
|
2947
|
-
raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
|
|
2948
|
-
f"but got {type(strategy)}.")
|
|
2949
|
-
|
|
2950
|
-
try:
|
|
2951
|
-
parameter_name = sliced_parameters[0].name
|
|
2952
|
-
parameter_shape = sliced_parameters[0].data.shape
|
|
2953
|
-
parameter_shape_length = len(parameter_shape)
|
|
2954
|
-
except BaseException as e:
|
|
2955
|
-
raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
|
|
2956
|
-
f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
|
|
2957
|
-
|
|
2958
|
-
is_even = True
|
|
2959
|
-
for index, parameter in enumerate(sliced_parameters):
|
|
2960
|
-
if not isinstance(parameter, Parameter):
|
|
2961
|
-
raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
|
|
2962
|
-
f"but got {type(parameter)} at index {index}.")
|
|
2963
|
-
|
|
2964
|
-
if parameter.name != parameter_name \
|
|
2965
|
-
or len(parameter.data.shape) != parameter_shape_length \
|
|
2966
|
-
or parameter.data.shape[1:] != parameter_shape[1:]:
|
|
2967
|
-
raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
|
|
2968
|
-
f" have the same name, dimension length and shape except 0 axis. The name, dimension "
|
|
2969
|
-
f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
|
|
2970
|
-
f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
|
|
2971
|
-
f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
|
|
2972
|
-
f"at index {index}.")
|
|
2973
|
-
|
|
2974
|
-
if parameter.data.shape != parameter_shape:
|
|
2975
|
-
is_even = False
|
|
2976
|
-
|
|
2977
|
-
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
|
2978
|
-
requires_grad = sliced_parameters[0].requires_grad
|
|
2979
|
-
sliced_data = []
|
|
2980
|
-
for parameter in sliced_parameters:
|
|
2981
|
-
if parameter.data.dtype == mstype.bfloat16:
|
|
2982
|
-
sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
|
|
2983
|
-
else:
|
|
2984
|
-
sliced_data.append(parameter.data.asnumpy())
|
|
2985
|
-
|
|
2986
|
-
if not strategy:
|
|
2987
|
-
merged_tensor = Tensor(np.concatenate(sliced_data))
|
|
2988
|
-
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
2989
|
-
|
|
2990
|
-
else:
|
|
2991
|
-
if parameter_name not in strategy.keys():
|
|
2992
|
-
raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
|
|
2993
|
-
f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
|
|
2994
|
-
merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
|
|
2995
|
-
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
2996
|
-
|
|
2997
|
-
return merged_parameter
|
|
2998
|
-
|
|
2999
|
-
|
|
3000
|
-
def _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, dst_device_num,
|
|
3001
|
-
output_format, name_map, return_param_dict):
|
|
3002
|
-
"""gather transform tasks"""
|
|
3003
|
-
tasks = []
|
|
3004
|
-
for rank in range(0, dst_device_num):
|
|
3005
|
-
tasks.append(
|
|
3006
|
-
(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank, output_format, name_map,
|
|
3007
|
-
return_param_dict))
|
|
3008
|
-
return tasks
|
|
3009
|
-
|
|
3010
|
-
|
|
3011
|
-
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
3012
|
-
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
3013
|
-
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
|
|
3014
|
-
output_format='safetensors', name_map=None, max_process_num=64,
|
|
3015
|
-
return_param_dict=False):
|
|
3016
|
-
"""
|
|
3017
|
-
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
3018
|
-
|
|
3019
|
-
Note:
|
|
3020
|
-
`output_format` will only take effect when `format` is set to `safetensors` and `network` is set to `None`.
|
|
3021
|
-
|
|
3022
|
-
Args:
|
|
3023
|
-
network (Cell): Network for distributed predication, When the format is `safetensors`, the network parameter
|
|
3024
|
-
can be left blank or passed as None, and the interface will execute save mode.
|
|
3025
|
-
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
3026
|
-
predict_strategy (Union[dict, str]): Strategy of predication process. It means that using one device to predict
|
|
3027
|
-
when setting predict_strategy as None. Default: ``None`` .
|
|
3028
|
-
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
3029
|
-
When train_strategy_filename is None, the training strategy file will be
|
|
3030
|
-
obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
|
|
3031
|
-
Therefore, the training strategy file needs to be specified
|
|
3032
|
-
in at least one of them. Default: ``None`` .
|
|
3033
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
3034
|
-
into net when parameter name's suffix in checkpoint file is the same as the
|
|
3035
|
-
parameter in the network. When the types are inconsistent, perform type conversion
|
|
3036
|
-
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
3037
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
3038
|
-
is not required. Default: ``None`` .
|
|
3039
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
3040
|
-
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
3041
|
-
Default: ``'AES-GCM'`` .
|
|
3042
|
-
format (str): Input weight format to be loaded into the network.
|
|
3043
|
-
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
3044
|
-
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
3045
|
-
Default: ``None`` .
|
|
3046
|
-
dst_safetensors_dir (str): In the save mode scenario, the save directory for weights.
|
|
3047
|
-
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
3048
|
-
globally by initializing the network; In save mode, save the file according to the input
|
|
3049
|
-
sequence number. If it is not input, save the entire file.
|
|
3050
|
-
output_format (str, optional): Control the format of the output checkpoint after conversion.
|
|
3051
|
-
It can be set to either "ckpt" or "safetensors". Default: "safetensors".
|
|
3052
|
-
name_map (dict): The weight mapping dictionary will modify the weight names according to the mapping
|
|
3053
|
-
dictionary before loading or saving the segmented weights into the network. Default: None.
|
|
3054
|
-
max_process_num (int): Maximum number of processes. Default: 64.
|
|
3055
|
-
return_param_dict (bool): Whether to return the param_dict. Default: ``False``.
|
|
3056
|
-
|
|
3057
|
-
Raises:
|
|
3058
|
-
TypeError: The type of inputs do not match the requirements.
|
|
3059
|
-
ValueError: Failed to load checkpoint into net.
|
|
3060
|
-
|
|
3061
|
-
Supported Platforms:
|
|
3062
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
3063
|
-
|
|
3064
|
-
Examples:
|
|
3065
|
-
.. note::
|
|
3066
|
-
Before running the following examples, you need to configure the communication environment variables.
|
|
3067
|
-
|
|
3068
|
-
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
3069
|
-
Please see the `rank table startup
|
|
3070
|
-
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
3071
|
-
for more details.
|
|
3072
|
-
|
|
3073
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
3074
|
-
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
3075
|
-
|
|
3076
|
-
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
3077
|
-
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
3078
|
-
|
|
3079
|
-
>>> import os
|
|
3080
|
-
>>> import numpy as np
|
|
3081
|
-
>>> import mindspore as ms
|
|
3082
|
-
>>> import mindspore.dataset as ds
|
|
3083
|
-
>>> from mindspore import nn, ops, train
|
|
3084
|
-
>>> from mindspore.communication import init
|
|
3085
|
-
>>>
|
|
3086
|
-
>>> step_per_epoch = 4
|
|
3087
|
-
>>> device_num = 8
|
|
3088
|
-
>>>
|
|
3089
|
-
>>> # Define the network structure.
|
|
3090
|
-
>>> class Net(nn.Cell):
|
|
3091
|
-
... def __init__(self, matmul_size, strategy=None):
|
|
3092
|
-
... super().__init__()
|
|
3093
|
-
... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
|
|
3094
|
-
... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
|
|
3095
|
-
... self.matmul = ops.MatMul()
|
|
3096
|
-
... self.neg = ops.Neg()
|
|
3097
|
-
... if strategy is not None:
|
|
3098
|
-
... self.matmul.shard(strategy)
|
|
3099
|
-
...
|
|
3100
|
-
... def construct(self, inputs):
|
|
3101
|
-
... x = self.matmul(inputs, self.matmul_weight)
|
|
3102
|
-
... x = self.neg(x)
|
|
3103
|
-
... return x
|
|
3104
|
-
>>>
|
|
3105
|
-
>>> # Create dataset.
|
|
3106
|
-
>>> def get_dataset(*inputs):
|
|
3107
|
-
... def generate():
|
|
3108
|
-
... for _ in range(step_per_epoch):
|
|
3109
|
-
... yield inputs
|
|
3110
|
-
... return generate
|
|
3111
|
-
>>>
|
|
3112
|
-
>>> # Train network and save distributed checkpoint.
|
|
3113
|
-
>>> def train_net():
|
|
3114
|
-
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
3115
|
-
... init()
|
|
3116
|
-
... np.random.seed(1)
|
|
3117
|
-
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
3118
|
-
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
3119
|
-
... fake_dataset = get_dataset(input_data, label_data)
|
|
3120
|
-
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
3121
|
-
...
|
|
3122
|
-
... # Set parallel strategy.
|
|
3123
|
-
... strategy = ((1, 4), (4, 1))
|
|
3124
|
-
... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
|
|
3125
|
-
... strategy_ckpt_save_file="./train_strategy.ckpt")
|
|
3126
|
-
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
3127
|
-
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
3128
|
-
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
3129
|
-
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
3130
|
-
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
|
|
3131
|
-
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
3132
|
-
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
3133
|
-
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
3134
|
-
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
3135
|
-
... ms.reset_auto_parallel_context()
|
|
3136
|
-
>>>
|
|
3137
|
-
>>> # Load distributed checkpoint and test.
|
|
3138
|
-
>>> def load_model():
|
|
3139
|
-
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
3140
|
-
... init()
|
|
3141
|
-
... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
|
|
3142
|
-
... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
|
|
3143
|
-
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
3144
|
-
... network = Net(matmul_size=(96, 16))
|
|
3145
|
-
... model = ms.Model(network)
|
|
3146
|
-
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
3147
|
-
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
3148
|
-
... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
|
|
3149
|
-
... predict_result = model.predict(predict_data)
|
|
3150
|
-
... print(predict_result)
|
|
3151
|
-
>>>
|
|
3152
|
-
>>> train_net()
|
|
3153
|
-
>>> load_model()
|
|
3154
|
-
[[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
|
|
3155
|
-
[ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
|
|
3156
|
-
...
|
|
3157
|
-
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
3158
|
-
"""
|
|
3159
|
-
if format not in ['safetensors', 'ckpt'] or output_format not in ['safetensors', 'ckpt']:
|
|
3160
|
-
raise ValueError(
|
|
3161
|
-
f"For 'load_distributed_checkpoint', 'format' and 'output_format' "
|
|
3162
|
-
f"must be 'ckpt' or 'safetensors', but got {format}.")
|
|
3163
|
-
|
|
3164
|
-
if format == 'safetensors':
|
|
3165
|
-
if unified_safetensors_dir is None:
|
|
3166
|
-
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
3167
|
-
f"when format is 'safetensors'.")
|
|
3168
|
-
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
3169
|
-
for param in unsupport_param:
|
|
3170
|
-
if param is not None:
|
|
3171
|
-
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
3172
|
-
f"when format is 'safetensors'.")
|
|
3173
|
-
if strict_load or dec_mode != 'AES-GCM':
|
|
3174
|
-
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
3175
|
-
f"when format is 'safetensors'.")
|
|
3176
|
-
if network is not None:
|
|
3177
|
-
try:
|
|
3178
|
-
rank_id = get_rank()
|
|
3179
|
-
except RuntimeError:
|
|
3180
|
-
rank_id = 0
|
|
3181
|
-
logger.warning(f"Get rank failed, default loading weight for rank 0.")
|
|
3182
|
-
param_dict = _load_parallel_checkpoint(
|
|
3183
|
-
(unified_safetensors_dir, predict_strategy, network, None, rank_id, output_format, name_map,
|
|
3184
|
-
return_param_dict))
|
|
3185
|
-
return param_dict
|
|
3186
|
-
if dst_safetensors_dir is None:
|
|
3187
|
-
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
3188
|
-
f"when network is None.")
|
|
3189
|
-
if rank_id is not None:
|
|
3190
|
-
_load_parallel_checkpoint(
|
|
3191
|
-
(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3192
|
-
rank_id, output_format, name_map, return_param_dict))
|
|
3193
|
-
else:
|
|
3194
|
-
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
3195
|
-
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
3196
|
-
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
3197
|
-
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
3198
|
-
tasks = _gather_tasks_load_dis(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3199
|
-
dst_device_num, output_format, name_map, return_param_dict)
|
|
3200
|
-
with Pool(processes=max_process_num) as pool:
|
|
3201
|
-
list(pool.imap(_load_parallel_checkpoint, tasks))
|
|
3202
|
-
return True
|
|
3203
|
-
|
|
3204
|
-
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
3205
|
-
_check_checkpoint_file(checkpoint_filenames)
|
|
3206
|
-
_check_predict_strategy(predict_strategy)
|
|
3207
|
-
|
|
3208
|
-
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
3209
|
-
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
3210
|
-
|
|
3211
|
-
if train_strategy_filename is None:
|
|
3212
|
-
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
|
3213
|
-
_train_strategy = build_searched_strategy(train_strategy_filename)
|
|
3214
|
-
train_strategy = _convert_to_list(_train_strategy)
|
|
3215
|
-
|
|
3216
|
-
train_dev_count = 1
|
|
3217
|
-
ckpt_file_len = len(checkpoint_filenames)
|
|
3218
|
-
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
|
|
3219
|
-
train_dev_count *= dim
|
|
3220
|
-
if train_dev_count != ckpt_file_len:
|
|
3221
|
-
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
|
|
3222
|
-
f"equal to the device count of training process. "
|
|
3223
|
-
f"But got the length of 'checkpoint_filenames'"
|
|
3224
|
-
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
|
|
3225
|
-
rank_list = _infer_rank_list(train_strategy, predict_strategy)
|
|
3226
|
-
|
|
3227
|
-
param_total_dict = defaultdict(dict)
|
|
3228
|
-
for file_index, file_name in enumerate(checkpoint_filenames):
|
|
3229
|
-
ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
|
|
3230
|
-
for param_name, param in ckpt_dict.items():
|
|
3231
|
-
param_total_dict[param_name][file_index] = param
|
|
3232
|
-
|
|
3233
|
-
param_dict = {}
|
|
3234
|
-
param_not_in_strategy = []
|
|
3235
|
-
param_not_in_ckpt = []
|
|
3236
|
-
for _, param in network.parameters_and_names():
|
|
3237
|
-
sliced_params = []
|
|
3238
|
-
if param.name not in rank_list.keys():
|
|
3239
|
-
param_not_in_strategy.append(param.name)
|
|
3240
|
-
continue
|
|
3241
|
-
if param.name not in param_total_dict:
|
|
3242
|
-
param_not_in_ckpt.append(param.name)
|
|
3243
|
-
continue
|
|
3244
|
-
|
|
3245
|
-
param_rank = rank_list.get(param.name)[0]
|
|
3246
|
-
skip_merge_split = rank_list.get(param.name)[1]
|
|
3247
|
-
shard_stride = train_strategy.get(param.name)[4]
|
|
3248
|
-
tensor_map = train_strategy.get(param.name)[1]
|
|
3249
|
-
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
3250
|
-
device_arrangement = train_strategy.get(param.name)[0]
|
|
3251
|
-
first_dim_shard_size = 1
|
|
3252
|
-
if first_dim_shard_idx >= 0:
|
|
3253
|
-
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
3254
|
-
if train_strategy.get(param.name)[5]:
|
|
3255
|
-
repeat_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
3256
|
-
else:
|
|
3257
|
-
repeat_size = 0
|
|
3258
|
-
for rank in param_rank:
|
|
3259
|
-
param_total_list = list(range(0, ckpt_file_len))
|
|
3260
|
-
if first_dim_shard_size != 1:
|
|
3261
|
-
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
3262
|
-
if repeat_size > 0:
|
|
3263
|
-
shard_size = shard_stride * train_strategy.get(param.name)[5]
|
|
3264
|
-
rank_index = param_total_list.index(rank)
|
|
3265
|
-
start = rank_index // shard_size * shard_size
|
|
3266
|
-
param_total_list = param_total_list[start:start + shard_size]
|
|
3267
|
-
if shard_stride > 0:
|
|
3268
|
-
param_stride = []
|
|
3269
|
-
# merge pre parameter
|
|
3270
|
-
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
|
|
3271
|
-
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
|
|
3272
|
-
param_index = list(set(param_index))
|
|
3273
|
-
param_index.sort()
|
|
3274
|
-
for rank_num in param_index:
|
|
3275
|
-
if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
|
|
3276
|
-
param_stride.append(
|
|
3277
|
-
cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
|
|
3278
|
-
else:
|
|
3279
|
-
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
|
|
3280
|
-
|
|
3281
|
-
sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
|
|
3282
|
-
else:
|
|
3283
|
-
sliced_param = param_total_dict[param.name][rank]
|
|
3284
|
-
|
|
3285
|
-
sliced_params.append(sliced_param)
|
|
3286
|
-
if skip_merge_split:
|
|
3287
|
-
split_param = sliced_params[0]
|
|
3288
|
-
else:
|
|
3289
|
-
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
|
|
3290
|
-
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
|
|
3291
|
-
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
|
|
3292
|
-
opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
|
|
3293
|
-
if opt_shard_group:
|
|
3294
|
-
if split_param.data.dtype == mstype.bfloat16:
|
|
3295
|
-
data = cpu_cast(split_param.data, mstype.float32).asnumpy()
|
|
3296
|
-
else:
|
|
3297
|
-
data = split_param.data.asnumpy()
|
|
3298
|
-
rank = get_rank(opt_shard_group)
|
|
3299
|
-
size = get_group_size(opt_shard_group)
|
|
3300
|
-
try:
|
|
3301
|
-
data_slice = np.split(data, size)[rank]
|
|
3302
|
-
except BaseException as e:
|
|
3303
|
-
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
|
3304
|
-
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
|
3305
|
-
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
|
|
3306
|
-
f" in load distributed checkpoint for {param.name}. Data shape is "
|
|
3307
|
-
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
|
3308
|
-
split_param = Parameter(Tensor(data_slice), param.name,
|
|
3309
|
-
split_param.requires_grad, split_param.layerwise_parallel)
|
|
3310
|
-
param_dict[param.name] = split_param
|
|
3311
|
-
|
|
3312
|
-
if param_not_in_strategy:
|
|
3313
|
-
logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
|
|
3314
|
-
"you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
|
|
3315
|
-
.format(param_not_in_strategy))
|
|
3316
|
-
if param_not_in_ckpt:
|
|
3317
|
-
logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
|
|
3318
|
-
"the checkpoint file, please check whether 'checkpoint_filenames' is correct."
|
|
3319
|
-
.format(param_not_in_ckpt))
|
|
3320
|
-
|
|
3321
|
-
load_param_into_net(network, param_dict, strict_load=strict_load)
|
|
3322
|
-
return True
|
|
3323
|
-
|
|
3324
|
-
|
|
3325
2614
|
def async_ckpt_thread_status():
|
|
3326
2615
|
"""
|
|
3327
2616
|
Get the status of asynchronous save checkpoint thread.
|
|
@@ -3346,69 +2635,6 @@ def async_ckpt_thread_status():
|
|
|
3346
2635
|
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
|
|
3347
2636
|
|
|
3348
2637
|
|
|
3349
|
-
def _check_predict_strategy(predict_strategy):
|
|
3350
|
-
"""Check predict strategy."""
|
|
3351
|
-
|
|
3352
|
-
def _check_int_list(arg):
|
|
3353
|
-
if not isinstance(arg, list):
|
|
3354
|
-
return False
|
|
3355
|
-
for item in arg:
|
|
3356
|
-
if not isinstance(item, int):
|
|
3357
|
-
return False
|
|
3358
|
-
return True
|
|
3359
|
-
|
|
3360
|
-
if predict_strategy is None:
|
|
3361
|
-
return
|
|
3362
|
-
|
|
3363
|
-
flag = True
|
|
3364
|
-
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
|
|
3365
|
-
for key in predict_strategy.keys():
|
|
3366
|
-
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
|
|
3367
|
-
or len(predict_strategy[key]) < 4:
|
|
3368
|
-
flag = False
|
|
3369
|
-
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
|
|
3370
|
-
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
|
|
3371
|
-
not (_check_int_list(param_split_shape) or not param_split_shape) or \
|
|
3372
|
-
not (isinstance(field_size, int) and field_size == 0):
|
|
3373
|
-
flag = False
|
|
3374
|
-
|
|
3375
|
-
if not flag:
|
|
3376
|
-
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
|
|
3377
|
-
f"the key of it must be string, and the value of it must be list or tuple that "
|
|
3378
|
-
f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
|
|
3379
|
-
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
|
|
3380
|
-
f"Please check whether 'predict_strategy' is correct.")
|
|
3381
|
-
|
|
3382
|
-
|
|
3383
|
-
def _check_checkpoint_file(checkpoint_filenames):
|
|
3384
|
-
"""Check checkpoint file name."""
|
|
3385
|
-
for index, filename in enumerate(checkpoint_filenames):
|
|
3386
|
-
if not isinstance(filename, str) or not os.path.exists(filename) \
|
|
3387
|
-
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
|
|
3388
|
-
raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
|
|
3389
|
-
f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
|
|
3390
|
-
f"be a string ending with '.ckpt', and the checkpoint file it represents must "
|
|
3391
|
-
f"be exist and not empty.")
|
|
3392
|
-
|
|
3393
|
-
|
|
3394
|
-
def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
3395
|
-
"""Merge sliced parameter and split it according to the predict strategy."""
|
|
3396
|
-
merged_param = merge_sliced_parameter(sliced_params, train_strategy)
|
|
3397
|
-
if predict_strategy is None:
|
|
3398
|
-
return merged_param
|
|
3399
|
-
param_name = merged_param.name
|
|
3400
|
-
tensor_layout = predict_strategy[param_name]
|
|
3401
|
-
rank = get_rank()
|
|
3402
|
-
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
3403
|
-
requires_grad = merged_param.requires_grad
|
|
3404
|
-
layerwise_parallel = merged_param.layerwise_parallel
|
|
3405
|
-
if merged_param.data.dtype == mstype.bfloat16:
|
|
3406
|
-
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
3407
|
-
else:
|
|
3408
|
-
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
3409
|
-
return split_param
|
|
3410
|
-
|
|
3411
|
-
|
|
3412
2638
|
def _calculation_net_size(net):
|
|
3413
2639
|
"""Calculate the size of parameters in the network."""
|
|
3414
2640
|
data_total = 0
|
|
@@ -3702,3 +2928,35 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
3702
2928
|
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
3703
2929
|
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
3704
2930
|
mindspore.save_checkpoint(param_dict_tensor, dst_file)
|
|
2931
|
+
|
|
2932
|
+
|
|
2933
|
+
def restore_group_info_list(group_info_file_name):
|
|
2934
|
+
"""
|
|
2935
|
+
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2936
|
+
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2937
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2938
|
+
"""
|
|
2939
|
+
return new_restore_group_info_list(group_info_file_name)
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2943
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2944
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
|
|
2945
|
+
output_format='safetensors', name_map=None, max_process_num=64,
|
|
2946
|
+
return_param_dict=False):
|
|
2947
|
+
""" Load checkpoint into net for distributed predication. Used in the case of distributed inference. """
|
|
2948
|
+
new_load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy,
|
|
2949
|
+
train_strategy_filename, strict_load, dec_key, dec_mode,
|
|
2950
|
+
format, unified_safetensors_dir, dst_safetensors_dir, rank_id,
|
|
2951
|
+
output_format, name_map, max_process_num,
|
|
2952
|
+
return_param_dict)
|
|
2953
|
+
|
|
2954
|
+
|
|
2955
|
+
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2956
|
+
""" Merge parameter slices into one parameter. Used in the case of distributed inference. """
|
|
2957
|
+
return new_merge_sliced_parameter(sliced_parameters, strategy)
|
|
2958
|
+
|
|
2959
|
+
|
|
2960
|
+
def build_searched_strategy(strategy_filename):
|
|
2961
|
+
""" Build strategy of every parameter in network. Used in the case of distributed inference. """
|
|
2962
|
+
return new_build_searched_strategy(strategy_filename)
|