mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0rc1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +46 -197
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +217 -98
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +435 -371
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +951 -1992
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +314 -566
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +182 -116
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +157 -117
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +796 -759
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +921 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1370 -189
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +17 -13
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +365 -363
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +27 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +236 -46
- mindspore/ops/auto_generate/gen_extend_func.py +764 -124
- mindspore/ops/auto_generate/gen_ops_def.py +4018 -2264
- mindspore/ops/auto_generate/gen_ops_prim.py +15463 -5037
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4501 -3802
- mindspore/ops/function/nn_func.py +1726 -620
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +440 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +22 -7
- mindspore/ops/functional_overload.py +1440 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +13 -7
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +232 -78
- mindspore/ops/operations/debug_ops.py +153 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +210 -498
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1888 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +299 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +152 -34
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +698 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -58
- mindspore/parallel/transform_safetensors.py +363 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +106 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +409 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +88 -25
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +204 -107
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +184 -113
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -24,10 +24,12 @@ import os
|
|
|
24
24
|
import re
|
|
25
25
|
import shutil
|
|
26
26
|
import stat
|
|
27
|
+
import atexit
|
|
27
28
|
import threading
|
|
28
29
|
from threading import Thread, RLock
|
|
29
|
-
from multiprocessing import
|
|
30
|
-
|
|
30
|
+
from multiprocessing import active_children
|
|
31
|
+
import multiprocessing as mp
|
|
32
|
+
from collections import OrderedDict
|
|
31
33
|
from io import BytesIO
|
|
32
34
|
|
|
33
35
|
import math
|
|
@@ -36,6 +38,9 @@ import time
|
|
|
36
38
|
import google
|
|
37
39
|
import numpy as np
|
|
38
40
|
|
|
41
|
+
from safetensors.numpy import save_file, load_file
|
|
42
|
+
from safetensors import safe_open
|
|
43
|
+
|
|
39
44
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
40
45
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
41
46
|
from mindspore.train.print_pb2 import Print
|
|
@@ -44,43 +49,37 @@ import mindspore
|
|
|
44
49
|
import mindspore.nn as nn
|
|
45
50
|
from mindspore import context
|
|
46
51
|
from mindspore import log as logger
|
|
52
|
+
from mindspore.log import vlog_print
|
|
47
53
|
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
48
54
|
from mindspore import _checkparam as Validator
|
|
49
55
|
from mindspore.common import dtype as mstype
|
|
56
|
+
from mindspore.common import np_dtype
|
|
50
57
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
51
|
-
from mindspore.common.api import
|
|
58
|
+
from mindspore.common.api import _JitExecutor
|
|
52
59
|
from mindspore.common.api import _get_parameter_layout
|
|
53
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
54
60
|
from mindspore.common.initializer import initializer, One
|
|
55
61
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
56
62
|
from mindspore.common.tensor import Tensor
|
|
57
|
-
from mindspore._c_expression import
|
|
63
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
58
64
|
from mindspore.common._utils import is_shape_unknown
|
|
59
65
|
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
60
66
|
from mindspore.communication.management import get_rank, get_group_size
|
|
61
67
|
from mindspore.experimental import MapParameter
|
|
62
68
|
from mindspore.ops import Cast
|
|
63
69
|
from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
|
|
64
|
-
from mindspore.parallel._tensor import
|
|
65
|
-
from mindspore.parallel.
|
|
66
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
|
|
67
|
-
_get_device_num
|
|
68
|
-
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
69
|
-
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
70
|
-
_restore_group_info_list, _get_param_list_when_first_dim_sharded
|
|
70
|
+
from mindspore.parallel._tensor import _reshape_param_data
|
|
71
|
+
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|
71
72
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
72
73
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
73
74
|
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
74
|
-
from mindspore.parallel.
|
|
75
|
-
|
|
76
|
-
from mindspore.
|
|
77
|
-
from mindspore.
|
|
75
|
+
from mindspore.parallel.checkpoint_transform import restore_group_info_list as new_restore_group_info_list
|
|
76
|
+
from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
|
|
77
|
+
from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
|
|
78
|
+
from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
|
|
79
|
+
from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
|
|
80
|
+
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
|
|
78
81
|
split_mindir, split_dynamic_mindir
|
|
79
82
|
from mindspore.common.generator import Generator
|
|
80
|
-
from safetensors.numpy import save_file
|
|
81
|
-
from safetensors import safe_open
|
|
82
|
-
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
83
|
-
|
|
84
83
|
|
|
85
84
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
86
85
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
@@ -91,6 +90,9 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
|
|
|
91
90
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
92
91
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
93
92
|
|
|
93
|
+
if hasattr(np_dtype, "bfloat16"):
|
|
94
|
+
tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
|
|
95
|
+
|
|
94
96
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
95
97
|
|
|
96
98
|
mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
|
|
@@ -123,24 +125,55 @@ def init_ckpt_file_system(fs: FileSystem):
|
|
|
123
125
|
init_ckpt_file_system(_ckpt_fs)
|
|
124
126
|
|
|
125
127
|
|
|
128
|
+
def _wait_async_process_save_ckpt():
|
|
129
|
+
"""Waiting for asynchronous saving process of ckpt to complete"""
|
|
130
|
+
for process in active_children():
|
|
131
|
+
if process.name == "asyn_save_ckpt":
|
|
132
|
+
process.join()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _wait_async_thread_save_ckpt():
|
|
136
|
+
"""Waiting for asynchronous saving thread of ckpt to complete"""
|
|
137
|
+
thread_list = threading.enumerate()
|
|
138
|
+
for thread in thread_list:
|
|
139
|
+
if thread.getName() == "asyn_save_ckpt":
|
|
140
|
+
thread.join()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _async_save_close():
|
|
144
|
+
"""Waiting for asynchronous saving of ckpt to complete"""
|
|
145
|
+
_wait_async_process_save_ckpt()
|
|
146
|
+
_wait_async_thread_save_ckpt()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# Registering atexit handles asynchronous save
|
|
150
|
+
atexit.register(_async_save_close)
|
|
151
|
+
|
|
152
|
+
|
|
126
153
|
def _get_cur_rank_dp(parameter_layout_dict):
|
|
127
154
|
""" Get dp and tp from layout dict. """
|
|
128
|
-
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
129
|
-
dev_num = _get_device_num()
|
|
130
155
|
global_rank = get_rank()
|
|
131
|
-
|
|
132
|
-
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
133
|
-
parameter_redundancy_dict = get_parameter_redundancy(
|
|
134
|
-
parameter_layout_dict, initial_rank)
|
|
156
|
+
parameter_redundancy_dict = get_parameter_redundancy(parameter_layout_dict)
|
|
135
157
|
value_len = sys.maxsize
|
|
136
158
|
min_value = ()
|
|
159
|
+
min_value_set = set()
|
|
137
160
|
for key, value in parameter_redundancy_dict.items():
|
|
138
|
-
if "accu_grads"
|
|
161
|
+
if key.startswith("accu_grads") or key.startswith("inputs"):
|
|
139
162
|
continue
|
|
140
163
|
for item in value:
|
|
141
|
-
if
|
|
164
|
+
if global_rank not in item:
|
|
165
|
+
continue
|
|
166
|
+
# if item is subset of min_value_set, update min_value_set and min_value
|
|
167
|
+
if len(item) < value_len:
|
|
168
|
+
if min_value_set and not set(item).issubset(min_value_set):
|
|
169
|
+
return (global_rank,)
|
|
142
170
|
value_len = len(item)
|
|
171
|
+
min_value_set = set(item)
|
|
143
172
|
min_value = item
|
|
173
|
+
# if value is not smaller than len of min_value len,
|
|
174
|
+
# check if min_value_set is subset of current item
|
|
175
|
+
elif not min_value_set.issubset(set(item)):
|
|
176
|
+
return (global_rank,)
|
|
144
177
|
return min_value
|
|
145
178
|
|
|
146
179
|
|
|
@@ -160,7 +193,7 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
|
160
193
|
cur_strategy_path (str): strategy file path for current rank.
|
|
161
194
|
|
|
162
195
|
Returns:
|
|
163
|
-
- new_ckpt_file (
|
|
196
|
+
- new_ckpt_file (str), if found available checkpoint file , return it.
|
|
164
197
|
- None, if not found available checkpoint, return None.
|
|
165
198
|
|
|
166
199
|
Examples:
|
|
@@ -175,6 +208,9 @@ def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
|
175
208
|
>>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
|
|
176
209
|
>>> print(ckpt_file_new)
|
|
177
210
|
"""
|
|
211
|
+
cur_rank = get_rank()
|
|
212
|
+
if f"rank_{str(cur_rank)}" in cur_ckpt_path and os.path.isfile(cur_ckpt_path):
|
|
213
|
+
return cur_ckpt_path
|
|
178
214
|
dp = _get_cur_rank_dp(cur_strategy_path)
|
|
179
215
|
pattern = r'rank_\d+'
|
|
180
216
|
for i in dp:
|
|
@@ -282,7 +318,8 @@ def _type_convert(param, new_param, strict_load):
|
|
|
282
318
|
{param.data.dtype, new_param.data.dtype}.issubset(int_type)):
|
|
283
319
|
logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
|
|
284
320
|
f"the type of it in 'net':{param.data.dtype}, then the type convert from "
|
|
285
|
-
f"{new_param.data.dtype} to {param.data.dtype} in the network."
|
|
321
|
+
f"{new_param.data.dtype} to {param.data.dtype} in the network. May consume additional memory "
|
|
322
|
+
f"and time")
|
|
286
323
|
return True
|
|
287
324
|
return False
|
|
288
325
|
|
|
@@ -338,6 +375,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
338
375
|
os.chmod(tmp_name, stat.S_IWUSR)
|
|
339
376
|
os.remove(tmp_name)
|
|
340
377
|
if format == "ckpt":
|
|
378
|
+
ckpt_total_io_time = 0
|
|
341
379
|
with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
|
|
342
380
|
plain_data = None
|
|
343
381
|
if enc_key is not None:
|
|
@@ -354,20 +392,26 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
354
392
|
if value[0] == "offload_parameter":
|
|
355
393
|
new_value = value[1:]
|
|
356
394
|
new_value[2] = value[3]
|
|
357
|
-
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
395
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data, ckpt_total_io_time)
|
|
358
396
|
_offload_if_config(value[3])
|
|
359
397
|
continue
|
|
360
398
|
if value[1] == "str":
|
|
361
|
-
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
399
|
+
crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
400
|
+
crc_num, crc_check,
|
|
401
|
+
ckpt_total_io_time)
|
|
362
402
|
continue
|
|
363
403
|
if isinstance(value[2], np.ndarray):
|
|
364
|
-
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
404
|
+
crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
|
|
405
|
+
crc_num, crc_check,
|
|
406
|
+
ckpt_total_io_time)
|
|
365
407
|
continue
|
|
366
408
|
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
367
409
|
_write_hugeparameter(name, value, f)
|
|
368
410
|
continue
|
|
369
411
|
|
|
370
|
-
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
412
|
+
crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
|
|
413
|
+
crc_num, crc_check,
|
|
414
|
+
ckpt_total_io_time)
|
|
371
415
|
|
|
372
416
|
if enc_key is not None:
|
|
373
417
|
plain_data.seek(0)
|
|
@@ -378,11 +422,36 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
378
422
|
block_data = plain_data.read(max_block_size)
|
|
379
423
|
if crc_check:
|
|
380
424
|
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
425
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
426
|
+
f"Save ckpt io cost time:{ckpt_total_io_time}.")
|
|
427
|
+
|
|
381
428
|
elif format == "safetensors":
|
|
382
429
|
save_dict = {}
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
430
|
+
crc_num = 0
|
|
431
|
+
for name in sorted(data_list.keys()):
|
|
432
|
+
value = data_list[name]
|
|
433
|
+
if isinstance(value[2], np.ndarray):
|
|
434
|
+
save_dict[name] = value[2]
|
|
435
|
+
else:
|
|
436
|
+
bytes_data = value[2].get_bytes()
|
|
437
|
+
np_type = tensor_to_np_type.get(value[1])
|
|
438
|
+
np_array = np.frombuffer(bytes_data, np_type)
|
|
439
|
+
new_np_array = np_array.reshape(value[0])
|
|
440
|
+
save_dict[name] = new_np_array
|
|
441
|
+
|
|
442
|
+
if crc_check:
|
|
443
|
+
crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
|
|
444
|
+
crc_num = binascii.crc32(
|
|
445
|
+
bytes(save_dict[name]), crc_num)
|
|
446
|
+
safetensors_save_time_start = time.time()
|
|
447
|
+
if crc_check:
|
|
448
|
+
save_file(save_dict, tmp_name, metadata={
|
|
449
|
+
"crc_num": str(crc_num)})
|
|
450
|
+
else:
|
|
451
|
+
save_file(save_dict, tmp_name)
|
|
452
|
+
safetensors_save_time_end = time.time()
|
|
453
|
+
cost_time = safetensors_save_time_end - safetensors_save_time_start
|
|
454
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
|
|
386
455
|
if not os.path.exists(tmp_name):
|
|
387
456
|
logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
|
|
388
457
|
f"simultaneously modified a file.")
|
|
@@ -407,7 +476,7 @@ def _write_random_seed(name, value, f):
|
|
|
407
476
|
f.write(checkpoint_list.SerializeToString())
|
|
408
477
|
|
|
409
478
|
|
|
410
|
-
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
479
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
|
|
411
480
|
"""Write parameter data into protobuf file."""
|
|
412
481
|
data_size = value[2].nbytes / 1024
|
|
413
482
|
if data_size > SLICE_SIZE:
|
|
@@ -429,14 +498,18 @@ def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_ch
|
|
|
429
498
|
output_data = checkpoint_list.SerializeToString()
|
|
430
499
|
if crc_check:
|
|
431
500
|
crc_num = binascii.crc32(output_data, crc_num)
|
|
501
|
+
io_start_time = time.time()
|
|
432
502
|
f.write(output_data)
|
|
503
|
+
io_end_time = time.time()
|
|
504
|
+
io_cost_time = io_end_time - io_start_time
|
|
505
|
+
ckpt_total_io_time += io_cost_time
|
|
433
506
|
else:
|
|
434
507
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
435
508
|
|
|
436
|
-
return crc_num
|
|
509
|
+
return crc_num, ckpt_total_io_time
|
|
437
510
|
|
|
438
511
|
|
|
439
|
-
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
512
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
|
|
440
513
|
"""Write parameter bytes data into protobuf file."""
|
|
441
514
|
bytes_value = value[2].get_bytes()
|
|
442
515
|
chunk_size = 1024 * SLICE_SIZE
|
|
@@ -454,11 +527,15 @@ def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0,
|
|
|
454
527
|
output_data = checkpoint_list.SerializeToString()
|
|
455
528
|
if crc_check:
|
|
456
529
|
crc_num = binascii.crc32(output_data, crc_num)
|
|
530
|
+
io_start_time = time.time()
|
|
457
531
|
f.write(output_data)
|
|
532
|
+
io_end_time = time.time()
|
|
533
|
+
io_cost_time = io_end_time - io_start_time
|
|
534
|
+
ckpt_total_io_time += io_cost_time
|
|
458
535
|
else:
|
|
459
536
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
460
537
|
|
|
461
|
-
return crc_num
|
|
538
|
+
return crc_num, ckpt_total_io_time
|
|
462
539
|
|
|
463
540
|
|
|
464
541
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
@@ -522,12 +599,56 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
|
522
599
|
return ckpt_file_name
|
|
523
600
|
|
|
524
601
|
|
|
525
|
-
def
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
602
|
+
def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
|
|
603
|
+
"""check load checkpoint unsupported param"""
|
|
604
|
+
if format != "safetensors":
|
|
605
|
+
return
|
|
606
|
+
default_params = {
|
|
607
|
+
"dec_key": None,
|
|
608
|
+
"dec_mode": "AES-GCM",
|
|
609
|
+
}
|
|
610
|
+
for param_name, default_value in default_params.items():
|
|
611
|
+
current_value = locals()[param_name]
|
|
612
|
+
if current_value != default_value:
|
|
613
|
+
raise ValueError(f"For 'load_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
|
|
614
|
+
f"be set to default value '{default_value}', but got '{current_value}'.")
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def _check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
|
|
618
|
+
"""check save checkpoint unsupported param"""
|
|
619
|
+
if format != "safetensors":
|
|
620
|
+
return
|
|
621
|
+
default_params = {
|
|
622
|
+
"enc_key": None,
|
|
623
|
+
"enc_mode": "AES-GCM",
|
|
624
|
+
"map_param_inc": False,
|
|
625
|
+
"global_step_num": None
|
|
626
|
+
}
|
|
627
|
+
for param_name, default_value in default_params.items():
|
|
628
|
+
current_value = locals()[param_name]
|
|
629
|
+
if current_value != default_value:
|
|
630
|
+
raise ValueError(f"For 'save_checkpoint', when format is 'safetensors', the parameter '{param_name}' must "
|
|
631
|
+
f"be set to default value '{default_value}', but got '{current_value}'.")
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _check_async_save(async_save):
|
|
635
|
+
"""Check async_save for save_checkpoint."""
|
|
636
|
+
if not isinstance(async_save, (bool, str)):
|
|
637
|
+
raise TypeError("For 'save_checkpoint', the parameter 'async_save' must be bool or str, "
|
|
638
|
+
"but got {}.".format(type(async_save)))
|
|
639
|
+
if isinstance(async_save, str):
|
|
640
|
+
if async_save not in ("process", "thread"):
|
|
641
|
+
raise ValueError("For 'save_checkpoint', the argument 'async_save' can only be 'process' or 'thread',"
|
|
642
|
+
"but got {}.".format(async_save))
|
|
643
|
+
return async_save
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
|
|
647
|
+
crc_check=False, format="ckpt", cond=None):
|
|
648
|
+
"""Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
|
|
649
|
+
with cond:
|
|
650
|
+
cond.notify()
|
|
651
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
531
652
|
|
|
532
653
|
|
|
533
654
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
@@ -541,13 +662,19 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
541
662
|
|
|
542
663
|
Args:
|
|
543
664
|
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
544
|
-
list, or dict.
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
665
|
+
list, or dict.
|
|
666
|
+
|
|
667
|
+
- If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
668
|
+
elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
|
|
669
|
+
`param_name` must be string, and the type of `param_data` must be parameter or Tensor).
|
|
670
|
+
- If dict, it can be the returned value of :func:`mindspore.load_checkpoint`.
|
|
671
|
+
|
|
548
672
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
|
549
673
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
|
|
550
|
-
async_save (bool): Whether to
|
|
674
|
+
async_save (Union[bool, str], optional): Whether to use asynchronous saving of the checkpoint file or
|
|
675
|
+
safetensors file, if True, the asynchronous thread is used by default. If the type
|
|
676
|
+
is string, the method of asynchronous saving, it can be "process" or "thread".
|
|
677
|
+
Default: ``False`` .
|
|
551
678
|
append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
|
|
552
679
|
of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
|
|
553
680
|
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
|
|
@@ -557,9 +684,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
557
684
|
Default: ``"AES-GCM"`` .
|
|
558
685
|
choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
|
|
559
686
|
a parameter name in string type, and the returned value is a bool.
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
687
|
+
Default: ``None`` .
|
|
688
|
+
|
|
689
|
+
- If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
690
|
+
- If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
691
|
+
be saved.
|
|
692
|
+
|
|
563
693
|
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
564
694
|
result to the file. Default: ``False`` .
|
|
565
695
|
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
@@ -567,8 +697,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
567
697
|
|
|
568
698
|
Raises:
|
|
569
699
|
TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
|
|
570
|
-
TypeError: If the parameter `integrated_save`
|
|
700
|
+
TypeError: If the parameter `integrated_save` is not bool type.
|
|
571
701
|
TypeError: If the parameter `ckpt_file_name` is not string type.
|
|
702
|
+
TypeError: If the parameter `async_save` is not bool or string type.
|
|
703
|
+
ValueError: If the parameter `async_save` is string type but not in ["process", "thread"].
|
|
572
704
|
|
|
573
705
|
Examples:
|
|
574
706
|
>>> import mindspore as ms
|
|
@@ -596,9 +728,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
596
728
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
597
729
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
598
730
|
"""
|
|
731
|
+
start_save_time = time.time()
|
|
599
732
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
600
733
|
integrated_save = Validator.check_bool(integrated_save)
|
|
601
|
-
async_save =
|
|
734
|
+
async_save = _check_async_save(async_save)
|
|
602
735
|
append_dict = _check_append_dict(append_dict)
|
|
603
736
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
604
737
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
@@ -606,12 +739,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
606
739
|
map_param_inc = kwargs.get('incremental', False)
|
|
607
740
|
logger.info("Execute the process of saving checkpoint files.")
|
|
608
741
|
global_step_num = kwargs.get('global_step_num', None)
|
|
609
|
-
|
|
742
|
+
_check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
|
|
610
743
|
|
|
611
744
|
if append_dict and "__exception_save__" in append_dict:
|
|
612
745
|
s1 = mindspore.hal.Stream()
|
|
613
746
|
with mindspore.hal.StreamCtx(s1):
|
|
614
747
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
748
|
+
for k_name, value in append_dict.items():
|
|
749
|
+
if isinstance(value, (Tensor, Parameter)):
|
|
750
|
+
append_dict[k_name] = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
615
751
|
s1.synchronize()
|
|
616
752
|
else:
|
|
617
753
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
@@ -682,23 +818,74 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
682
818
|
data_list[key].append(dims)
|
|
683
819
|
tensor_type = str(param["data"].dtype)
|
|
684
820
|
data_list[key].append(tensor_type)
|
|
685
|
-
data = param["data"]
|
|
821
|
+
data = param["data"] if async_save is False else param["data"].asnumpy()
|
|
686
822
|
data_list[key].append(data)
|
|
687
823
|
|
|
824
|
+
from mindspore.profiler import mstx
|
|
825
|
+
range_id = mstx.range_start('save_checkpoint', None)
|
|
688
826
|
if os.getenv("AITURBO") == "1":
|
|
689
827
|
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
690
828
|
ckpt_name = os.path.basename(ckpt_file_name)
|
|
691
829
|
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
|
|
692
830
|
elif async_save:
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
831
|
+
if async_save == "process":
|
|
832
|
+
if sys.platform.startswith("win"):
|
|
833
|
+
logger.warining("The Win platform currently does not support asynchronous process saving of ckpt, "
|
|
834
|
+
"so serial saving of ckpt is used now.")
|
|
835
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
836
|
+
else:
|
|
837
|
+
_wait_async_process_save_ckpt()
|
|
838
|
+
ctx = mp.get_context("fork")
|
|
839
|
+
cond = ctx.Condition()
|
|
840
|
+
process_flag = True
|
|
841
|
+
while process_flag:
|
|
842
|
+
process = ctx.Process(target=_async_process_save,
|
|
843
|
+
args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
|
|
844
|
+
format, cond), daemon=True, name="asyn_save_ckpt")
|
|
845
|
+
process.start()
|
|
846
|
+
with cond:
|
|
847
|
+
wait_flag = cond.wait(timeout=5)
|
|
848
|
+
if not wait_flag:
|
|
849
|
+
logger.warning("Async save process fails to create. will kill and recreate")
|
|
850
|
+
process.kill()
|
|
851
|
+
else:
|
|
852
|
+
process_flag = False
|
|
853
|
+
else:
|
|
854
|
+
data_copy = copy.deepcopy(data_list)
|
|
855
|
+
_wait_async_thread_save_ckpt()
|
|
856
|
+
thr = Thread(target=_exec_save,
|
|
857
|
+
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
|
|
858
|
+
name="asyn_save_ckpt")
|
|
859
|
+
thr.start()
|
|
698
860
|
else:
|
|
699
861
|
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
700
862
|
|
|
863
|
+
mstx.range_end(range_id)
|
|
701
864
|
logger.info("Saving checkpoint process is finished.")
|
|
865
|
+
end_save_time = time.time()
|
|
866
|
+
save_checkpoint_cost_time = end_save_time - start_save_time
|
|
867
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save checkpoint cost time {save_checkpoint_cost_time}.")
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def _handle_shared_param_for_pipeline_parallel(save_obj):
|
|
871
|
+
""" Remove shared param for save_obj """
|
|
872
|
+
filtered_save_obj = []
|
|
873
|
+
for param_dict in save_obj:
|
|
874
|
+
cur_param = param_dict['data']
|
|
875
|
+
if isinstance(cur_param, Parameter):
|
|
876
|
+
if not cur_param.param_info.is_pipeline_shared_param:
|
|
877
|
+
filtered_save_obj.append(param_dict)
|
|
878
|
+
else:
|
|
879
|
+
filtered_save_obj.append(param_dict)
|
|
880
|
+
return filtered_save_obj
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def _is_auto_parallel_mode(save_obj):
|
|
884
|
+
"""Check if in auto parallel mode by verifying parameter initialization."""
|
|
885
|
+
for _, param in save_obj.parameters_and_names():
|
|
886
|
+
if param.param_info.is_param_init:
|
|
887
|
+
return True
|
|
888
|
+
return False
|
|
702
889
|
|
|
703
890
|
|
|
704
891
|
def _convert_list_to_param_list(save_obj, choice_func):
|
|
@@ -739,7 +926,7 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
739
926
|
"""Convert a dict of Parameter to param_list."""
|
|
740
927
|
param_list = []
|
|
741
928
|
for (key, value) in save_obj.items():
|
|
742
|
-
if isinstance(key, str) and isinstance(value, (Parameter, str)):
|
|
929
|
+
if isinstance(key, str) and (isinstance(value, (Parameter, str)) or _is_buffer_type(value)):
|
|
743
930
|
if choice_func is not None and not choice_func(key):
|
|
744
931
|
continue
|
|
745
932
|
each_param = {"name": key, "data": value}
|
|
@@ -751,15 +938,19 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
751
938
|
return param_list
|
|
752
939
|
|
|
753
940
|
|
|
754
|
-
def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
941
|
+
def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
|
|
755
942
|
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
756
943
|
param_dict = OrderedDict()
|
|
757
944
|
for _, param in save_obj.parameters_and_names():
|
|
945
|
+
if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
|
|
946
|
+
continue
|
|
758
947
|
not_sliced = not param.sliced
|
|
759
948
|
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
760
949
|
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
761
950
|
judgment = not_sliced or param.has_init
|
|
762
|
-
if
|
|
951
|
+
if param.param_info.is_pipeline_shared_param:
|
|
952
|
+
continue
|
|
953
|
+
if is_graph_mode and is_parallel_mode and judgment:
|
|
763
954
|
continue
|
|
764
955
|
if choice_func is not None and not choice_func(param.name):
|
|
765
956
|
continue
|
|
@@ -777,11 +968,12 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
777
968
|
sync_pipeline_shared_parameters(save_obj)
|
|
778
969
|
param_list = []
|
|
779
970
|
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
780
|
-
|
|
971
|
+
is_parallel_mode = _is_auto_parallel_mode(save_obj)
|
|
972
|
+
if is_parallel_mode and not parameter_layout_dict:
|
|
781
973
|
parameter_layout_dict = _get_parameter_layout()
|
|
782
|
-
if not
|
|
974
|
+
if not is_parallel_mode:
|
|
783
975
|
save_obj.init_parameters_data()
|
|
784
|
-
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
|
|
976
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
|
|
785
977
|
if append_dict and "random_op" in append_dict:
|
|
786
978
|
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
787
979
|
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
@@ -829,11 +1021,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
829
1021
|
|
|
830
1022
|
def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
831
1023
|
"""Convert a save_obj to param_list."""
|
|
832
|
-
if isinstance(save_obj, list):
|
|
833
|
-
|
|
1024
|
+
if isinstance(save_obj, (list, dict)):
|
|
1025
|
+
if isinstance(save_obj, list):
|
|
1026
|
+
save_obj = _convert_list_to_param_list(save_obj, choice_func)
|
|
1027
|
+
|
|
1028
|
+
if isinstance(save_obj, dict):
|
|
1029
|
+
save_obj = _convert_dict_to_param_dict(save_obj, choice_func)
|
|
834
1030
|
|
|
835
|
-
|
|
836
|
-
return _convert_dict_to_param_dict(save_obj, choice_func)
|
|
1031
|
+
return _handle_shared_param_for_pipeline_parallel(save_obj)
|
|
837
1032
|
|
|
838
1033
|
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
839
1034
|
|
|
@@ -864,11 +1059,8 @@ def _check_append_dict(append_dict):
|
|
|
864
1059
|
return append_dict
|
|
865
1060
|
|
|
866
1061
|
|
|
867
|
-
def
|
|
868
|
-
if
|
|
869
|
-
customized_func = _check_customized_func(kwargs.get('obf_func'))
|
|
870
|
-
clean_funcs()
|
|
871
|
-
add_opaque_predicate(customized_func.__name__, customized_func)
|
|
1062
|
+
def _is_buffer_type(value):
|
|
1063
|
+
if isinstance(value, Tensor) and getattr(value, "_is_buffer", False):
|
|
872
1064
|
return True
|
|
873
1065
|
return False
|
|
874
1066
|
|
|
@@ -885,20 +1077,18 @@ def load(file_name, **kwargs):
|
|
|
885
1077
|
kwargs (dict): Configuration options dictionary.
|
|
886
1078
|
|
|
887
1079
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
888
|
-
- dec_mode (Union[str, function]):
|
|
1080
|
+
- dec_mode (Union[str, function], optional):
|
|
1081
|
+
Specifies the decryption mode, to take effect when dec_key is set.
|
|
889
1082
|
|
|
890
1083
|
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
|
|
891
1084
|
- For details of using the customized decryption, please check the `tutorial
|
|
892
1085
|
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
893
1086
|
|
|
894
|
-
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
895
|
-
`obfuscate_model()
|
|
896
|
-
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
897
|
-
|
|
898
1087
|
Returns:
|
|
899
1088
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
900
1089
|
|
|
901
1090
|
Raises:
|
|
1091
|
+
NotImplementedError: Dynamic model structure obfuscation is no longer supported.
|
|
902
1092
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
903
1093
|
RuntimeError: Failed to parse MindIR file.
|
|
904
1094
|
|
|
@@ -925,6 +1115,8 @@ def load(file_name, **kwargs):
|
|
|
925
1115
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
926
1116
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
927
1117
|
"""
|
|
1118
|
+
if 'obf_func' in kwargs.keys():
|
|
1119
|
+
raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
|
|
928
1120
|
if not isinstance(file_name, str):
|
|
929
1121
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
930
1122
|
"got {}.".format(type(file_name)))
|
|
@@ -936,9 +1128,6 @@ def load(file_name, **kwargs):
|
|
|
936
1128
|
"please check whether the 'file_name' is correct.")
|
|
937
1129
|
file_name = os.path.realpath(file_name)
|
|
938
1130
|
|
|
939
|
-
# set customized functions for dynamic obfuscation
|
|
940
|
-
obfuscated = _check_load_obfuscate(**kwargs)
|
|
941
|
-
|
|
942
1131
|
logger.info("Execute the process of loading mindir.")
|
|
943
1132
|
if 'dec_key' in kwargs.keys():
|
|
944
1133
|
dec_key = Validator.check_isinstance('dec_key', kwargs.get('dec_key'), bytes)
|
|
@@ -951,9 +1140,9 @@ def load(file_name, **kwargs):
|
|
|
951
1140
|
else:
|
|
952
1141
|
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
|
953
1142
|
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
|
|
954
|
-
decrypt=dec_func
|
|
1143
|
+
decrypt=dec_func)
|
|
955
1144
|
else:
|
|
956
|
-
graph = load_mindir(file_name
|
|
1145
|
+
graph = load_mindir(file_name)
|
|
957
1146
|
|
|
958
1147
|
if graph is None:
|
|
959
1148
|
if _is_cipher_file(file_name):
|
|
@@ -1020,189 +1209,45 @@ def _check_param_type(param_config, key, target_type, requested):
|
|
|
1020
1209
|
if key in param_config:
|
|
1021
1210
|
if not isinstance(param_config[key], target_type):
|
|
1022
1211
|
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
|
|
1023
|
-
if key == 'obf_random_seed':
|
|
1024
|
-
if param_config[key] > INT_64_MAX or param_config[key] <= 0:
|
|
1025
|
-
raise ValueError(
|
|
1026
|
-
"'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
|
|
1027
|
-
param_config[key]))
|
|
1028
1212
|
return param_config[key]
|
|
1029
1213
|
if requested:
|
|
1030
1214
|
raise ValueError("The parameter {} is requested, but not got.".format(key))
|
|
1031
|
-
if key == "obf_random_seed":
|
|
1032
|
-
return 0
|
|
1033
1215
|
return None
|
|
1034
1216
|
|
|
1035
1217
|
|
|
1036
|
-
def _check_customized_func(customized_func):
|
|
1037
|
-
""" check customized function of dynamic obfuscation """
|
|
1038
|
-
if not callable(customized_func):
|
|
1039
|
-
raise TypeError(
|
|
1040
|
-
"'customized_func' must be a function, but not got {}.".format(type(customized_func)))
|
|
1041
|
-
# test customized_func
|
|
1042
|
-
try:
|
|
1043
|
-
func_result = customized_func(1.0, 1.0)
|
|
1044
|
-
except Exception as ex:
|
|
1045
|
-
raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
|
|
1046
|
-
else:
|
|
1047
|
-
if not isinstance(func_result, bool):
|
|
1048
|
-
raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
|
|
1049
|
-
return customized_func
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
def _check_obfuscate_params(obf_config):
|
|
1053
|
-
"""Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
|
|
1054
|
-
if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
|
|
1055
|
-
raise ValueError(
|
|
1056
|
-
"At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
|
|
1057
|
-
obfuscate_type = _check_param_type(obf_config, "type", str, False)
|
|
1058
|
-
if obfuscate_type not in (None, "dynamic"):
|
|
1059
|
-
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
|
|
1060
|
-
if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
|
|
1061
|
-
if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
|
|
1062
|
-
raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
|
|
1063
|
-
obf_config['obf_ratio']))
|
|
1064
|
-
ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
|
|
1065
|
-
obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
|
|
1066
|
-
obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
|
|
1067
|
-
if (obf_ratio <= 0) or (obf_ratio > 1):
|
|
1068
|
-
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
|
|
1069
|
-
customized_funcs = []
|
|
1070
|
-
if 'customized_func' in obf_config.keys():
|
|
1071
|
-
device_target = context.get_context('device_target')
|
|
1072
|
-
if device_target in ["GPU", "Ascend"]:
|
|
1073
|
-
raise ValueError(
|
|
1074
|
-
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
1075
|
-
customized_funcs.append(_check_customized_func(obf_config['customized_func']))
|
|
1076
|
-
obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
|
|
1077
|
-
return obf_ratio, customized_funcs, obf_random_seed
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
def obfuscate_model(obf_config, **kwargs):
|
|
1081
|
-
"""
|
|
1082
|
-
Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
|
|
1083
|
-
predict correctness. The obfuscated model can prevent attackers from stealing the model.
|
|
1084
|
-
|
|
1085
|
-
Args:
|
|
1086
|
-
obf_config (dict): obfuscation config.
|
|
1087
|
-
|
|
1088
|
-
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1089
|
-
- original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
|
|
1090
|
-
model is encrypted, then enc_key and enc_mode should be provided.
|
|
1091
|
-
- save_model_path (str): The path to save the obfuscated model.
|
|
1092
|
-
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
1093
|
-
is the same as using :func:`mindspore.export`.
|
|
1094
|
-
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1095
|
-
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1096
|
-
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1097
|
-
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1098
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1099
|
-
Reference to 'my_func()' in
|
|
1100
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1101
|
-
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1102
|
-
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
1103
|
-
when loading obfuscated model.
|
|
1104
|
-
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1105
|
-
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1106
|
-
`obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
|
|
1107
|
-
interface when loading
|
|
1108
|
-
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1109
|
-
be set, and the latter mode would be applied if both of them are set.
|
|
1110
|
-
|
|
1111
|
-
kwargs (dict): Configuration options dictionary.
|
|
1112
|
-
|
|
1113
|
-
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
1114
|
-
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
1115
|
-
Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
|
|
1116
|
-
|
|
1117
|
-
Raises:
|
|
1118
|
-
TypeError: If `obf_config` is not a dict.
|
|
1119
|
-
ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
|
|
1120
|
-
ValueError: If `original_model_path` is not provided in `obf_config`.
|
|
1121
|
-
ValueError: If the model saved in `original_model_path` has been obfuscated.
|
|
1122
|
-
ValueError: If `save_model_path` is not provided in `obf_config`.
|
|
1123
|
-
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
1124
|
-
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
1125
|
-
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
1126
|
-
ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
|
|
1127
|
-
|
|
1128
|
-
Examples:
|
|
1129
|
-
>>> import mindspore as ms
|
|
1130
|
-
>>> import mindspore.nn as nn
|
|
1131
|
-
>>> import numpy as np
|
|
1132
|
-
>>> # Download ori_net.mindir
|
|
1133
|
-
>>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
|
|
1134
|
-
>>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
|
1135
|
-
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
1136
|
-
... 'save_model_path': "./obf_net",
|
|
1137
|
-
... 'model_inputs': [input1, ],
|
|
1138
|
-
... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
|
|
1139
|
-
>>> ms.obfuscate_model(obf_config)
|
|
1140
|
-
>>> obf_func = ms.load("obf_net.mindir")
|
|
1141
|
-
>>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
|
|
1142
|
-
>>> print(obf_net(input1).asnumpy())
|
|
1143
|
-
"""
|
|
1144
|
-
if not isinstance(obf_config, dict):
|
|
1145
|
-
raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
|
|
1146
|
-
file_path = _check_param_type(obf_config, "original_model_path", str, True)
|
|
1147
|
-
if not file_path.endswith(".mindir"):
|
|
1148
|
-
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
|
|
1149
|
-
"please input the correct 'file_path'.")
|
|
1150
|
-
if not os.path.exists(file_path):
|
|
1151
|
-
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
|
|
1152
|
-
"please check whether the 'file_path' is correct.")
|
|
1153
|
-
saved_path = _check_param_type(obf_config, "save_model_path", str, True)
|
|
1154
|
-
model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
|
|
1155
|
-
for item in model_inputs:
|
|
1156
|
-
if not isinstance(item, Tensor):
|
|
1157
|
-
raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
|
|
1158
|
-
if -1 in item.shape:
|
|
1159
|
-
raise ValueError(
|
|
1160
|
-
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
1161
|
-
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
|
|
1162
|
-
if customized_funcs and obf_random_seed > 0:
|
|
1163
|
-
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
1164
|
-
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
1165
|
-
|
|
1166
|
-
if obf_random_seed == 0: # apply customized_func mode
|
|
1167
|
-
clean_funcs()
|
|
1168
|
-
for func in customized_funcs:
|
|
1169
|
-
add_opaque_predicate(func.__name__, func)
|
|
1170
|
-
branch_control_input = 0
|
|
1171
|
-
else: # apply password mode
|
|
1172
|
-
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
1173
|
-
|
|
1174
|
-
if 'enc_key' in kwargs.keys():
|
|
1175
|
-
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
|
1176
|
-
enc_mode = "AES-GCM"
|
|
1177
|
-
if 'enc_mode' in kwargs.keys():
|
|
1178
|
-
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
1179
|
-
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
1180
|
-
raise ValueError(
|
|
1181
|
-
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
1182
|
-
"obfuscate_model(), but got {}.".format(enc_mode))
|
|
1183
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
1184
|
-
branch_control_input=branch_control_input, dec_key=enc_key,
|
|
1185
|
-
key_len=len(enc_key),
|
|
1186
|
-
dec_mode=enc_mode)
|
|
1187
|
-
else:
|
|
1188
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
1189
|
-
branch_control_input=branch_control_input)
|
|
1190
|
-
|
|
1191
|
-
obf_net = nn.GraphCell(obf_graph)
|
|
1192
|
-
if obf_random_seed != 0:
|
|
1193
|
-
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
1194
|
-
model_inputs += [append_y_tensor]
|
|
1195
|
-
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
1218
|
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1199
1219
|
dec_mode, crc_check, format):
|
|
1200
1220
|
"""load parameter into parameter_dict"""
|
|
1201
1221
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
|
|
1202
1222
|
if format == "safetensors":
|
|
1203
1223
|
with safe_open(ckpt_file_name, framework='np') as f:
|
|
1204
|
-
|
|
1205
|
-
|
|
1224
|
+
cal_crc_num = 0
|
|
1225
|
+
total_io_cost_time = 0
|
|
1226
|
+
for k in sorted(f.keys()):
|
|
1227
|
+
if crc_check:
|
|
1228
|
+
cal_crc_num = binascii.crc32(bytes(k, encoding='utf-8'), cal_crc_num)
|
|
1229
|
+
cal_crc_num = binascii.crc32(bytes(f.get_tensor(k)), cal_crc_num)
|
|
1230
|
+
if choice_func is not None and not choice_func(k):
|
|
1231
|
+
continue
|
|
1232
|
+
io_start_time = time.time()
|
|
1233
|
+
value = f.get_tensor(k)
|
|
1234
|
+
io_end_time = time.time()
|
|
1235
|
+
io_cost_time = io_end_time - io_start_time
|
|
1236
|
+
total_io_cost_time += io_cost_time
|
|
1237
|
+
parameter_dict[k] = Parameter(Tensor.from_numpy(value))
|
|
1238
|
+
|
|
1239
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1240
|
+
f"Load safetensors io cost time:{total_io_cost_time}.")
|
|
1241
|
+
if crc_check:
|
|
1242
|
+
if f.metadata() is None or f.metadata().get("crc_num") is None:
|
|
1243
|
+
logger.warning(
|
|
1244
|
+
"For 'load_checkpoint', the safetensors file do not contain the crc code, "
|
|
1245
|
+
"please check the file.")
|
|
1246
|
+
else:
|
|
1247
|
+
crc_num = int(f.metadata()["crc_num"])
|
|
1248
|
+
if cal_crc_num != crc_num:
|
|
1249
|
+
raise ValueError("For 'load_checkpoint', the crc check has failed. "
|
|
1250
|
+
"Please check whether the ckpt file is damaged.")
|
|
1206
1251
|
return
|
|
1207
1252
|
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1208
1253
|
try:
|
|
@@ -1270,38 +1315,37 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1270
1315
|
Load checkpoint info from a specified file.
|
|
1271
1316
|
|
|
1272
1317
|
Note:
|
|
1273
|
-
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1274
|
-
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1275
1318
|
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1276
|
-
`choice_func` is recommended instead.
|
|
1319
|
+
`choice_func` is recommended instead. `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1277
1320
|
And using either of those two args will override `choice_func` at the same time.
|
|
1321
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1278
1322
|
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1279
1323
|
|
|
1280
1324
|
Args:
|
|
1281
1325
|
ckpt_file_name (str): Checkpoint file name.
|
|
1282
|
-
net (Cell): The network where the parameters will be loaded. Default: ``None`` .
|
|
1283
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1284
|
-
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1326
|
+
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1327
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1328
|
+
parameter into net when parameter name's suffix in checkpoint file is the same as the
|
|
1285
1329
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1286
1330
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1287
|
-
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`).
|
|
1288
|
-
filter_prefix will not be loaded. Default: ``None`` .
|
|
1289
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` ,
|
|
1290
|
-
is not required. Default: ``None`` .
|
|
1291
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
|
|
1292
|
-
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
1331
|
+
filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
|
|
1332
|
+
Parameters starting with the filter_prefix will not be loaded. Default: ``None`` .
|
|
1333
|
+
dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
|
|
1334
|
+
the decryption is not required. Default: ``None`` .
|
|
1335
|
+
dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies the
|
|
1336
|
+
decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
1293
1337
|
Default: ``"AES-GCM"`` .
|
|
1294
|
-
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`).
|
|
1295
|
-
specify_prefix will be loaded. Default: ``None`` .
|
|
1296
|
-
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
|
|
1338
|
+
specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`).
|
|
1339
|
+
Parameters starting with the specify_prefix will be loaded. Default: ``None`` .
|
|
1340
|
+
choice_func (Union[None, function], optional) : Input value of the function is a Parameter name of type string,
|
|
1297
1341
|
and the return value is a bool. If returns ``True`` , the Parameter
|
|
1298
1342
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1299
1343
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1300
|
-
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1301
|
-
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1344
|
+
crc_check (bool, optional) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1345
|
+
remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1302
1346
|
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1303
1347
|
redundant-free loading is not enabled.
|
|
1304
|
-
format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1348
|
+
format (str, optional): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1305
1349
|
|
|
1306
1350
|
Returns:
|
|
1307
1351
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1346,13 +1390,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1346
1390
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1347
1391
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1348
1392
|
"""
|
|
1393
|
+
start_load_time = time.time()
|
|
1394
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
|
|
1349
1395
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1350
1396
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1351
1397
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1352
1398
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1353
1399
|
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1354
1400
|
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1355
|
-
|
|
1401
|
+
_check_load_checkpoint_upsupported_param(format, dec_key, dec_mode)
|
|
1356
1402
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1357
1403
|
|
|
1358
1404
|
parameter_dict = {}
|
|
@@ -1392,6 +1438,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1392
1438
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1393
1439
|
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1394
1440
|
|
|
1441
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Load checkpoint is finished.")
|
|
1442
|
+
end_load_time = time.time()
|
|
1443
|
+
load_checkpoint_cost_time = end_load_time - start_load_time
|
|
1444
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load checkpoint cost time {load_checkpoint_cost_time}.")
|
|
1395
1445
|
return parameter_dict
|
|
1396
1446
|
|
|
1397
1447
|
|
|
@@ -1411,7 +1461,7 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
|
|
|
1411
1461
|
And using either of those two args will override `choice_func` at the same time.
|
|
1412
1462
|
|
|
1413
1463
|
Args:
|
|
1414
|
-
ckpt_file_name (str): Checkpoint file name.
|
|
1464
|
+
ckpt_file_name (str): Checkpoint file name. The file extension must be `ckpt` or `safetensors` .
|
|
1415
1465
|
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1416
1466
|
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1417
1467
|
parameter into net when parameter name's suffix in checkpoint file is the
|
|
@@ -1448,7 +1498,8 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
|
|
|
1448
1498
|
>>> from mindspore import context
|
|
1449
1499
|
>>> from mindspore import load_checkpoint_async
|
|
1450
1500
|
>>> from mindspore import load_param_into_net
|
|
1451
|
-
>>>
|
|
1501
|
+
>>> mindspore.set_device(device_target="Ascend")
|
|
1502
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
1452
1503
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1453
1504
|
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1454
1505
|
>>> dataset = create_dataset()
|
|
@@ -1468,10 +1519,11 @@ def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_pr
|
|
|
1468
1519
|
>>> model.train(2, dataset)
|
|
1469
1520
|
>>> print("param dict len: ", len(param_dict), flush=True)
|
|
1470
1521
|
"""
|
|
1522
|
+
format = "safetensors" if ckpt_file_name.endswith(".safetensors") else "ckpt"
|
|
1471
1523
|
from concurrent.futures import ThreadPoolExecutor
|
|
1472
1524
|
executor = ThreadPoolExecutor(max_workers=2)
|
|
1473
1525
|
param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
|
|
1474
|
-
dec_key, dec_mode, specify_prefix, choice_func)
|
|
1526
|
+
dec_key, dec_mode, specify_prefix, choice_func, format=format)
|
|
1475
1527
|
return ParamDictFuture(executor, param_dict_future)
|
|
1476
1528
|
|
|
1477
1529
|
|
|
@@ -1555,7 +1607,12 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
|
1555
1607
|
try:
|
|
1556
1608
|
if dec_key is None:
|
|
1557
1609
|
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
1610
|
+
ckpt_load_time_start = time.time()
|
|
1558
1611
|
pb_content = f.read()
|
|
1612
|
+
ckpt_load_time_end = time.time()
|
|
1613
|
+
cost_time = ckpt_load_time_end - ckpt_load_time_start
|
|
1614
|
+
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Load ckpt io cost time:{cost_time}.")
|
|
1615
|
+
|
|
1559
1616
|
else:
|
|
1560
1617
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1561
1618
|
if pb_content is None:
|
|
@@ -1625,17 +1682,18 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1625
1682
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
1626
1683
|
|
|
1627
1684
|
Note:
|
|
1628
|
-
|
|
1685
|
+
When loading a parameter dict that has removed redundancy, the network should be compiled.
|
|
1629
1686
|
|
|
1630
1687
|
Args:
|
|
1631
1688
|
net (Cell): The network where the parameters will be loaded.
|
|
1632
1689
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
1633
1690
|
it is a dictionary consisting of key: parameters's name, value: parameter.
|
|
1634
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` ,
|
|
1691
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` ,
|
|
1692
|
+
it will load parameter
|
|
1635
1693
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1636
1694
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1637
1695
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1638
|
-
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1696
|
+
remove_redundancy (bool, optional): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1639
1697
|
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1640
1698
|
redundant-free loading is not enabled.
|
|
1641
1699
|
|
|
@@ -1673,11 +1731,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1673
1731
|
strict_load = Validator.check_bool(strict_load)
|
|
1674
1732
|
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1675
1733
|
logger.info("Execute the process of loading parameters into net.")
|
|
1676
|
-
for _, param in net.parameters_and_names():
|
|
1677
|
-
param.from_ckpt = True
|
|
1678
1734
|
param_not_load = []
|
|
1679
1735
|
ckpt_not_load = list(parameter_dict.keys())
|
|
1680
1736
|
for _, param in net.parameters_and_names():
|
|
1737
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1738
|
+
continue
|
|
1681
1739
|
if param.name in parameter_dict:
|
|
1682
1740
|
if isinstance(param, MapParameter):
|
|
1683
1741
|
param.import_data(parameter_dict[param.name])
|
|
@@ -1696,31 +1754,24 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1696
1754
|
if param_not_load and not strict_load:
|
|
1697
1755
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1698
1756
|
|
|
1699
|
-
logger.info("Loading parameters into net is finished.")
|
|
1700
|
-
if param_not_load:
|
|
1701
|
-
logger.warning("For 'load_param_into_net', "
|
|
1702
|
-
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1703
|
-
"'parameter_dict', please check whether the network structure is consistent "
|
|
1704
|
-
"when training and loading checkpoint. Another possibility is that "
|
|
1705
|
-
"the redundant loading is not enabled, but the loaded checkpoint is saved with "
|
|
1706
|
-
"redundancy removed. ".format(len(param_not_load)))
|
|
1707
|
-
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1708
1757
|
if remove_redundancy:
|
|
1709
|
-
|
|
1710
|
-
if parallel_mode == "stand_alone":
|
|
1758
|
+
if get_group_size() == 1:
|
|
1711
1759
|
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1712
|
-
f"in parallel scenarios, but got
|
|
1760
|
+
f"in parallel scenarios, but got stand_alone.")
|
|
1713
1761
|
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1714
1762
|
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1715
1763
|
"the network should be compiled.")
|
|
1716
1764
|
param_layout = net.parameter_layout_dict
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
1720
|
-
chunk_size = device_num // stage_num
|
|
1721
|
-
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
1722
|
-
_single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
|
|
1765
|
+
_single_parameter_broadcast(net, param_layout, param_not_load)
|
|
1766
|
+
mindspore.hal.synchronize()
|
|
1723
1767
|
|
|
1768
|
+
logger.info("Loading parameters into net is finished.")
|
|
1769
|
+
if param_not_load:
|
|
1770
|
+
logger.warning("For 'load_param_into_net', "
|
|
1771
|
+
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1772
|
+
"'parameter_dict', please check whether the network structure is consistent "
|
|
1773
|
+
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
1774
|
+
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1724
1775
|
return param_not_load, ckpt_not_load
|
|
1725
1776
|
|
|
1726
1777
|
|
|
@@ -1903,9 +1954,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1903
1954
|
elif opt_shard_group:
|
|
1904
1955
|
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1905
1956
|
tuple(after_reshape_slice_shape))
|
|
1906
|
-
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
|
1907
|
-
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1908
|
-
tuple(after_reshape_slice_shape))
|
|
1909
1957
|
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
|
1910
1958
|
if allgather_net:
|
|
1911
1959
|
param_data = allgather_net(param_data)
|
|
@@ -1959,27 +2007,6 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1959
2007
|
|
|
1960
2008
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1961
2009
|
preprocessing of the dataset into MindIR.
|
|
1962
|
-
|
|
1963
|
-
- obf_config (dict): obfuscation config.
|
|
1964
|
-
|
|
1965
|
-
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1966
|
-
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1967
|
-
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1968
|
-
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1969
|
-
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1970
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1971
|
-
Reference to 'my_func()' in
|
|
1972
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1973
|
-
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1974
|
-
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1975
|
-
obfuscated model.
|
|
1976
|
-
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1977
|
-
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1978
|
-
`obf_random_seed` is set, then it should be passed
|
|
1979
|
-
to :class:`mindspore.nn.GraphCell` interface when loading
|
|
1980
|
-
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1981
|
-
be set, and the latter mode would be applied if both of them are set.
|
|
1982
|
-
|
|
1983
2010
|
- incremental (bool): export MindIR incrementally.
|
|
1984
2011
|
|
|
1985
2012
|
- custom_func (function): Functions for custom defined export policies. This function will be used to
|
|
@@ -2013,6 +2040,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
2013
2040
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
2014
2041
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
2015
2042
|
"""
|
|
2043
|
+
if 'obf_func' in kwargs.keys():
|
|
2044
|
+
raise NotImplementedError("Dynamic model structure obfuscation is no longer supported.")
|
|
2016
2045
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
2017
2046
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
2018
2047
|
|
|
@@ -2094,9 +2123,7 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
2094
2123
|
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
|
2095
2124
|
"""
|
|
2096
2125
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
2097
|
-
if "
|
|
2098
|
-
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
2099
|
-
if "custom_func" in kwargs and file_format != "MINDIR":
|
|
2126
|
+
if "custom_func" in kwargs and file_format != "MINDIR" and kwargs["custom_func"] is not None:
|
|
2100
2127
|
raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
|
|
2101
2128
|
if file_format == 'AIR':
|
|
2102
2129
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
@@ -2309,14 +2336,13 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2309
2336
|
os.chmod(data_file_name, stat.S_IRUSR)
|
|
2310
2337
|
|
|
2311
2338
|
|
|
2312
|
-
def _msfunc_info(net, *inputs):
|
|
2339
|
+
def _msfunc_info(net, jit_executor, *inputs):
|
|
2313
2340
|
"""Get mindir stream and parameter dict of ms_function"""
|
|
2314
2341
|
# pylint: disable=protected-access
|
|
2315
2342
|
net_dict = OrderedDict()
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
|
|
2319
|
-
params = _ms_func_executor._graph_executor.get_params(graph_id)
|
|
2343
|
+
graph_id = jit_executor.compile(net.__name__, *inputs)
|
|
2344
|
+
mindir_stream = jit_executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
2345
|
+
params = jit_executor._graph_executor.get_params(graph_id)
|
|
2320
2346
|
for name, value in params.items():
|
|
2321
2347
|
net_dict[name] = Parameter(value, name=name)
|
|
2322
2348
|
return mindir_stream, net_dict
|
|
@@ -2328,53 +2354,21 @@ def _cell_info(net, incremental, *inputs):
|
|
|
2328
2354
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
2329
2355
|
# pylint: disable=protected-access
|
|
2330
2356
|
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
|
|
2331
|
-
# clean obfuscation config to prevent the next call
|
|
2332
|
-
_executor.obfuscate_config = None
|
|
2333
|
-
|
|
2334
2357
|
net_dict = net.parameters_dict()
|
|
2335
2358
|
return mindir_stream, net_dict
|
|
2336
2359
|
|
|
2337
2360
|
|
|
2338
|
-
def _set_obfuscate_config(**kwargs):
|
|
2339
|
-
"""Set obfuscation config for executor."""
|
|
2340
|
-
logger.warning("Obfuscate model.")
|
|
2341
|
-
if 'enc_mode' in kwargs.keys():
|
|
2342
|
-
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
2343
|
-
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
2344
|
-
raise ValueError(
|
|
2345
|
-
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
2346
|
-
"obfuscation, but got {}.".format(enc_mode))
|
|
2347
|
-
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
|
|
2348
|
-
if customized_funcs and obf_random_seed > 0:
|
|
2349
|
-
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
2350
|
-
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
2351
|
-
|
|
2352
|
-
if obf_random_seed == 0: # apply customized_func mode
|
|
2353
|
-
device_target = context.get_context('device_target')
|
|
2354
|
-
if device_target in ["GPU", "Ascend"]:
|
|
2355
|
-
raise ValueError(
|
|
2356
|
-
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
2357
|
-
clean_funcs()
|
|
2358
|
-
for func in customized_funcs:
|
|
2359
|
-
add_opaque_predicate(func.__name__, func)
|
|
2360
|
-
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
|
|
2361
|
-
|
|
2362
|
-
|
|
2363
2361
|
def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
2364
2362
|
"""Save MindIR format file."""
|
|
2365
|
-
|
|
2366
|
-
if
|
|
2367
|
-
|
|
2368
|
-
for item in inputs:
|
|
2369
|
-
if -1 in item.shape:
|
|
2370
|
-
raise ValueError(
|
|
2371
|
-
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
2363
|
+
executor = _executor
|
|
2364
|
+
if not isinstance(net, nn.Cell):
|
|
2365
|
+
executor = _JitExecutor(net, time.time() * 1e9)
|
|
2372
2366
|
|
|
2373
2367
|
incremental = kwargs.get('incremental', False)
|
|
2374
2368
|
|
|
2375
2369
|
model = mindir_model()
|
|
2376
2370
|
if not isinstance(net, nn.Cell):
|
|
2377
|
-
mindir_stream, net_dict = _msfunc_info(net, *inputs)
|
|
2371
|
+
mindir_stream, net_dict = _msfunc_info(net, executor, *inputs)
|
|
2378
2372
|
else:
|
|
2379
2373
|
mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
|
|
2380
2374
|
model.ParseFromString(mindir_stream)
|
|
@@ -2447,8 +2441,10 @@ def _save_together(net_dict, model):
|
|
|
2447
2441
|
if name in net_dict.keys():
|
|
2448
2442
|
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
2449
2443
|
else:
|
|
2450
|
-
raise ValueError("
|
|
2451
|
-
"
|
|
2444
|
+
raise ValueError("There's a mindspore.Parameter that wasn't created in nn.Cell, and mindspore.export() "
|
|
2445
|
+
f"does not support exporting such Parameters. The parameter name is: {name}.\n"
|
|
2446
|
+
"You can find the supported syntax range for mindspore.export() at the following link:\n"
|
|
2447
|
+
"https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html")
|
|
2452
2448
|
if data_total > TOTAL_SAVE:
|
|
2453
2449
|
return False
|
|
2454
2450
|
return True
|
|
@@ -2478,6 +2474,9 @@ def check_checkpoint(ckpt_file_name):
|
|
|
2478
2474
|
"""
|
|
2479
2475
|
Check whether the checkpoint is valid.
|
|
2480
2476
|
|
|
2477
|
+
Note:
|
|
2478
|
+
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2479
|
+
|
|
2481
2480
|
Args:
|
|
2482
2481
|
ckpt_file_name (str): Checkpoint file name.
|
|
2483
2482
|
|
|
@@ -2491,6 +2490,8 @@ def check_checkpoint(ckpt_file_name):
|
|
|
2491
2490
|
>>> print(check_result)
|
|
2492
2491
|
True
|
|
2493
2492
|
"""
|
|
2493
|
+
logger.warning("The interface 'mindspore.check_checkpoint' is deprecated from version 2.5 "
|
|
2494
|
+
"and will be removed in a future version.")
|
|
2494
2495
|
if not ckpt_file_name.endswith('.ckpt'):
|
|
2495
2496
|
return False
|
|
2496
2497
|
checkpoint_list = Checkpoint()
|
|
@@ -2517,6 +2518,9 @@ def parse_print(print_file_name):
|
|
|
2517
2518
|
"""
|
|
2518
2519
|
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
2519
2520
|
|
|
2521
|
+
Note:
|
|
2522
|
+
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2523
|
+
|
|
2520
2524
|
Args:
|
|
2521
2525
|
print_file_name (str): The file name needs to be parsed.
|
|
2522
2526
|
|
|
@@ -2551,6 +2555,8 @@ def parse_print(print_file_name):
|
|
|
2551
2555
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2552
2556
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2553
2557
|
"""
|
|
2558
|
+
logger.warning("The interface 'mindspore.parse_print' is deprecated from version 2.5 "
|
|
2559
|
+
"and will be removed in a future version.")
|
|
2554
2560
|
print_file_path = os.path.realpath(print_file_name)
|
|
2555
2561
|
|
|
2556
2562
|
if os.path.getsize(print_file_path) == 0:
|
|
@@ -2605,548 +2611,13 @@ def parse_print(print_file_name):
|
|
|
2605
2611
|
return tensor_list
|
|
2606
2612
|
|
|
2607
2613
|
|
|
2608
|
-
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
2609
|
-
"""
|
|
2610
|
-
Merge data slices to one tensor with whole data when strategy is not None.
|
|
2611
|
-
|
|
2612
|
-
Args:
|
|
2613
|
-
sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
|
|
2614
|
-
parameter_name (str): Name of parameter.
|
|
2615
|
-
strategy (dict): Parameter slice strategy.
|
|
2616
|
-
is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
|
|
2617
|
-
|
|
2618
|
-
Returns:
|
|
2619
|
-
Tensor, the merged Tensor which has the whole data.
|
|
2620
|
-
|
|
2621
|
-
Raises:
|
|
2622
|
-
ValueError: Failed to merge.
|
|
2623
|
-
"""
|
|
2624
|
-
layout = strategy.get(parameter_name)
|
|
2625
|
-
try:
|
|
2626
|
-
dev_mat = list(layout.dev_matrix[0].dim)
|
|
2627
|
-
tensor_map = list(layout.tensor_map[0].dim)
|
|
2628
|
-
param_split_shape = list(layout.param_split_shape[0].dim)
|
|
2629
|
-
field_size = int(layout.field)
|
|
2630
|
-
except BaseException as e:
|
|
2631
|
-
raise ValueError(f"{e.__str__()}. For 'merge_sliced_parameter'"
|
|
2632
|
-
f", please make sure that 'strategy' is correct.") from e
|
|
2633
|
-
|
|
2634
|
-
device_count = 1
|
|
2635
|
-
for dim in dev_mat:
|
|
2636
|
-
device_count *= dim
|
|
2637
|
-
|
|
2638
|
-
if len(sliced_data) != device_count:
|
|
2639
|
-
raise ValueError(f"For 'merge_sliced_parameter', the length of 'sliced_parameters' should be equal to "
|
|
2640
|
-
f"device_count. The length of 'sliced_parameters' is {len(sliced_data)}, but "
|
|
2641
|
-
f"device_count is {device_count}.")
|
|
2642
|
-
|
|
2643
|
-
if not param_split_shape:
|
|
2644
|
-
if not is_even:
|
|
2645
|
-
raise ValueError("For 'merge_sliced_parameter', the shape of every parameter in 'sliced_parameters' "
|
|
2646
|
-
"should be the same when slice manner is even.")
|
|
2647
|
-
|
|
2648
|
-
all_gather_tensor = Tensor(np.concatenate(sliced_data))
|
|
2649
|
-
|
|
2650
|
-
if field_size > 0:
|
|
2651
|
-
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
|
|
2652
|
-
else:
|
|
2653
|
-
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
|
|
2654
|
-
|
|
2655
|
-
else:
|
|
2656
|
-
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
2657
|
-
|
|
2658
|
-
slice_count = 1
|
|
2659
|
-
for dim in tensor_strategy:
|
|
2660
|
-
slice_count *= dim
|
|
2661
|
-
|
|
2662
|
-
if len(param_split_shape) != slice_count:
|
|
2663
|
-
raise ValueError(f"For 'merge_sliced_parameter', the param_split_shape length in 'strategy' should be "
|
|
2664
|
-
f"{slice_count}, but got {len(param_split_shape)}.")
|
|
2665
|
-
|
|
2666
|
-
tensor_slices_new = list(range(slice_count))
|
|
2667
|
-
tensor_slices = sliced_data
|
|
2668
|
-
for i in range(device_count):
|
|
2669
|
-
slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
|
|
2670
|
-
if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
|
|
2671
|
-
raise ValueError(f"For 'merge_sliced_parameter', the slice {slice_index} should be "
|
|
2672
|
-
f"{param_split_shape[slice_index]} in 0 axis, but got "
|
|
2673
|
-
f"{tensor_slices[i].shape[0]}.")
|
|
2674
|
-
tensor_slices_new[slice_index] = np.array(tensor_slices[i])
|
|
2675
|
-
|
|
2676
|
-
dim_len = len(tensor_strategy)
|
|
2677
|
-
for i in range(dim_len):
|
|
2678
|
-
ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
|
|
2679
|
-
tensor_slices_new_inner = []
|
|
2680
|
-
for j in range(ele_count):
|
|
2681
|
-
new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
|
|
2682
|
-
for k in range(j * tensor_strategy[dim_len - 1 - i] + 1,
|
|
2683
|
-
(j + 1) * tensor_strategy[dim_len - 1 - i]):
|
|
2684
|
-
new_tensor = np.concatenate((new_tensor, tensor_slices_new[k]), axis=dim_len - 1 - i)
|
|
2685
|
-
tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
|
|
2686
|
-
tensor_slices_new = tensor_slices_new_inner
|
|
2687
|
-
merged_tensor = Tensor(tensor_slices_new[0])
|
|
2688
|
-
|
|
2689
|
-
return merged_tensor
|
|
2690
|
-
|
|
2691
|
-
|
|
2692
|
-
def restore_group_info_list(group_info_file_name):
|
|
2693
|
-
"""
|
|
2694
|
-
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2695
|
-
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2696
|
-
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2697
|
-
|
|
2698
|
-
Args:
|
|
2699
|
-
group_info_file_name (str): Name of group information file.
|
|
2700
|
-
|
|
2701
|
-
Returns:
|
|
2702
|
-
List, the rank list.
|
|
2703
|
-
|
|
2704
|
-
Raises:
|
|
2705
|
-
ValueError: group information file is incorrect.
|
|
2706
|
-
TypeError: `group_info_file_name` is not str.
|
|
2707
|
-
|
|
2708
|
-
Examples:
|
|
2709
|
-
>>> import mindspore as ms
|
|
2710
|
-
>>> ms.restore_list = restore_group_info_list("./group_info.pb")
|
|
2711
|
-
"""
|
|
2712
|
-
if not isinstance(group_info_file_name, str):
|
|
2713
|
-
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
|
|
2714
|
-
f"but got {type(group_info_file_name)}.")
|
|
2715
|
-
|
|
2716
|
-
if not os.path.isfile(group_info_file_name):
|
|
2717
|
-
raise ValueError(f"For 'restore_group_info_list', no such group information file: {group_info_file_name}.")
|
|
2718
|
-
|
|
2719
|
-
if os.path.getsize(group_info_file_name) == 0:
|
|
2720
|
-
raise ValueError("For 'restore_group_info_list', the group information file should not be empty.")
|
|
2721
|
-
|
|
2722
|
-
return _restore_group_info_list(group_info_file_name)
|
|
2723
|
-
|
|
2724
|
-
|
|
2725
|
-
def build_searched_strategy(strategy_filename):
|
|
2726
|
-
"""
|
|
2727
|
-
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
2728
|
-
|
|
2729
|
-
Args:
|
|
2730
|
-
strategy_filename (str): Name of strategy file.
|
|
2731
|
-
|
|
2732
|
-
Returns:
|
|
2733
|
-
Dict, whose key is parameter name and value is slice strategy of this parameter.
|
|
2734
|
-
|
|
2735
|
-
Raises:
|
|
2736
|
-
ValueError: Strategy file is incorrect.
|
|
2737
|
-
TypeError: `strategy_filename` is not a string.
|
|
2738
|
-
|
|
2739
|
-
Examples:
|
|
2740
|
-
>>> import mindspore as ms
|
|
2741
|
-
>>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
|
|
2742
|
-
"""
|
|
2743
|
-
return _build_searched_strategy(strategy_filename)
|
|
2744
|
-
|
|
2745
|
-
|
|
2746
|
-
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2747
|
-
"""
|
|
2748
|
-
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
2749
|
-
|
|
2750
|
-
Args:
|
|
2751
|
-
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
2752
|
-
strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
|
|
2753
|
-
value is slice strategy of this parameter. If strategy is None, just merge
|
|
2754
|
-
parameter slices in 0 axis order. Default: ``None``.
|
|
2755
|
-
|
|
2756
|
-
Returns:
|
|
2757
|
-
Parameter, the merged parameter which has the whole data.
|
|
2758
|
-
|
|
2759
|
-
Raises:
|
|
2760
|
-
ValueError: Failed to merge.
|
|
2761
|
-
TypeError: The sliced_parameters is incorrect or strategy is not dict.
|
|
2762
|
-
KeyError: The parameter name is not in keys of strategy.
|
|
2763
|
-
|
|
2764
|
-
Examples:
|
|
2765
|
-
>>> import numpy as np
|
|
2766
|
-
>>> import mindspore as ms
|
|
2767
|
-
>>> from mindspore import Tensor, Parameter
|
|
2768
|
-
>>>
|
|
2769
|
-
>>> sliced_parameters = [
|
|
2770
|
-
... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
|
|
2771
|
-
... "network.embedding_table"),
|
|
2772
|
-
... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
|
|
2773
|
-
... "network.embedding_table"),
|
|
2774
|
-
... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
|
|
2775
|
-
... "network.embedding_table"),
|
|
2776
|
-
... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
|
|
2777
|
-
... "network.embedding_table")]
|
|
2778
|
-
>>> merged_parameter = ms.merge_sliced_parameter(sliced_parameters)
|
|
2779
|
-
>>> print(merged_parameter)
|
|
2780
|
-
Parameter (name=network.embedding_table, shape=(12,), dtype=Float64, requires_grad=True)
|
|
2781
|
-
"""
|
|
2782
|
-
if not isinstance(sliced_parameters, list):
|
|
2783
|
-
raise TypeError(f"For 'merge_sliced_parameter', the argument 'sliced_parameters' should be list, "
|
|
2784
|
-
f"but got {type(sliced_parameters)}.")
|
|
2785
|
-
|
|
2786
|
-
if not sliced_parameters:
|
|
2787
|
-
raise ValueError("For 'merge_sliced_parameter', the argument 'sliced_parameters' should not be empty.")
|
|
2788
|
-
|
|
2789
|
-
if strategy and not isinstance(strategy, dict):
|
|
2790
|
-
raise TypeError(f"For 'merge_sliced_parameter', the argument 'strategy' should be dict, "
|
|
2791
|
-
f"but got {type(strategy)}.")
|
|
2792
|
-
|
|
2793
|
-
try:
|
|
2794
|
-
parameter_name = sliced_parameters[0].name
|
|
2795
|
-
parameter_shape = sliced_parameters[0].data.shape
|
|
2796
|
-
parameter_shape_length = len(parameter_shape)
|
|
2797
|
-
except BaseException as e:
|
|
2798
|
-
raise TypeError(e.__str__() + f" For 'merge_sliced_parameter', the element in 'sliced_parameters' should be "
|
|
2799
|
-
f"'Parameter', but got {type(sliced_parameters[0])} at index 0.") from e
|
|
2800
|
-
|
|
2801
|
-
is_even = True
|
|
2802
|
-
for index, parameter in enumerate(sliced_parameters):
|
|
2803
|
-
if not isinstance(parameter, Parameter):
|
|
2804
|
-
raise TypeError(f"For 'merge_sliced_parameter', the element in 'sliced_parameters' should be 'Parameter', "
|
|
2805
|
-
f"but got {type(parameter)} at index {index}.")
|
|
2806
|
-
|
|
2807
|
-
if parameter.name != parameter_name \
|
|
2808
|
-
or len(parameter.data.shape) != parameter_shape_length \
|
|
2809
|
-
or parameter.data.shape[1:] != parameter_shape[1:]:
|
|
2810
|
-
raise ValueError(f"For 'merge_sliced_parameter', please make sure that the elements in 'slice_parameters'"
|
|
2811
|
-
f" have the same name, dimension length and shape except 0 axis. The name, dimension "
|
|
2812
|
-
f"length, shape except 0 axis should be {parameter_name}, {parameter_shape_length}, "
|
|
2813
|
-
f"{parameter_shape[1:]}, but got name: {parameter.name}, dimension length: "
|
|
2814
|
-
f"{len(parameter.data.shape)}, shape except 0 axis: {parameter.data.shape[1:]} "
|
|
2815
|
-
f"at index {index}.")
|
|
2816
|
-
|
|
2817
|
-
if parameter.data.shape != parameter_shape:
|
|
2818
|
-
is_even = False
|
|
2819
|
-
|
|
2820
|
-
layerwise_parallel = sliced_parameters[0].layerwise_parallel
|
|
2821
|
-
requires_grad = sliced_parameters[0].requires_grad
|
|
2822
|
-
sliced_data = []
|
|
2823
|
-
for parameter in sliced_parameters:
|
|
2824
|
-
if parameter.data.dtype == mstype.bfloat16:
|
|
2825
|
-
sliced_data.append(cpu_cast(parameter.data, mstype.float32).asnumpy())
|
|
2826
|
-
else:
|
|
2827
|
-
sliced_data.append(parameter.data.asnumpy())
|
|
2828
|
-
|
|
2829
|
-
if not strategy:
|
|
2830
|
-
merged_tensor = Tensor(np.concatenate(sliced_data))
|
|
2831
|
-
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
2832
|
-
|
|
2833
|
-
else:
|
|
2834
|
-
if parameter_name not in strategy.keys():
|
|
2835
|
-
raise KeyError(f"For 'merge_sliced_parameter', the parameter name {parameter_name} should be a key in "
|
|
2836
|
-
f"the 'strategy'. Please check 'sliced_parameter' and 'strategy'.")
|
|
2837
|
-
merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
|
|
2838
|
-
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
|
|
2839
|
-
|
|
2840
|
-
return merged_parameter
|
|
2841
|
-
|
|
2842
|
-
|
|
2843
|
-
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2844
|
-
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2845
|
-
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
|
|
2846
|
-
"""
|
|
2847
|
-
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2848
|
-
|
|
2849
|
-
Args:
|
|
2850
|
-
network (Cell): Network for distributed predication.
|
|
2851
|
-
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
2852
|
-
predict_strategy (dict): Strategy of predication process. It means that using one device to predict
|
|
2853
|
-
when setting predict_strategy as None. Default: ``None`` .
|
|
2854
|
-
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
2855
|
-
When train_strategy_filename is None, the training strategy file will be
|
|
2856
|
-
obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
|
|
2857
|
-
Therefore, the training strategy file needs to be specified
|
|
2858
|
-
in at least one of them. Default: ``None`` .
|
|
2859
|
-
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
2860
|
-
into net when parameter name's suffix in checkpoint file is the same as the
|
|
2861
|
-
parameter in the network. When the types are inconsistent, perform type conversion
|
|
2862
|
-
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
2863
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
2864
|
-
is not required. Default: ``None`` .
|
|
2865
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
2866
|
-
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
2867
|
-
Default: ``'AES-GCM'`` .
|
|
2868
|
-
format (str): Input weight format to be loaded into the network.
|
|
2869
|
-
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
2870
|
-
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
2871
|
-
Default: ``None`` .
|
|
2872
|
-
dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
|
|
2873
|
-
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
2874
|
-
globally by initializing the network; In save mode, save the file according to the input
|
|
2875
|
-
sequence number. If it is not input, save the entire file.
|
|
2876
|
-
|
|
2877
|
-
Raises:
|
|
2878
|
-
TypeError: The type of inputs do not match the requirements.
|
|
2879
|
-
ValueError: Failed to load checkpoint into net.
|
|
2880
|
-
|
|
2881
|
-
Supported Platforms:
|
|
2882
|
-
``Ascend`` ``GPU``
|
|
2883
|
-
|
|
2884
|
-
Examples:
|
|
2885
|
-
.. note::
|
|
2886
|
-
Before running the following examples, you need to configure the communication environment variables.
|
|
2887
|
-
|
|
2888
|
-
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2889
|
-
Please see the `rank table startup
|
|
2890
|
-
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
2891
|
-
for more details.
|
|
2892
|
-
|
|
2893
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2894
|
-
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
2895
|
-
|
|
2896
|
-
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2897
|
-
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
2898
|
-
|
|
2899
|
-
>>> import os
|
|
2900
|
-
>>> import numpy as np
|
|
2901
|
-
>>> import mindspore as ms
|
|
2902
|
-
>>> import mindspore.dataset as ds
|
|
2903
|
-
>>> from mindspore import nn, ops, train
|
|
2904
|
-
>>> from mindspore.communication import init
|
|
2905
|
-
>>>
|
|
2906
|
-
>>> step_per_epoch = 4
|
|
2907
|
-
>>> device_num = 8
|
|
2908
|
-
>>>
|
|
2909
|
-
>>> # Define the network structure.
|
|
2910
|
-
>>> class Net(nn.Cell):
|
|
2911
|
-
... def __init__(self, matmul_size, strategy=None):
|
|
2912
|
-
... super().__init__()
|
|
2913
|
-
... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
|
|
2914
|
-
... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
|
|
2915
|
-
... self.matmul = ops.MatMul()
|
|
2916
|
-
... self.neg = ops.Neg()
|
|
2917
|
-
... if strategy is not None:
|
|
2918
|
-
... self.matmul.shard(strategy)
|
|
2919
|
-
...
|
|
2920
|
-
... def construct(self, inputs):
|
|
2921
|
-
... x = self.matmul(inputs, self.matmul_weight)
|
|
2922
|
-
... x = self.neg(x)
|
|
2923
|
-
... return x
|
|
2924
|
-
>>>
|
|
2925
|
-
>>> # Create dataset.
|
|
2926
|
-
>>> def get_dataset(*inputs):
|
|
2927
|
-
... def generate():
|
|
2928
|
-
... for _ in range(step_per_epoch):
|
|
2929
|
-
... yield inputs
|
|
2930
|
-
... return generate
|
|
2931
|
-
>>>
|
|
2932
|
-
>>> # Train network and save distributed checkpoint.
|
|
2933
|
-
>>> def train_net():
|
|
2934
|
-
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2935
|
-
... init()
|
|
2936
|
-
... np.random.seed(1)
|
|
2937
|
-
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
2938
|
-
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
2939
|
-
... fake_dataset = get_dataset(input_data, label_data)
|
|
2940
|
-
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
2941
|
-
...
|
|
2942
|
-
... # Set parallel strategy.
|
|
2943
|
-
... strategy = ((1, 4), (4, 1))
|
|
2944
|
-
... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
|
|
2945
|
-
... strategy_ckpt_save_file="./train_strategy.ckpt")
|
|
2946
|
-
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
2947
|
-
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
2948
|
-
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
2949
|
-
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
2950
|
-
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
|
|
2951
|
-
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
2952
|
-
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
2953
|
-
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
2954
|
-
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
2955
|
-
... ms.reset_auto_parallel_context()
|
|
2956
|
-
>>>
|
|
2957
|
-
>>> # Load distributed checkpoint and test.
|
|
2958
|
-
>>> def load_model():
|
|
2959
|
-
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2960
|
-
... init()
|
|
2961
|
-
... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
|
|
2962
|
-
... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
|
|
2963
|
-
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
2964
|
-
... network = Net(matmul_size=(96, 16))
|
|
2965
|
-
... model = ms.Model(network)
|
|
2966
|
-
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
2967
|
-
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
2968
|
-
... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
|
|
2969
|
-
... predict_result = model.predict(predict_data)
|
|
2970
|
-
... print(predict_result)
|
|
2971
|
-
>>>
|
|
2972
|
-
>>> train_net()
|
|
2973
|
-
>>> load_model()
|
|
2974
|
-
[[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
|
|
2975
|
-
[ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
|
|
2976
|
-
...
|
|
2977
|
-
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2978
|
-
"""
|
|
2979
|
-
if format not in ['safetensors', 'ckpt']:
|
|
2980
|
-
raise ValueError(
|
|
2981
|
-
f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
|
|
2982
|
-
|
|
2983
|
-
if format == 'safetensors':
|
|
2984
|
-
if unified_safetensors_dir is None:
|
|
2985
|
-
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
2986
|
-
f"when format is 'safetensors'.")
|
|
2987
|
-
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
2988
|
-
for param in unsupport_param:
|
|
2989
|
-
if param is not None:
|
|
2990
|
-
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
2991
|
-
f"when format is 'safetensors'.")
|
|
2992
|
-
if strict_load or dec_mode != 'AES-GCM':
|
|
2993
|
-
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
2994
|
-
f"when format is 'safetensors'.")
|
|
2995
|
-
if network is not None:
|
|
2996
|
-
rank_id = get_rank()
|
|
2997
|
-
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
|
|
2998
|
-
else:
|
|
2999
|
-
if dst_safetensors_dir is None:
|
|
3000
|
-
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
3001
|
-
f"when network is None.")
|
|
3002
|
-
if rank_id is not None:
|
|
3003
|
-
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3004
|
-
rank_id)
|
|
3005
|
-
else:
|
|
3006
|
-
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
3007
|
-
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
3008
|
-
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
3009
|
-
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
3010
|
-
processes = []
|
|
3011
|
-
activate_processes = 0
|
|
3012
|
-
for rank in range(0, dst_device_num):
|
|
3013
|
-
p = Process(target=_load_parallel_checkpoint, args=(
|
|
3014
|
-
unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
|
|
3015
|
-
p.start()
|
|
3016
|
-
processes.append(p)
|
|
3017
|
-
activate_processes += 1
|
|
3018
|
-
max_processes = 64
|
|
3019
|
-
if activate_processes >= max_processes:
|
|
3020
|
-
p = processes.pop(0)
|
|
3021
|
-
p.join()
|
|
3022
|
-
activate_processes -= 1
|
|
3023
|
-
for p in processes:
|
|
3024
|
-
p.join()
|
|
3025
|
-
return
|
|
3026
|
-
|
|
3027
|
-
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
3028
|
-
_check_checkpoint_file(checkpoint_filenames)
|
|
3029
|
-
_check_predict_strategy(predict_strategy)
|
|
3030
|
-
|
|
3031
|
-
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
3032
|
-
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
3033
|
-
|
|
3034
|
-
if train_strategy_filename is None:
|
|
3035
|
-
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
|
|
3036
|
-
_train_strategy = build_searched_strategy(train_strategy_filename)
|
|
3037
|
-
train_strategy = _convert_to_list(_train_strategy)
|
|
3038
|
-
|
|
3039
|
-
train_dev_count = 1
|
|
3040
|
-
ckpt_file_len = len(checkpoint_filenames)
|
|
3041
|
-
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
|
|
3042
|
-
train_dev_count *= dim
|
|
3043
|
-
if train_dev_count != ckpt_file_len:
|
|
3044
|
-
raise ValueError(f"For 'Load_distributed_checkpoint', the length of 'checkpoint_filenames' should be "
|
|
3045
|
-
f"equal to the device count of training process. "
|
|
3046
|
-
f"But got the length of 'checkpoint_filenames'"
|
|
3047
|
-
f" is {ckpt_file_len} and the device count is {train_dev_count}.")
|
|
3048
|
-
rank_list = _infer_rank_list(train_strategy, predict_strategy)
|
|
3049
|
-
|
|
3050
|
-
param_total_dict = defaultdict(dict)
|
|
3051
|
-
for file_index, file_name in enumerate(checkpoint_filenames):
|
|
3052
|
-
ckpt_dict = load_checkpoint(file_name, dec_key=dec_key, dec_mode=dec_mode)
|
|
3053
|
-
for param_name, param in ckpt_dict.items():
|
|
3054
|
-
param_total_dict[param_name][file_index] = param
|
|
3055
|
-
|
|
3056
|
-
param_dict = {}
|
|
3057
|
-
param_not_in_strategy = []
|
|
3058
|
-
param_not_in_ckpt = []
|
|
3059
|
-
for _, param in network.parameters_and_names():
|
|
3060
|
-
sliced_params = []
|
|
3061
|
-
if param.name not in rank_list.keys():
|
|
3062
|
-
param_not_in_strategy.append(param.name)
|
|
3063
|
-
continue
|
|
3064
|
-
if param.name not in param_total_dict:
|
|
3065
|
-
param_not_in_ckpt.append(param.name)
|
|
3066
|
-
continue
|
|
3067
|
-
|
|
3068
|
-
param_rank = rank_list.get(param.name)[0]
|
|
3069
|
-
skip_merge_split = rank_list.get(param.name)[1]
|
|
3070
|
-
shard_stride = train_strategy.get(param.name)[4]
|
|
3071
|
-
tensor_map = train_strategy.get(param.name)[1]
|
|
3072
|
-
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
3073
|
-
device_arrangement = train_strategy.get(param.name)[0]
|
|
3074
|
-
first_dim_shard_size = 1
|
|
3075
|
-
if first_dim_shard_idx >= 0:
|
|
3076
|
-
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
3077
|
-
if train_strategy.get(param.name)[5]:
|
|
3078
|
-
shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
3079
|
-
else:
|
|
3080
|
-
shard_size = 0
|
|
3081
|
-
for rank in param_rank:
|
|
3082
|
-
param_total_list = list(range(0, ckpt_file_len))
|
|
3083
|
-
if first_dim_shard_size != 1:
|
|
3084
|
-
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
3085
|
-
if shard_size > 0:
|
|
3086
|
-
rank_index = param_total_list.index(rank)
|
|
3087
|
-
start = rank_index // shard_size * shard_size
|
|
3088
|
-
param_total_list = param_total_list[start:start + shard_size]
|
|
3089
|
-
if shard_stride > 0:
|
|
3090
|
-
param_stride = []
|
|
3091
|
-
# merge pre parameter
|
|
3092
|
-
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
|
|
3093
|
-
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
|
|
3094
|
-
param_index = list(set(param_index))
|
|
3095
|
-
param_index.sort()
|
|
3096
|
-
for rank_num in param_index:
|
|
3097
|
-
if param_total_dict[param.name][rank_num].data.dtype == mstype.bfloat16:
|
|
3098
|
-
param_stride.append(
|
|
3099
|
-
cpu_cast(param_total_dict[param.name][rank_num].data, mstype.float32).asnumpy())
|
|
3100
|
-
else:
|
|
3101
|
-
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())
|
|
3102
|
-
|
|
3103
|
-
sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
|
|
3104
|
-
else:
|
|
3105
|
-
sliced_param = param_total_dict[param.name][rank]
|
|
3106
|
-
|
|
3107
|
-
sliced_params.append(sliced_param)
|
|
3108
|
-
if skip_merge_split:
|
|
3109
|
-
split_param = sliced_params[0]
|
|
3110
|
-
else:
|
|
3111
|
-
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
|
|
3112
|
-
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
|
|
3113
|
-
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
|
|
3114
|
-
opt_shard_group = predict_strategy[param.name][5] if predict_strategy else None
|
|
3115
|
-
if opt_shard_group:
|
|
3116
|
-
if split_param.data.dtype == mstype.bfloat16:
|
|
3117
|
-
data = cpu_cast(split_param.data, mstype.float32).asnumpy()
|
|
3118
|
-
else:
|
|
3119
|
-
data = split_param.data.asnumpy()
|
|
3120
|
-
rank = get_rank(opt_shard_group)
|
|
3121
|
-
size = get_group_size(opt_shard_group)
|
|
3122
|
-
try:
|
|
3123
|
-
data_slice = np.split(data, size)[rank]
|
|
3124
|
-
except BaseException as e:
|
|
3125
|
-
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
|
3126
|
-
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
|
3127
|
-
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
|
|
3128
|
-
f" in load distributed checkpoint for {param.name}. Data shape is "
|
|
3129
|
-
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
|
3130
|
-
split_param = Parameter(Tensor(data_slice), param.name,
|
|
3131
|
-
split_param.requires_grad, split_param.layerwise_parallel)
|
|
3132
|
-
param_dict[param.name] = split_param
|
|
3133
|
-
|
|
3134
|
-
if param_not_in_strategy:
|
|
3135
|
-
logger.warning("For 'load_distributed_checkpoint', {} parameters in network are not in the slice strategy, "
|
|
3136
|
-
"you can check whether 'predict_strategy' or 'train_strategy_filename' is correct."
|
|
3137
|
-
.format(param_not_in_strategy))
|
|
3138
|
-
if param_not_in_ckpt:
|
|
3139
|
-
logger.warning("For 'load_distributed_checkpoint', {} parameters in network and slice strategy but not in "
|
|
3140
|
-
"the checkpoint file, please check whether 'checkpoint_filenames' is correct."
|
|
3141
|
-
.format(param_not_in_ckpt))
|
|
3142
|
-
|
|
3143
|
-
load_param_into_net(network, param_dict, strict_load=strict_load)
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
2614
|
def async_ckpt_thread_status():
|
|
3147
2615
|
"""
|
|
3148
2616
|
Get the status of asynchronous save checkpoint thread.
|
|
3149
2617
|
|
|
2618
|
+
Note:
|
|
2619
|
+
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
2620
|
+
|
|
3150
2621
|
When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed.
|
|
3151
2622
|
|
|
3152
2623
|
Returns:
|
|
@@ -3158,73 +2629,12 @@ def async_ckpt_thread_status():
|
|
|
3158
2629
|
>>> ms.async_ckpt_thread_status()
|
|
3159
2630
|
False
|
|
3160
2631
|
"""
|
|
2632
|
+
logger.warning("The interface 'mindspore.async_ckpt_thread_status' is deprecated from version 2.5 "
|
|
2633
|
+
"and will be removed in a future version.")
|
|
3161
2634
|
thr_list = threading.enumerate()
|
|
3162
2635
|
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
|
|
3163
2636
|
|
|
3164
2637
|
|
|
3165
|
-
def _check_predict_strategy(predict_strategy):
|
|
3166
|
-
"""Check predict strategy."""
|
|
3167
|
-
|
|
3168
|
-
def _check_int_list(arg):
|
|
3169
|
-
if not isinstance(arg, list):
|
|
3170
|
-
return False
|
|
3171
|
-
for item in arg:
|
|
3172
|
-
if not isinstance(item, int):
|
|
3173
|
-
return False
|
|
3174
|
-
return True
|
|
3175
|
-
|
|
3176
|
-
if predict_strategy is None:
|
|
3177
|
-
return
|
|
3178
|
-
|
|
3179
|
-
flag = True
|
|
3180
|
-
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
|
|
3181
|
-
for key in predict_strategy.keys():
|
|
3182
|
-
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
|
|
3183
|
-
or len(predict_strategy[key]) < 4:
|
|
3184
|
-
flag = False
|
|
3185
|
-
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
|
|
3186
|
-
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
|
|
3187
|
-
not (_check_int_list(param_split_shape) or not param_split_shape) or \
|
|
3188
|
-
not (isinstance(field_size, int) and field_size == 0):
|
|
3189
|
-
flag = False
|
|
3190
|
-
|
|
3191
|
-
if not flag:
|
|
3192
|
-
raise ValueError(f"For 'load_distributed_checkpoint', the argument 'predict_strategy' is dict, "
|
|
3193
|
-
f"the key of it must be string, and the value of it must be list or tuple that "
|
|
3194
|
-
f"the first four elements must be dev_matrix (list[int]), tensor_map (list[int]), "
|
|
3195
|
-
f"param_split_shape (list[int]) and field_size (int, which value is 0)."
|
|
3196
|
-
f"Please check whether 'predict_strategy' is correct.")
|
|
3197
|
-
|
|
3198
|
-
|
|
3199
|
-
def _check_checkpoint_file(checkpoint_filenames):
|
|
3200
|
-
"""Check checkpoint file name."""
|
|
3201
|
-
for index, filename in enumerate(checkpoint_filenames):
|
|
3202
|
-
if not isinstance(filename, str) or not os.path.exists(filename) \
|
|
3203
|
-
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
|
|
3204
|
-
raise ValueError(f"For 'load_distributed_checkpoint', please check 'checkpoint_filenames', and "
|
|
3205
|
-
f"make sure the {filename} at index {index} is a valid checkpoint file, it must "
|
|
3206
|
-
f"be a string ending with '.ckpt', and the checkpoint file it represents must "
|
|
3207
|
-
f"be exist and not empty.")
|
|
3208
|
-
|
|
3209
|
-
|
|
3210
|
-
def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
3211
|
-
"""Merge sliced parameter and split it according to the predict strategy."""
|
|
3212
|
-
merged_param = merge_sliced_parameter(sliced_params, train_strategy)
|
|
3213
|
-
if predict_strategy is None:
|
|
3214
|
-
return merged_param
|
|
3215
|
-
param_name = merged_param.name
|
|
3216
|
-
tensor_layout = predict_strategy[param_name]
|
|
3217
|
-
rank = get_rank()
|
|
3218
|
-
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
3219
|
-
requires_grad = merged_param.requires_grad
|
|
3220
|
-
layerwise_parallel = merged_param.layerwise_parallel
|
|
3221
|
-
if merged_param.data.dtype == mstype.bfloat16:
|
|
3222
|
-
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
3223
|
-
else:
|
|
3224
|
-
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
3225
|
-
return split_param
|
|
3226
|
-
|
|
3227
|
-
|
|
3228
2638
|
def _calculation_net_size(net):
|
|
3229
2639
|
"""Calculate the size of parameters in the network."""
|
|
3230
2640
|
data_total = 0
|
|
@@ -3288,8 +2698,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
3288
2698
|
"""
|
|
3289
2699
|
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
3290
2700
|
|
|
3291
|
-
|
|
3292
|
-
|
|
2701
|
+
Note:
|
|
2702
|
+
The interface is deprecated from version 2.5 and will be removed in a future version.
|
|
3293
2703
|
|
|
3294
2704
|
Args:
|
|
3295
2705
|
mindir_file (str): MindIR file name.
|
|
@@ -3305,6 +2715,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
3305
2715
|
>>> import mindspore as ms
|
|
3306
2716
|
>>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
|
|
3307
2717
|
"""
|
|
2718
|
+
logger.warning("The interface 'mindspore.train.serialization.convert_model' is deprecated from version 2.5 "
|
|
2719
|
+
"and will be removed in a future version.")
|
|
3308
2720
|
Validator.check_file_name_by_regular(mindir_file)
|
|
3309
2721
|
Validator.check_file_name_by_regular(convert_file)
|
|
3310
2722
|
if file_format != "ONNX":
|
|
@@ -3316,3 +2728,235 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
3316
2728
|
export(net, net_input, file_name=convert_file, file_format=file_format)
|
|
3317
2729
|
else:
|
|
3318
2730
|
export(net, *net_input, file_name=convert_file, file_format=file_format)
|
|
2731
|
+
|
|
2732
|
+
|
|
2733
|
+
def _transform_tensor_to_numpy(path, name_map=None):
|
|
2734
|
+
return _load_and_transform(path, name_map, mindspore.load_checkpoint, lambda v, new_name: v.asnumpy())
|
|
2735
|
+
|
|
2736
|
+
|
|
2737
|
+
def _transform_numpy_to_tensor(path, name_map=None):
|
|
2738
|
+
return _load_and_transform(path, name_map, load_file, lambda v, new_name: mindspore.Parameter(v, name=new_name))
|
|
2739
|
+
|
|
2740
|
+
|
|
2741
|
+
def _process_file(file_info):
|
|
2742
|
+
cur_ckpt_path, name_map, save_path, file = file_info
|
|
2743
|
+
param_dict_numpy = _transform_tensor_to_numpy(cur_ckpt_path, name_map)
|
|
2744
|
+
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
2745
|
+
dst_file = os.path.join(save_path, safetensors_filename)
|
|
2746
|
+
save_file(param_dict_numpy, dst_file)
|
|
2747
|
+
|
|
2748
|
+
|
|
2749
|
+
def _process_file_safetensors(file_info):
|
|
2750
|
+
cur_safe_path, name_map, save_path, file = file_info
|
|
2751
|
+
param_dict_tensor = _transform_numpy_to_tensor(cur_safe_path, name_map)
|
|
2752
|
+
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2753
|
+
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2754
|
+
mindspore.save_checkpoint(param_dict_tensor, dst_file)
|
|
2755
|
+
|
|
2756
|
+
|
|
2757
|
+
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
2758
|
+
"""gather transform rank together"""
|
|
2759
|
+
tasks = []
|
|
2760
|
+
for root, dirs, _ in os.walk(file_path):
|
|
2761
|
+
if root != file_path:
|
|
2762
|
+
continue
|
|
2763
|
+
|
|
2764
|
+
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2765
|
+
if not rank_dirs:
|
|
2766
|
+
raise ValueError(
|
|
2767
|
+
f"For 'safetensors_to_ckpt', no directories starting with 'rank' found in {file_path}")
|
|
2768
|
+
|
|
2769
|
+
for rank_dir in rank_dirs:
|
|
2770
|
+
rank_dir_path = os.path.join(root, rank_dir)
|
|
2771
|
+
dst_root = os.path.join(save_path,
|
|
2772
|
+
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
2773
|
+
os.makedirs(dst_root, exist_ok=True)
|
|
2774
|
+
tasks.extend(
|
|
2775
|
+
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
2776
|
+
for file in os.listdir(rank_dir_path)
|
|
2777
|
+
if file.endswith(".safetensors") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
2778
|
+
)
|
|
2779
|
+
return tasks
|
|
2780
|
+
|
|
2781
|
+
|
|
2782
|
+
def _gather_tasks_covert(file_path, save_path, file_name_regex, name_map):
|
|
2783
|
+
"""gather transform rank together"""
|
|
2784
|
+
tasks = []
|
|
2785
|
+
for root, dirs, _ in os.walk(file_path):
|
|
2786
|
+
if root != file_path:
|
|
2787
|
+
continue
|
|
2788
|
+
|
|
2789
|
+
rank_dirs = [d for d in dirs if d.startswith('rank')]
|
|
2790
|
+
if not rank_dirs:
|
|
2791
|
+
raise ValueError(
|
|
2792
|
+
f"For 'ckpt_to_safetensors', no directories starting with 'rank' found in {file_path}")
|
|
2793
|
+
|
|
2794
|
+
for rank_dir in rank_dirs:
|
|
2795
|
+
rank_dir_path = os.path.join(root, rank_dir)
|
|
2796
|
+
dst_root = os.path.join(save_path,
|
|
2797
|
+
os.path.relpath(rank_dir_path, file_path)) if save_path else rank_dir_path
|
|
2798
|
+
os.makedirs(dst_root, exist_ok=True)
|
|
2799
|
+
tasks.extend(
|
|
2800
|
+
(os.path.join(rank_dir_path, file), name_map, dst_root, file)
|
|
2801
|
+
for file in os.listdir(rank_dir_path)
|
|
2802
|
+
if file.endswith(".ckpt") and (file_name_regex is None or re.findall(file_name_regex, file))
|
|
2803
|
+
)
|
|
2804
|
+
return tasks
|
|
2805
|
+
|
|
2806
|
+
|
|
2807
|
+
def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
2808
|
+
"""
|
|
2809
|
+
Converts MindSpore checkpoint files into safetensors format and saves them to `save_path`.
|
|
2810
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
2811
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
2812
|
+
|
|
2813
|
+
Note:
|
|
2814
|
+
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
2815
|
+
too large, otherwise it may cause freezing.
|
|
2816
|
+
The safetensors format does not support the enc verification function. If ckpt is enabled to save enc
|
|
2817
|
+
verification, an error will be generated when performing the conversion.
|
|
2818
|
+
The safetensors format currently does not support crc verification function. If ckpt contains crc verification
|
|
2819
|
+
information, the crc verification information will be lost after conversion to safetensors.
|
|
2820
|
+
|
|
2821
|
+
Args:
|
|
2822
|
+
file_path (str): Path to the directory containing checkpoint files or a single checkpoint file (.ckpt).
|
|
2823
|
+
save_path (str, optional): Directory path where safetensors files will be saved. Defaults: ``None``.
|
|
2824
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
2825
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2826
|
+
Defaults: ``None``.
|
|
2827
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
2828
|
+
Raises:
|
|
2829
|
+
ValueError: If the input path is invalid or the save_path is not a directory,
|
|
2830
|
+
or the file_path does not end with '.ckpt'.
|
|
2831
|
+
|
|
2832
|
+
Supported Platforms:
|
|
2833
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2834
|
+
|
|
2835
|
+
Examples:
|
|
2836
|
+
>>> import mindspore as ms
|
|
2837
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path")
|
|
2838
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt")
|
|
2839
|
+
>>> ms.ckpt_to_safetensors(file_path="./ckpt_save_path/rank0/checkpoint_0.ckpt", save_path="./new_path/")
|
|
2840
|
+
>>> namemap = {"lin.weight":"new_name"}
|
|
2841
|
+
>>> ms.ckpt_to_safetensors("./ckpt_save_path/rank0/checkpoint_0.ckpt", "./new_path/", namemap)
|
|
2842
|
+
"""
|
|
2843
|
+
is_dir = os.path.isdir(file_path)
|
|
2844
|
+
is_file = os.path.isfile(file_path)
|
|
2845
|
+
if not is_dir and not is_file:
|
|
2846
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input path must be a valid path or file, but got {file_path}")
|
|
2847
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
2848
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the save_path must be a directory, but got '{save_path}'")
|
|
2849
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
2850
|
+
raise ValueError(
|
|
2851
|
+
f"For 'ckpt_to_safetensors', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2852
|
+
|
|
2853
|
+
if is_dir:
|
|
2854
|
+
tasks = _gather_tasks_covert(file_path, save_path, file_name_regex, name_map)
|
|
2855
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
2856
|
+
list(_progress_bar(pool.imap(_process_file, tasks), total=len(tasks)))
|
|
2857
|
+
elif is_file:
|
|
2858
|
+
if not file_path.endswith(".ckpt"):
|
|
2859
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input file must be a .ckpt file, but got {file_path}")
|
|
2860
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2861
|
+
raise ValueError(f"For 'ckpt_to_safetensors', the input file does not match the regular expression.")
|
|
2862
|
+
if save_path and not os.path.exists(save_path):
|
|
2863
|
+
os.makedirs(save_path, exist_ok=True)
|
|
2864
|
+
|
|
2865
|
+
param_dict_numpy = _transform_tensor_to_numpy(file_path, name_map)
|
|
2866
|
+
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
2867
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
2868
|
+
save_file(param_dict_numpy, dst_file)
|
|
2869
|
+
|
|
2870
|
+
|
|
2871
|
+
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
2872
|
+
"""
|
|
2873
|
+
Converts safetensors files into MindSpore checkpoint format and saves them to `save_path`.
|
|
2874
|
+
Safetensors is a reliable and portable machine learning model storage format introduced by Huggingface,
|
|
2875
|
+
used for securely storing Tensors with fast speed (zero copy).
|
|
2876
|
+
|
|
2877
|
+
Note:
|
|
2878
|
+
The number of multiprocess settings is related to the size of the host, and it is not recommended to set it
|
|
2879
|
+
too large, otherwise it may cause freezing.
|
|
2880
|
+
|
|
2881
|
+
Args:
|
|
2882
|
+
file_path (str): Path to the directory containing safetensors files or a single safetensors file (.safetensors).
|
|
2883
|
+
save_path (str, optional): Directory path where checkpoint files will be saved. Defaults: ``None``.
|
|
2884
|
+
name_map (dict, optional): Dictionary mapping original parameter names to new names. Defaults: ``None``.
|
|
2885
|
+
file_name_regex (str, optional): Regular expression used to match the file that needs to be converted.
|
|
2886
|
+
Defaults: ``None``.
|
|
2887
|
+
processes_num (int, optional): Number of processes to use for parallel processing. Defaults: 1.
|
|
2888
|
+
|
|
2889
|
+
Raises:
|
|
2890
|
+
ValueError: If the input path is invalid, the save_path is not a directory,
|
|
2891
|
+
or the file_path does not end with '.safetensors'.
|
|
2892
|
+
|
|
2893
|
+
Supported Platforms:
|
|
2894
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2895
|
+
|
|
2896
|
+
Examples:
|
|
2897
|
+
>>> import mindspore as ms
|
|
2898
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path")
|
|
2899
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors")
|
|
2900
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/")
|
|
2901
|
+
>>> namemap = {"lin.weight":"new_name"}
|
|
2902
|
+
>>> ms.safetensors_to_ckpt("./safetensors_save_path/rank0/checkpoint_0.safetensors", "./new_path/", namemap)
|
|
2903
|
+
"""
|
|
2904
|
+
is_dir = os.path.isdir(file_path)
|
|
2905
|
+
is_file = os.path.isfile(file_path)
|
|
2906
|
+
if not is_dir and not is_file:
|
|
2907
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the input path must be a valid path or file, but got {file_path}")
|
|
2908
|
+
if save_path and os.path.splitext(save_path)[1]:
|
|
2909
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the save_path must be a directory, but got '{save_path}'")
|
|
2910
|
+
if name_map is not None and not isinstance(name_map, dict):
|
|
2911
|
+
raise ValueError(
|
|
2912
|
+
f"For 'safetensors_to_ckpt', the type of 'name_map' must be a directory, but got '{type(name_map)}'")
|
|
2913
|
+
|
|
2914
|
+
if is_dir:
|
|
2915
|
+
tasks = _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map)
|
|
2916
|
+
with mp.Pool(processes=processes_num) as pool:
|
|
2917
|
+
list(_progress_bar(pool.imap(_process_file_safetensors, tasks), total=len(tasks)))
|
|
2918
|
+
elif is_file:
|
|
2919
|
+
if not file_path.endswith(".safetensors"):
|
|
2920
|
+
raise ValueError(
|
|
2921
|
+
f"For 'safetensors_to_ckpt', the input file must be a .safetensors file, but got {file_path}")
|
|
2922
|
+
if file_name_regex is not None and not re.findall(file_name_regex, file_path):
|
|
2923
|
+
raise ValueError(f"For 'safetensors_to_ckpt', the input file does not match the regular expression.")
|
|
2924
|
+
if save_path and not os.path.exists(save_path):
|
|
2925
|
+
os.makedirs(save_path, exist_ok=True)
|
|
2926
|
+
|
|
2927
|
+
param_dict_tensor = _transform_numpy_to_tensor(file_path, name_map)
|
|
2928
|
+
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
2929
|
+
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
2930
|
+
mindspore.save_checkpoint(param_dict_tensor, dst_file)
|
|
2931
|
+
|
|
2932
|
+
|
|
2933
|
+
def restore_group_info_list(group_info_file_name):
|
|
2934
|
+
"""
|
|
2935
|
+
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
2936
|
+
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2937
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2938
|
+
"""
|
|
2939
|
+
return new_restore_group_info_list(group_info_file_name)
|
|
2940
|
+
|
|
2941
|
+
|
|
2942
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2943
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2944
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None,
|
|
2945
|
+
output_format='safetensors', name_map=None, max_process_num=64,
|
|
2946
|
+
return_param_dict=False):
|
|
2947
|
+
""" Load checkpoint into net for distributed predication. Used in the case of distributed inference. """
|
|
2948
|
+
new_load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy,
|
|
2949
|
+
train_strategy_filename, strict_load, dec_key, dec_mode,
|
|
2950
|
+
format, unified_safetensors_dir, dst_safetensors_dir, rank_id,
|
|
2951
|
+
output_format, name_map, max_process_num,
|
|
2952
|
+
return_param_dict)
|
|
2953
|
+
|
|
2954
|
+
|
|
2955
|
+
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
2956
|
+
""" Merge parameter slices into one parameter. Used in the case of distributed inference. """
|
|
2957
|
+
return new_merge_sliced_parameter(sliced_parameters, strategy)
|
|
2958
|
+
|
|
2959
|
+
|
|
2960
|
+
def build_searched_strategy(strategy_filename):
|
|
2961
|
+
""" Build strategy of every parameter in network. Used in the case of distributed inference. """
|
|
2962
|
+
return new_build_searched_strategy(strategy_filename)
|