mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,19 +17,23 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from __future__ import division
|
|
19
19
|
|
|
20
|
+
import binascii
|
|
20
21
|
import copy
|
|
21
22
|
import json
|
|
22
23
|
import os
|
|
24
|
+
import re
|
|
23
25
|
import shutil
|
|
24
26
|
import stat
|
|
25
27
|
import threading
|
|
26
28
|
from threading import Thread, RLock
|
|
29
|
+
from multiprocessing import Process
|
|
27
30
|
from collections import defaultdict, OrderedDict
|
|
28
31
|
from io import BytesIO
|
|
29
32
|
|
|
30
33
|
import math
|
|
31
34
|
import sys
|
|
32
35
|
import time
|
|
36
|
+
import google
|
|
33
37
|
import numpy as np
|
|
34
38
|
|
|
35
39
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
@@ -50,32 +54,41 @@ from mindspore.common.api import _generate_branch_control_input
|
|
|
50
54
|
from mindspore.common.initializer import initializer, One
|
|
51
55
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
52
56
|
from mindspore.common.tensor import Tensor
|
|
57
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
53
58
|
from mindspore.common._utils import is_shape_unknown
|
|
59
|
+
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
54
60
|
from mindspore.communication.management import get_rank, get_group_size
|
|
55
61
|
from mindspore.experimental import MapParameter
|
|
56
|
-
from mindspore.
|
|
62
|
+
from mindspore.ops import Cast
|
|
63
|
+
from mindspore.parallel._cell_wrapper import get_allgather_cell, _single_parameter_broadcast
|
|
57
64
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
58
65
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
59
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
66
|
+
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode, \
|
|
67
|
+
_get_device_num, _is_parallel_mode
|
|
68
|
+
from mindspore.parallel._auto_parallel_context import _get_auto_parallel_context
|
|
60
69
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
61
|
-
_restore_group_info_list
|
|
70
|
+
_restore_group_info_list, _get_param_list_when_first_dim_sharded
|
|
62
71
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
72
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
64
|
-
from mindspore.
|
|
73
|
+
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
74
|
+
from mindspore.parallel.transform_safetensors import _load_parallel_checkpoint, _get_device_num_from_strategy, \
|
|
75
|
+
_extract_pipeline_stage_num
|
|
76
|
+
from mindspore.train._utils import read_proto, get_parameter_redundancy
|
|
65
77
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
66
78
|
split_mindir, split_dynamic_mindir
|
|
79
|
+
from mindspore.common.generator import Generator
|
|
80
|
+
from safetensors.numpy import save_file
|
|
81
|
+
from safetensors import safe_open
|
|
67
82
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
68
|
-
from ..ops.operations import Cast
|
|
69
83
|
|
|
70
84
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
71
85
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
72
86
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
73
|
-
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
|
|
87
|
+
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
|
|
74
88
|
|
|
75
89
|
tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
76
90
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
77
|
-
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"
|
|
78
|
-
"BFloat16": np.float32}
|
|
91
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
79
92
|
|
|
80
93
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
81
94
|
|
|
@@ -95,6 +108,92 @@ INT_64_MAX = 9223372036854775807
|
|
|
95
108
|
|
|
96
109
|
cpu_cast = Cast().set_device("CPU")
|
|
97
110
|
|
|
111
|
+
_ckpt_fs = FileSystem()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def init_ckpt_file_system(fs: FileSystem):
|
|
115
|
+
"""Initialize checkpoint file system"""
|
|
116
|
+
if _register_mindio_file_system(fs):
|
|
117
|
+
return
|
|
118
|
+
_register_basic_file_system(fs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# Initialize checkpoint file system
|
|
122
|
+
init_ckpt_file_system(_ckpt_fs)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_cur_rank_dp(parameter_layout_dict):
|
|
126
|
+
""" Get dp and tp from layout dict. """
|
|
127
|
+
pp_num = _get_auto_parallel_context("pipeline_stages")
|
|
128
|
+
dev_num = _get_device_num()
|
|
129
|
+
global_rank = get_rank()
|
|
130
|
+
pipe_size = dev_num // pp_num
|
|
131
|
+
initial_rank = (global_rank // pipe_size) * pipe_size
|
|
132
|
+
parameter_redundancy_dict = get_parameter_redundancy(
|
|
133
|
+
parameter_layout_dict, initial_rank)
|
|
134
|
+
value_len = sys.maxsize
|
|
135
|
+
min_value = ()
|
|
136
|
+
for key, value in parameter_redundancy_dict.items():
|
|
137
|
+
if "accu_grads" in key or "inputs" in key:
|
|
138
|
+
continue
|
|
139
|
+
for item in value:
|
|
140
|
+
if len(item) < value_len and global_rank in item:
|
|
141
|
+
value_len = len(item)
|
|
142
|
+
min_value = item
|
|
143
|
+
return min_value
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_ckpt_path_with_strategy(cur_ckpt_path, cur_strategy_path):
|
|
147
|
+
"""
|
|
148
|
+
Find available checkpoint file path from all backup checkpoint files of current rank.
|
|
149
|
+
It suppose that checkpoint path contains substring 'rank_{rank_id}' which is used to
|
|
150
|
+
distinguish between different path.If cur_ckpt_path doesn't have 'rank_{rank_id}' substring, will return
|
|
151
|
+
cur_ckpt_path itself when cur_ckpt_path is exist, otherwise return None.
|
|
152
|
+
|
|
153
|
+
Note:
|
|
154
|
+
This API must be called after the communication is initialized because the cluster information
|
|
155
|
+
needs to be obtained internally.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
cur_ckpt_path (str): the checkpoint file path which cur rank needs.
|
|
159
|
+
cur_strategy_path (str): strategy file path for current rank.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
- new_ckpt_file (String), if found available checkpoint file , return it.
|
|
163
|
+
- None, if not found available checkpoint, return None.
|
|
164
|
+
|
|
165
|
+
Examples:
|
|
166
|
+
>>> import mindspore as ms
|
|
167
|
+
>>> from mindspore.communication import init
|
|
168
|
+
>>> from mindspore import get_ckpt_path_with_strategy
|
|
169
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
170
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
|
|
171
|
+
>>> init()
|
|
172
|
+
>>> ckpt_file= "./rank_5/iteration-1_40.ckpt"
|
|
173
|
+
>>> strategy_file = "./src_pipeline_strategys/src_strategy_5.ckpt"
|
|
174
|
+
>>> ckpt_file_new = get_ckpt_path_with_strategy(ckpt_file, strategy_file)
|
|
175
|
+
>>> print(ckpt_file_new)
|
|
176
|
+
"""
|
|
177
|
+
dp = _get_cur_rank_dp(cur_strategy_path)
|
|
178
|
+
pattern = r'rank_\d+'
|
|
179
|
+
for i in dp:
|
|
180
|
+
new_ckpt_path = re.sub(pattern, f"rank_{str(i)}", cur_ckpt_path)
|
|
181
|
+
if not os.path.isfile(new_ckpt_path):
|
|
182
|
+
continue
|
|
183
|
+
return new_ckpt_path
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class ParamDictFuture:
|
|
188
|
+
def __init__(self, executor, param_dict_future):
|
|
189
|
+
self.executor = executor
|
|
190
|
+
self.param_dict_future = param_dict_future
|
|
191
|
+
|
|
192
|
+
def result(self):
|
|
193
|
+
param_dict = self.param_dict_future.result()
|
|
194
|
+
self.executor.shutdown()
|
|
195
|
+
return param_dict
|
|
196
|
+
|
|
98
197
|
|
|
99
198
|
def _special_process_par(par, new_par):
|
|
100
199
|
"""
|
|
@@ -221,53 +320,72 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
221
320
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
222
321
|
|
|
223
322
|
|
|
224
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False
|
|
323
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
|
|
324
|
+
format="ckpt"):
|
|
225
325
|
"""Execute the process of saving checkpoint into file."""
|
|
226
326
|
try:
|
|
227
327
|
with _ckpt_mutex:
|
|
328
|
+
file_name_list = list(os.path.splitext(ckpt_file_name))
|
|
329
|
+
file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
|
|
330
|
+
tmp_name = ''.join(file_name_list)
|
|
228
331
|
if os.path.exists(ckpt_file_name):
|
|
229
332
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
230
333
|
os.remove(ckpt_file_name)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
if value[
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
334
|
+
if os.path.exists(tmp_name):
|
|
335
|
+
os.chmod(tmp_name, stat.S_IWUSR)
|
|
336
|
+
os.remove(tmp_name)
|
|
337
|
+
if format == "ckpt":
|
|
338
|
+
with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
|
|
339
|
+
plain_data = None
|
|
340
|
+
if enc_key is not None:
|
|
341
|
+
plain_data = BytesIO()
|
|
342
|
+
|
|
343
|
+
crc_num = 0
|
|
344
|
+
for name, value in data_list.items():
|
|
345
|
+
if name == "random_op":
|
|
346
|
+
_write_random_seed(name, value, f)
|
|
347
|
+
continue
|
|
348
|
+
if value[0] == "mapparameter":
|
|
349
|
+
_write_mapparameter(name, value, f, map_param_inc)
|
|
350
|
+
continue
|
|
351
|
+
if value[0] == "offload_parameter":
|
|
352
|
+
new_value = value[1:]
|
|
353
|
+
new_value[2] = value[3]
|
|
354
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
355
|
+
_offload_if_config(value[3])
|
|
356
|
+
continue
|
|
357
|
+
if value[1] == "str":
|
|
358
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
359
|
+
continue
|
|
360
|
+
if isinstance(value[2], np.ndarray):
|
|
361
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
362
|
+
continue
|
|
363
|
+
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
364
|
+
_write_hugeparameter(name, value, f)
|
|
365
|
+
continue
|
|
366
|
+
|
|
367
|
+
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
368
|
+
|
|
369
|
+
if enc_key is not None:
|
|
370
|
+
plain_data.seek(0)
|
|
371
|
+
max_block_size = ENCRYPT_BLOCK_SIZE * 1024
|
|
267
372
|
block_data = plain_data.read(max_block_size)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
373
|
+
while block_data:
|
|
374
|
+
f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
|
|
375
|
+
block_data = plain_data.read(max_block_size)
|
|
376
|
+
if crc_check:
|
|
377
|
+
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
378
|
+
elif format == "safetensors":
|
|
379
|
+
save_dict = {}
|
|
380
|
+
for name, value in data_list.items():
|
|
381
|
+
save_dict[name] = value[2].asnumpy()
|
|
382
|
+
save_file(save_dict, tmp_name)
|
|
383
|
+
if not os.path.exists(tmp_name):
|
|
384
|
+
logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
|
|
385
|
+
f"simultaneously modified a file.")
|
|
386
|
+
else:
|
|
387
|
+
os.rename(tmp_name, ckpt_file_name)
|
|
388
|
+
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
271
389
|
except BaseException as e:
|
|
272
390
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
273
391
|
"or the disk space is insufficient and so on.", ckpt_file_name)
|
|
@@ -286,22 +404,7 @@ def _write_random_seed(name, value, f):
|
|
|
286
404
|
f.write(checkpoint_list.SerializeToString())
|
|
287
405
|
|
|
288
406
|
|
|
289
|
-
def
|
|
290
|
-
"""Write bfloat16 data into protobuf file"""
|
|
291
|
-
checkpoint_list = Checkpoint()
|
|
292
|
-
param_value = checkpoint_list.value.add()
|
|
293
|
-
param_value.tag = name
|
|
294
|
-
param_tensor = param_value.tensor
|
|
295
|
-
param_tensor.dims.extend(value[1])
|
|
296
|
-
param_tensor.tensor_type = value[2]
|
|
297
|
-
param_tensor.tensor_content = value[3].get_bytes()
|
|
298
|
-
if enc_key is None:
|
|
299
|
-
f.write(checkpoint_list.SerializeToString())
|
|
300
|
-
else:
|
|
301
|
-
plain_data.write(checkpoint_list.SerializeToString())
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
407
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
305
408
|
"""Write parameter data into protobuf file."""
|
|
306
409
|
data_size = value[2].nbytes / 1024
|
|
307
410
|
if data_size > SLICE_SIZE:
|
|
@@ -320,10 +423,40 @@ def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
|
320
423
|
param_tensor.tensor_content = param_slice.tobytes()
|
|
321
424
|
|
|
322
425
|
if enc_key is None:
|
|
323
|
-
|
|
426
|
+
output_data = checkpoint_list.SerializeToString()
|
|
427
|
+
if crc_check:
|
|
428
|
+
crc_num = binascii.crc32(output_data, crc_num)
|
|
429
|
+
f.write(output_data)
|
|
430
|
+
else:
|
|
431
|
+
plain_data.write(checkpoint_list.SerializeToString())
|
|
432
|
+
|
|
433
|
+
return crc_num
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
437
|
+
"""Write parameter bytes data into protobuf file."""
|
|
438
|
+
bytes_value = value[2].get_bytes()
|
|
439
|
+
chunk_size = 1024 * SLICE_SIZE
|
|
440
|
+
|
|
441
|
+
for i in range(0, len(bytes_value), chunk_size):
|
|
442
|
+
checkpoint_list = Checkpoint()
|
|
443
|
+
param_value = checkpoint_list.value.add()
|
|
444
|
+
param_value.tag = name
|
|
445
|
+
param_tensor = param_value.tensor
|
|
446
|
+
param_tensor.dims.extend(value[0])
|
|
447
|
+
param_tensor.tensor_type = value[1]
|
|
448
|
+
param_tensor.tensor_content = bytes_value[i:i + chunk_size]
|
|
449
|
+
|
|
450
|
+
if enc_key is None:
|
|
451
|
+
output_data = checkpoint_list.SerializeToString()
|
|
452
|
+
if crc_check:
|
|
453
|
+
crc_num = binascii.crc32(output_data, crc_num)
|
|
454
|
+
f.write(output_data)
|
|
324
455
|
else:
|
|
325
456
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
326
457
|
|
|
458
|
+
return crc_num
|
|
459
|
+
|
|
327
460
|
|
|
328
461
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
329
462
|
"""Write map parameter into protobuf file."""
|
|
@@ -365,8 +498,11 @@ def _write_hugeparameter(name, value, f):
|
|
|
365
498
|
offset += numpy_data.shape[0]
|
|
366
499
|
|
|
367
500
|
|
|
368
|
-
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
501
|
+
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
369
502
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
503
|
+
if format not in ["safetensors", "ckpt"]:
|
|
504
|
+
raise ValueError(f"For 'save_checkpoint', the format must be "
|
|
505
|
+
f"'safetensors' or 'ckpt', but got {format}.")
|
|
370
506
|
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
371
507
|
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
372
508
|
"but got {}.".format(type(save_obj)))
|
|
@@ -374,20 +510,32 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
374
510
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
375
511
|
"'ckpt_file_name' must be "
|
|
376
512
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
|
377
|
-
ckpt_file_name = os.path.
|
|
513
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
378
514
|
if os.path.isdir(ckpt_file_name):
|
|
379
515
|
raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
|
|
380
516
|
"it must be a file name.".format(ckpt_file_name))
|
|
381
|
-
if not ckpt_file_name.endswith(
|
|
382
|
-
ckpt_file_name += ".
|
|
517
|
+
if not ckpt_file_name.endswith(format):
|
|
518
|
+
ckpt_file_name += f".{format}"
|
|
383
519
|
return ckpt_file_name
|
|
384
520
|
|
|
385
521
|
|
|
522
|
+
def _check_format_and_other_params(format, enc_key, enc_mode, crc_check=False, async_save=False, map_param_inc=False,
|
|
523
|
+
global_step_num=None):
|
|
524
|
+
param_not_default = (enc_key is not None or enc_mode != "AES-GCM" or crc_check or async_save
|
|
525
|
+
or map_param_inc or global_step_num is not None)
|
|
526
|
+
if format == "safetensors" and param_not_default:
|
|
527
|
+
raise ValueError("For 'save_checkpoint', when format is 'safetensors', other param must be default.")
|
|
528
|
+
|
|
529
|
+
|
|
386
530
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
387
|
-
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
531
|
+
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
532
|
+
crc_check=False, format="ckpt", **kwargs):
|
|
388
533
|
r"""
|
|
389
534
|
Save checkpoint to a specified file.
|
|
390
535
|
|
|
536
|
+
Note:
|
|
537
|
+
The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
|
|
538
|
+
|
|
391
539
|
Args:
|
|
392
540
|
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
393
541
|
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
@@ -409,6 +557,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
409
557
|
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
410
558
|
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
411
559
|
be saved. Default: ``None`` .
|
|
560
|
+
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
561
|
+
result to the file. Default: ``False`` .
|
|
562
|
+
format (str): Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
412
563
|
kwargs (dict): Configuration options dictionary.
|
|
413
564
|
|
|
414
565
|
Raises:
|
|
@@ -420,7 +571,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
420
571
|
>>> import mindspore as ms
|
|
421
572
|
>>>
|
|
422
573
|
>>> # Define the network structure of LeNet5. Refer to
|
|
423
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
574
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
424
575
|
>>> net = LeNet5()
|
|
425
576
|
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
426
577
|
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
@@ -440,35 +591,57 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
440
591
|
|
|
441
592
|
Tutorial Examples:
|
|
442
593
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
443
|
-
<https://mindspore.cn/tutorials/en/
|
|
594
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
444
595
|
"""
|
|
445
|
-
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
596
|
+
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
446
597
|
integrated_save = Validator.check_bool(integrated_save)
|
|
447
598
|
async_save = Validator.check_bool(async_save)
|
|
448
599
|
append_dict = _check_append_dict(append_dict)
|
|
449
600
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
450
601
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
602
|
+
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
451
603
|
map_param_inc = kwargs.get('incremental', False)
|
|
452
604
|
logger.info("Execute the process of saving checkpoint files.")
|
|
453
|
-
|
|
454
|
-
|
|
605
|
+
global_step_num = kwargs.get('global_step_num', None)
|
|
606
|
+
_check_format_and_other_params(format, enc_key, enc_mode, crc_check, async_save, map_param_inc, global_step_num)
|
|
607
|
+
|
|
608
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
609
|
+
s1 = mindspore.hal.Stream()
|
|
610
|
+
with mindspore.hal.StreamCtx(s1):
|
|
611
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
612
|
+
s1.synchronize()
|
|
613
|
+
else:
|
|
614
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
455
615
|
|
|
456
616
|
if append_dict:
|
|
617
|
+
if "__exception_save__" in append_dict:
|
|
618
|
+
del append_dict["__exception_save__"]
|
|
457
619
|
append_info_list = []
|
|
458
620
|
for k_name, value in append_dict.items():
|
|
459
|
-
if
|
|
621
|
+
if isinstance(value, Generator):
|
|
622
|
+
value = value.get_state()
|
|
623
|
+
elif not isinstance(value, str):
|
|
460
624
|
value = Tensor(value)
|
|
461
625
|
append_info_list.append({"name": k_name, "data": value})
|
|
462
626
|
save_obj.extend(append_info_list)
|
|
463
627
|
|
|
464
628
|
data_list = OrderedDict()
|
|
629
|
+
data_list_np = OrderedDict()
|
|
465
630
|
with _ckpt_mutex:
|
|
466
631
|
for param in save_obj:
|
|
467
632
|
if param["name"] == "random_op":
|
|
468
|
-
|
|
633
|
+
if os.getenv("AITURBO") == "1":
|
|
634
|
+
data_list_np["random_op"] = []
|
|
635
|
+
data_list_np["random_op"].append(param["data"])
|
|
636
|
+
if crc_check:
|
|
637
|
+
bytes_value = bytes(data_list_np[key][0])
|
|
638
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
639
|
+
else:
|
|
640
|
+
data_list["random_op"] = param["data"]
|
|
469
641
|
continue
|
|
470
642
|
key = param["name"]
|
|
471
643
|
data_list[key] = []
|
|
644
|
+
data_list_np[key] = []
|
|
472
645
|
if isinstance(param["data"], MapParameter):
|
|
473
646
|
data_list[param["name"]].append("mapparameter")
|
|
474
647
|
data_list[param["name"]].append(param["data"])
|
|
@@ -479,49 +652,48 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
479
652
|
elif param["data"][0] == "offload_parameter":
|
|
480
653
|
data_list[key].append("offload_parameter")
|
|
481
654
|
_save_param_list_data(data_list, key, param)
|
|
482
|
-
elif param["data"][0] == "BFloat16_tensor":
|
|
483
|
-
data_list[key].append("BFloat16_tensor")
|
|
484
|
-
_save_param_list_data(data_list, key, param)
|
|
485
|
-
continue
|
|
486
655
|
|
|
487
656
|
if isinstance(param["data"], str):
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
657
|
+
if os.getenv("AITURBO") == "1":
|
|
658
|
+
data_list_np[key].append(np.array(param["data"]))
|
|
659
|
+
if crc_check:
|
|
660
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
661
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
662
|
+
else:
|
|
663
|
+
data_list[key].append([0])
|
|
664
|
+
data_list[key].append('str')
|
|
665
|
+
data = np.array(param["data"])
|
|
666
|
+
data_list[key].append(data)
|
|
492
667
|
else:
|
|
493
668
|
if isinstance(param["data"], Parameter):
|
|
494
669
|
param["data"].init_data()
|
|
495
|
-
if
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
data_list[key].append(dims)
|
|
501
|
-
data_list[key].append("BFloat16")
|
|
502
|
-
data_list[key].append(cpu_cast(param["data"], mstype.float32))
|
|
503
|
-
continue
|
|
504
|
-
dims = []
|
|
505
|
-
if param['data'].shape == ():
|
|
506
|
-
dims.append(0)
|
|
670
|
+
if os.getenv("AITURBO") == "1":
|
|
671
|
+
data_list_np[key].append(param["data"].asnumpy())
|
|
672
|
+
if crc_check:
|
|
673
|
+
bytes_value = data_list_np[key][0].tobytes()
|
|
674
|
+
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
507
675
|
else:
|
|
676
|
+
dims = []
|
|
508
677
|
for dim in param['data'].shape:
|
|
509
678
|
dims.append(dim)
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
679
|
+
data_list[key].append(dims)
|
|
680
|
+
tensor_type = str(param["data"].dtype)
|
|
681
|
+
data_list[key].append(tensor_type)
|
|
682
|
+
data = param["data"]
|
|
683
|
+
data_list[key].append(data)
|
|
684
|
+
|
|
685
|
+
if os.getenv("AITURBO") == "1":
|
|
686
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
687
|
+
ckpt_name = os.path.basename(ckpt_file_name)
|
|
688
|
+
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
|
|
689
|
+
elif async_save:
|
|
520
690
|
data_copy = copy.deepcopy(data_list)
|
|
521
|
-
thr = Thread(target=_exec_save,
|
|
691
|
+
thr = Thread(target=_exec_save,
|
|
692
|
+
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
|
|
693
|
+
name="asyn_save_ckpt")
|
|
522
694
|
thr.start()
|
|
523
695
|
else:
|
|
524
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
|
|
696
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
525
697
|
|
|
526
698
|
logger.info("Saving checkpoint process is finished.")
|
|
527
699
|
|
|
@@ -532,7 +704,21 @@ def _convert_list_to_param_list(save_obj, choice_func):
|
|
|
532
704
|
if not save_obj:
|
|
533
705
|
return param_list
|
|
534
706
|
if isinstance(save_obj[0], dict):
|
|
535
|
-
|
|
707
|
+
for param in save_obj:
|
|
708
|
+
if isinstance(param, dict) and "name" in param and "data" in param:
|
|
709
|
+
if not isinstance(param["name"], str):
|
|
710
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
|
|
711
|
+
f"should be string, but got {type(param['name'])}.")
|
|
712
|
+
if not isinstance(param["data"], Tensor):
|
|
713
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
|
|
714
|
+
f"should be parameter, but got {type(param['data'])}.")
|
|
715
|
+
if choice_func is not None and not choice_func(param["name"]):
|
|
716
|
+
continue
|
|
717
|
+
each_param = {"name": param["name"], "data": param["data"]}
|
|
718
|
+
param_list.append(each_param)
|
|
719
|
+
else:
|
|
720
|
+
raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
|
|
721
|
+
f"have key values 'name' and 'value', but got {type(param)} and {param}.")
|
|
536
722
|
else:
|
|
537
723
|
for param in save_obj:
|
|
538
724
|
if isinstance(param, Parameter):
|
|
@@ -585,6 +771,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
|
585
771
|
|
|
586
772
|
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
587
773
|
"""Convert nn.Cell to param_list."""
|
|
774
|
+
sync_pipeline_shared_parameters(save_obj)
|
|
588
775
|
param_list = []
|
|
589
776
|
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
590
777
|
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
@@ -597,7 +784,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
597
784
|
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
598
785
|
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
599
786
|
param_list.append({"name": "random_op", "data": random_byte})
|
|
600
|
-
|
|
787
|
+
append_dict.pop("random_op")
|
|
601
788
|
for (key, value) in param_dict.items():
|
|
602
789
|
each_param = {"name": key}
|
|
603
790
|
if isinstance(value, MapParameter):
|
|
@@ -619,18 +806,16 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
619
806
|
param_data.append(param_tensor.shape)
|
|
620
807
|
param_data.append(str(param_tensor.dtype))
|
|
621
808
|
param_data.append(value.key)
|
|
622
|
-
elif value.data.dtype == mstype.bfloat16:
|
|
623
|
-
param_data = ["BFloat16_tensor"]
|
|
624
|
-
param_data.append(cpu_cast(value.data, mstype.float32))
|
|
625
|
-
param_data.append(value.data.shape)
|
|
626
|
-
param_data.append("BFloat16")
|
|
627
|
-
param_data.append(value.key)
|
|
628
809
|
else:
|
|
629
|
-
param_data =
|
|
810
|
+
param_data = value.data
|
|
811
|
+
if append_dict and "__exception_save__" in append_dict:
|
|
812
|
+
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
630
813
|
|
|
631
814
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
632
815
|
# which should be combined before saving
|
|
633
816
|
if key in parameter_layout_dict:
|
|
817
|
+
if not append_dict or "__exception_save__" not in append_dict:
|
|
818
|
+
param_data = Tensor(value.data)
|
|
634
819
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
635
820
|
integrated_save)
|
|
636
821
|
|
|
@@ -670,9 +855,9 @@ def _check_append_dict(append_dict):
|
|
|
670
855
|
raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
|
|
671
856
|
"{}.".format(type(append_dict)))
|
|
672
857
|
for key, value in append_dict.items():
|
|
673
|
-
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor)):
|
|
858
|
+
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
|
|
674
859
|
raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
|
|
675
|
-
f"value: int, float or
|
|
860
|
+
f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
|
|
676
861
|
return append_dict
|
|
677
862
|
|
|
678
863
|
|
|
@@ -699,13 +884,13 @@ def load(file_name, **kwargs):
|
|
|
699
884
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
700
885
|
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
|
|
701
886
|
|
|
702
|
-
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'
|
|
887
|
+
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
|
|
703
888
|
- For details of using the customized decryption, please check the `tutorial
|
|
704
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
889
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
705
890
|
|
|
706
891
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
707
892
|
`obfuscate_model()
|
|
708
|
-
<https://www.mindspore.cn/docs/en/
|
|
893
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
709
894
|
|
|
710
895
|
Returns:
|
|
711
896
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -735,7 +920,7 @@ def load(file_name, **kwargs):
|
|
|
735
920
|
|
|
736
921
|
Tutorial Examples:
|
|
737
922
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
738
|
-
<https://mindspore.cn/tutorials/en/
|
|
923
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
739
924
|
"""
|
|
740
925
|
if not isinstance(file_name, str):
|
|
741
926
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -746,7 +931,7 @@ def load(file_name, **kwargs):
|
|
|
746
931
|
if not os.path.exists(file_name):
|
|
747
932
|
raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
|
|
748
933
|
"please check whether the 'file_name' is correct.")
|
|
749
|
-
file_name = os.path.
|
|
934
|
+
file_name = os.path.realpath(file_name)
|
|
750
935
|
|
|
751
936
|
# set customized functions for dynamic obfuscation
|
|
752
937
|
obfuscated = _check_load_obfuscate(**kwargs)
|
|
@@ -776,7 +961,7 @@ def load(file_name, **kwargs):
|
|
|
776
961
|
return graph
|
|
777
962
|
|
|
778
963
|
|
|
779
|
-
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=
|
|
964
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
|
|
780
965
|
"""
|
|
781
966
|
Auto Split MindIR.
|
|
782
967
|
|
|
@@ -784,10 +969,10 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
|
|
|
784
969
|
|
|
785
970
|
Args:
|
|
786
971
|
file_name (str): MindIR file name.
|
|
787
|
-
device_num (int): device number.
|
|
788
|
-
rank_id (int): rank id.
|
|
789
|
-
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
790
|
-
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
972
|
+
device_num (int): device number. Default: '8'.
|
|
973
|
+
rank_id (int): rank id. Default: '0'.
|
|
974
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
|
|
975
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
|
|
791
976
|
|
|
792
977
|
Raises:
|
|
793
978
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
@@ -809,7 +994,7 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
|
|
|
809
994
|
if not os.path.exists(file_name):
|
|
810
995
|
raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
|
|
811
996
|
"please check whether the 'file_name' is correct.")
|
|
812
|
-
file_name = os.path.
|
|
997
|
+
file_name = os.path.realpath(file_name)
|
|
813
998
|
|
|
814
999
|
logger.info("Execute the process of export and split mindir.")
|
|
815
1000
|
dynamic = True
|
|
@@ -909,13 +1094,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
909
1094
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
910
1095
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
911
1096
|
Reference to 'my_func()' in
|
|
912
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
1097
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
913
1098
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
914
1099
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
915
1100
|
when loading obfuscated model.
|
|
916
1101
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
917
1102
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
918
|
-
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell
|
|
1103
|
+
`obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
|
|
1104
|
+
interface when loading
|
|
919
1105
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
920
1106
|
be set, and the latter mode would be applied if both of them are set.
|
|
921
1107
|
|
|
@@ -923,7 +1109,7 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
923
1109
|
|
|
924
1110
|
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
925
1111
|
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
926
|
-
|
|
1112
|
+
Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
|
|
927
1113
|
|
|
928
1114
|
Raises:
|
|
929
1115
|
TypeError: If `obf_config` is not a dict.
|
|
@@ -934,11 +1120,15 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
934
1120
|
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
935
1121
|
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
936
1122
|
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
937
|
-
ValueError: If `original_model_path`
|
|
1123
|
+
ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
|
|
938
1124
|
|
|
939
1125
|
Examples:
|
|
940
1126
|
>>> import mindspore as ms
|
|
941
1127
|
>>> import mindspore.nn as nn
|
|
1128
|
+
>>> import numpy as np
|
|
1129
|
+
>>> # Download ori_net.mindir
|
|
1130
|
+
>>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
|
|
1131
|
+
>>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
|
942
1132
|
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
943
1133
|
... 'save_model_path': "./obf_net",
|
|
944
1134
|
... 'model_inputs': [input1, ],
|
|
@@ -998,12 +1188,81 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
998
1188
|
obf_net = nn.GraphCell(obf_graph)
|
|
999
1189
|
if obf_random_seed != 0:
|
|
1000
1190
|
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
1001
|
-
model_inputs += [append_y_tensor
|
|
1191
|
+
model_inputs += [append_y_tensor]
|
|
1002
1192
|
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
1003
1193
|
|
|
1004
1194
|
|
|
1195
|
+
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1196
|
+
dec_mode, crc_check, format):
|
|
1197
|
+
"""load parameter into parameter_dict"""
|
|
1198
|
+
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
|
|
1199
|
+
if format == "safetensors":
|
|
1200
|
+
with safe_open(ckpt_file_name, framework='np') as f:
|
|
1201
|
+
for k in f.keys():
|
|
1202
|
+
parameter_dict[k] = Parameter(f.get_tensor(k))
|
|
1203
|
+
return
|
|
1204
|
+
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1205
|
+
try:
|
|
1206
|
+
param_data_list = []
|
|
1207
|
+
map_data_list = [[], [], []]
|
|
1208
|
+
map_shape_list = [0, 0, 0]
|
|
1209
|
+
if specify_prefix:
|
|
1210
|
+
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1211
|
+
"please use `choice_func` instead.")
|
|
1212
|
+
if filter_prefix:
|
|
1213
|
+
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1214
|
+
"please use `choice_func` instead.")
|
|
1215
|
+
for element_id, element in enumerate(checkpoint_list.value):
|
|
1216
|
+
if element.tag == "random_op":
|
|
1217
|
+
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1218
|
+
continue
|
|
1219
|
+
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1220
|
+
continue
|
|
1221
|
+
if specify_prefix is None and filter_prefix is None and \
|
|
1222
|
+
choice_func is not None and not choice_func(element.tag):
|
|
1223
|
+
continue
|
|
1224
|
+
if element.tensor.ByteSize() == 0:
|
|
1225
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
|
|
1226
|
+
parameter_dict)
|
|
1227
|
+
if element.tag in parameter_dict:
|
|
1228
|
+
map_data_list = [[], [], []]
|
|
1229
|
+
map_shape_list = [0, 0, 0]
|
|
1230
|
+
continue
|
|
1231
|
+
data = element.tensor.tensor_content
|
|
1232
|
+
data_type = element.tensor.tensor_type
|
|
1233
|
+
np_type = tensor_to_np_type.get(data_type)
|
|
1234
|
+
ms_type = tensor_to_ms_type[data_type]
|
|
1235
|
+
if data_type == 'str':
|
|
1236
|
+
str_length = int(len(data) / 4)
|
|
1237
|
+
np_type = np_type + str(str_length)
|
|
1238
|
+
param_data_list.append(data)
|
|
1239
|
+
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1240
|
+
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1241
|
+
new_data = b"".join(param_data_list)
|
|
1242
|
+
param_data_list.clear()
|
|
1243
|
+
dims = element.tensor.dims
|
|
1244
|
+
if data_type == 'str':
|
|
1245
|
+
str_value = np.frombuffer(new_data, np_type)
|
|
1246
|
+
parameter_dict[element.tag] = str(str_value[0])
|
|
1247
|
+
else:
|
|
1248
|
+
if dims == [0]:
|
|
1249
|
+
dims = []
|
|
1250
|
+
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1251
|
+
parameter = Parameter(param_data, name=element.tag)
|
|
1252
|
+
parameter_dict[element.tag] = parameter
|
|
1253
|
+
_offload_if_config(parameter)
|
|
1254
|
+
|
|
1255
|
+
logger.info("Loading checkpoint files process is finished.")
|
|
1256
|
+
|
|
1257
|
+
except BaseException as e:
|
|
1258
|
+
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
1259
|
+
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
|
1260
|
+
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
|
1261
|
+
|
|
1262
|
+
|
|
1005
1263
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
1006
|
-
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None
|
|
1264
|
+
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
|
|
1265
|
+
crc_check=False, remove_redundancy=False, format="ckpt"):
|
|
1007
1266
|
"""
|
|
1008
1267
|
Load checkpoint info from a specified file.
|
|
1009
1268
|
|
|
@@ -1013,6 +1272,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1013
1272
|
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1014
1273
|
`choice_func` is recommended instead.
|
|
1015
1274
|
And using either of those two args will override `choice_func` at the same time.
|
|
1275
|
+
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1016
1276
|
|
|
1017
1277
|
Args:
|
|
1018
1278
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -1034,6 +1294,11 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1034
1294
|
and the return value is a bool. If returns ``True`` , the Parameter
|
|
1035
1295
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1036
1296
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1297
|
+
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1298
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1299
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1300
|
+
redundant-free loading is not enabled.
|
|
1301
|
+
format (str): Format of the input file, can be "ckpt" or "safetensors". Default: "ckpt".
|
|
1037
1302
|
|
|
1038
1303
|
Returns:
|
|
1039
1304
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1076,83 +1341,42 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1076
1341
|
|
|
1077
1342
|
Tutorial Examples:
|
|
1078
1343
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1079
|
-
<https://mindspore.cn/tutorials/en/
|
|
1344
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1080
1345
|
"""
|
|
1081
|
-
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1082
1346
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1083
1347
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1084
1348
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1085
1349
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1350
|
+
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1351
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1352
|
+
_check_format_and_other_params(format, dec_key, dec_mode, crc_check)
|
|
1086
1353
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1087
|
-
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
|
|
1088
1354
|
|
|
1089
1355
|
parameter_dict = {}
|
|
1090
|
-
try:
|
|
1091
|
-
param_data_list = []
|
|
1092
|
-
map_data_list = [[], [], []]
|
|
1093
|
-
map_shape_list = [0, 0, 0]
|
|
1094
|
-
if specify_prefix:
|
|
1095
|
-
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1096
|
-
"please use `choice_func` instead.")
|
|
1097
|
-
if filter_prefix:
|
|
1098
|
-
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1099
|
-
"please use `choice_func` instead.")
|
|
1100
|
-
for element_id, element in enumerate(checkpoint_list.value):
|
|
1101
|
-
if element.tag == "random_op":
|
|
1102
|
-
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1103
|
-
continue
|
|
1104
|
-
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1105
|
-
continue
|
|
1106
|
-
if specify_prefix is None and filter_prefix is None and \
|
|
1107
|
-
choice_func is not None and not choice_func(element.tag):
|
|
1108
|
-
continue
|
|
1109
|
-
if element.tensor.ByteSize() == 0:
|
|
1110
|
-
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
1111
|
-
if element.tag in parameter_dict:
|
|
1112
|
-
map_data_list = [[], [], []]
|
|
1113
|
-
map_shape_list = [0, 0, 0]
|
|
1114
|
-
continue
|
|
1115
|
-
data = element.tensor.tensor_content
|
|
1116
|
-
data_type = element.tensor.tensor_type
|
|
1117
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
1118
|
-
ms_type = tensor_to_ms_type[data_type]
|
|
1119
|
-
if data_type == 'str':
|
|
1120
|
-
str_length = int(len(data) / 4)
|
|
1121
|
-
np_type = np_type + str(str_length)
|
|
1122
|
-
if data_type == "BFloat16":
|
|
1123
|
-
dims = element.tensor.dims
|
|
1124
|
-
param_data = np.frombuffer(data, np_type)
|
|
1125
|
-
param_data = param_data.reshape(list(dims))
|
|
1126
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1127
|
-
parameter_dict[element.tag] = parameter
|
|
1128
|
-
continue
|
|
1129
|
-
element_data = np.frombuffer(data, np_type)
|
|
1130
|
-
param_data_list.append(element_data)
|
|
1131
|
-
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1132
|
-
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1133
|
-
new_data = b"".join(param_data_list)
|
|
1134
|
-
param_data = np.frombuffer(new_data, np_type)
|
|
1135
|
-
param_data_list.clear()
|
|
1136
|
-
dims = element.tensor.dims
|
|
1137
|
-
if dims == [0] and data_type == 'str':
|
|
1138
|
-
parameter_dict[element.tag] = str(element_data[0])
|
|
1139
|
-
else:
|
|
1140
|
-
if dims == [0] and 'Float' in data_type:
|
|
1141
|
-
param_data = float(param_data[0])
|
|
1142
|
-
if dims == [0] and 'Int' in data_type:
|
|
1143
|
-
param_data = int(param_data[0])
|
|
1144
|
-
if dims not in ([0], [1]):
|
|
1145
|
-
param_data = param_data.reshape(list(dims))
|
|
1146
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1147
|
-
parameter_dict[element.tag] = parameter
|
|
1148
|
-
_offload_if_config(parameter)
|
|
1149
|
-
|
|
1150
|
-
logger.info("Loading checkpoint files process is finished.")
|
|
1151
1356
|
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1357
|
+
if os.getenv("AITURBO") == "1":
|
|
1358
|
+
rank_id = get_rank()
|
|
1359
|
+
from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
|
|
1360
|
+
ckpt_path = os.path.dirname(ckpt_file_name)
|
|
1361
|
+
ckpt_name = os.path.basename(ckpt_file_name)
|
|
1362
|
+
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id, crc_check)
|
|
1363
|
+
for key, value in np_dict.items():
|
|
1364
|
+
if crc_check and len(value) != 2:
|
|
1365
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, if CRC check is enabled, "
|
|
1366
|
+
f"the length of the value must be 2, but got {len(value)}.")
|
|
1367
|
+
if isinstance(value, str):
|
|
1368
|
+
if crc_check and value[1] != binascii.crc32(np.array(value[0]).tobytes()):
|
|
1369
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the string has not "
|
|
1370
|
+
f"passed the CRC check and has been corrupted.")
|
|
1371
|
+
parameter_dict[key] = value[0]
|
|
1372
|
+
else:
|
|
1373
|
+
if crc_check and value[1] != binascii.crc32(value[0].tobytes()):
|
|
1374
|
+
raise ValueError(f"When loading a checkpoint from AITurbo, the value of the parameter has not "
|
|
1375
|
+
f"passed the CRC check and has been corrupted.")
|
|
1376
|
+
parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
|
|
1377
|
+
else:
|
|
1378
|
+
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1379
|
+
dec_mode, crc_check, format)
|
|
1156
1380
|
|
|
1157
1381
|
if not parameter_dict:
|
|
1158
1382
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1161,13 +1385,93 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1161
1385
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1162
1386
|
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
1163
1387
|
if net is not None:
|
|
1164
|
-
load_param_into_net(net, parameter_dict, strict_load)
|
|
1388
|
+
load_param_into_net(net, parameter_dict, strict_load, remove_redundancy)
|
|
1165
1389
|
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1166
1390
|
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
1167
1391
|
|
|
1168
1392
|
return parameter_dict
|
|
1169
1393
|
|
|
1170
1394
|
|
|
1395
|
+
def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
|
|
1396
|
+
dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
1397
|
+
"""
|
|
1398
|
+
Load checkpoint info from a specified file asyncly.
|
|
1399
|
+
|
|
1400
|
+
.. warning::
|
|
1401
|
+
This is an experimental API that is subject to change or deletion.
|
|
1402
|
+
|
|
1403
|
+
Note:
|
|
1404
|
+
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1405
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1406
|
+
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1407
|
+
`choice_func` is recommended instead.
|
|
1408
|
+
And using either of those two args will override `choice_func` at the same time.
|
|
1409
|
+
|
|
1410
|
+
Args:
|
|
1411
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
1412
|
+
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1413
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1414
|
+
parameter into net when parameter name's suffix in checkpoint file is the
|
|
1415
|
+
same as the parameter in the network. When the types are inconsistent
|
|
1416
|
+
perform type conversion on the parameters of the same type, such as float32
|
|
1417
|
+
to float16. Default: ``False`` .
|
|
1418
|
+
filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1419
|
+
starting with the `filter_prefix` will not be loaded. Default: ``None`` .
|
|
1420
|
+
dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
|
|
1421
|
+
the decryption is not required. Default: ``None`` .
|
|
1422
|
+
dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
|
|
1423
|
+
the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
|
|
1424
|
+
and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
|
|
1425
|
+
specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1426
|
+
starting with the specify_prefix will be loaded. Default: ``None`` .
|
|
1427
|
+
choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
|
|
1428
|
+
string, and the return value is a bool. If returns ``True`` , the Parameter
|
|
1429
|
+
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1430
|
+
matches the custom condition will be removed. Default: ``None`` .
|
|
1431
|
+
|
|
1432
|
+
Returns:
|
|
1433
|
+
A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
|
|
1434
|
+
|
|
1435
|
+
Raises:
|
|
1436
|
+
ValueError: Checkpoint file's format is incorrect.
|
|
1437
|
+
ValueError: Parameter's dict is None after load checkpoint file.
|
|
1438
|
+
TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
|
1439
|
+
|
|
1440
|
+
Examples:
|
|
1441
|
+
>>> import mindspore
|
|
1442
|
+
>>> from mindspore import nn
|
|
1443
|
+
>>> from mindspore.train import Model
|
|
1444
|
+
>>> from mindspore.amp import FixedLossScaleManager
|
|
1445
|
+
>>> from mindspore import context
|
|
1446
|
+
>>> from mindspore import load_checkpoint_async
|
|
1447
|
+
>>> from mindspore import load_param_into_net
|
|
1448
|
+
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
1449
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1450
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1451
|
+
>>> dataset = create_dataset()
|
|
1452
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1453
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1454
|
+
>>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1455
|
+
>>> net = LeNet5()
|
|
1456
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
1457
|
+
>>> loss_scale_manager = FixedLossScaleManager()
|
|
1458
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1459
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1460
|
+
... loss_scale_manager=loss_scale_manager)
|
|
1461
|
+
>>> pd_future = load_checkpoint_async(ckpt_file)
|
|
1462
|
+
>>> model.build(train_dataset=dataset, epoch=2)
|
|
1463
|
+
>>> param_dict = pd_future.result()
|
|
1464
|
+
>>> load_param_into_net(net, param_dict)
|
|
1465
|
+
>>> model.train(2, dataset)
|
|
1466
|
+
>>> print("param dict len: ", len(param_dict), flush=True)
|
|
1467
|
+
"""
|
|
1468
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
1469
|
+
executor = ThreadPoolExecutor(max_workers=2)
|
|
1470
|
+
param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
|
|
1471
|
+
dec_key, dec_mode, specify_prefix, choice_func)
|
|
1472
|
+
return ParamDictFuture(executor, param_dict_future)
|
|
1473
|
+
|
|
1474
|
+
|
|
1171
1475
|
def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
1172
1476
|
map_shape_list, parameter_dict):
|
|
1173
1477
|
"""load map parameter."""
|
|
@@ -1198,17 +1502,20 @@ def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
|
1198
1502
|
parameter_dict[element.tag] = map_array
|
|
1199
1503
|
|
|
1200
1504
|
|
|
1201
|
-
def _check_ckpt_file_name(ckpt_file_name):
|
|
1505
|
+
def _check_ckpt_file_name(ckpt_file_name, format):
|
|
1202
1506
|
"""Check function load_checkpoint's ckpt_file_name."""
|
|
1203
1507
|
if not isinstance(ckpt_file_name, str):
|
|
1204
1508
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
1205
1509
|
"but got {}.".format(type(ckpt_file_name)))
|
|
1206
1510
|
|
|
1207
|
-
if
|
|
1208
|
-
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
|
|
1511
|
+
if format not in ['ckpt', 'safetensors']:
|
|
1512
|
+
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt' or '.safetensors', please "
|
|
1209
1513
|
"input the correct 'ckpt_file_name'.")
|
|
1514
|
+
if not ckpt_file_name.endswith(format):
|
|
1515
|
+
raise ValueError(f"For 'load_checkpoint', the checkpoint file format must same with 'format', but got "
|
|
1516
|
+
f"file_name:'{ckpt_file_name}', format:'{format}'")
|
|
1210
1517
|
|
|
1211
|
-
ckpt_file_name = os.path.
|
|
1518
|
+
ckpt_file_name = os.path.realpath(ckpt_file_name)
|
|
1212
1519
|
if not os.path.exists(ckpt_file_name):
|
|
1213
1520
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
|
1214
1521
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
|
@@ -1239,17 +1546,28 @@ def _check_prefix(prefix):
|
|
|
1239
1546
|
return prefix
|
|
1240
1547
|
|
|
1241
1548
|
|
|
1242
|
-
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
|
1549
|
+
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
1243
1550
|
"""Parse checkpoint protobuf."""
|
|
1244
1551
|
checkpoint_list = Checkpoint()
|
|
1245
1552
|
try:
|
|
1246
1553
|
if dec_key is None:
|
|
1247
|
-
with open(ckpt_file_name,
|
|
1554
|
+
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
1248
1555
|
pb_content = f.read()
|
|
1249
1556
|
else:
|
|
1250
1557
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1251
1558
|
if pb_content is None:
|
|
1252
1559
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
|
1560
|
+
if crc_check and pb_content[-17:-10] != b"crc_num":
|
|
1561
|
+
logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
|
|
1562
|
+
if pb_content[-17:-10] == b"crc_num":
|
|
1563
|
+
crc_num_bytes = pb_content[-10:]
|
|
1564
|
+
pb_content = pb_content[:-17]
|
|
1565
|
+
if crc_check:
|
|
1566
|
+
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
1567
|
+
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
1568
|
+
if cal_crc_num != crc_num:
|
|
1569
|
+
raise ValueError("For 'load_checkpoint', the crc check is failed, "
|
|
1570
|
+
"please check whether the ckpt file is damaged.")
|
|
1253
1571
|
checkpoint_list.ParseFromString(pb_content)
|
|
1254
1572
|
except BaseException as e:
|
|
1255
1573
|
if _is_cipher_file(ckpt_file_name):
|
|
@@ -1282,17 +1600,40 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
1282
1600
|
|
|
1283
1601
|
def _init_parameter_data_in_parallel_mode(net, parameter_dict):
|
|
1284
1602
|
"""In parallel mode, only init the paraemters in ckpt."""
|
|
1603
|
+
is_train_phase = net.phase.startswith('train')
|
|
1285
1604
|
for _, param in net.parameters_and_names():
|
|
1605
|
+
if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
|
|
1606
|
+
param.shape = tuple(parameter_dict[param.name].shape)
|
|
1607
|
+
continue
|
|
1286
1608
|
if param.name in parameter_dict and param.has_init:
|
|
1287
1609
|
logger.warning("{} is not init while load ckpt.".format(param.name))
|
|
1288
1610
|
new_tensor = param.init_data()
|
|
1289
1611
|
param._update_tensor_data(new_tensor)
|
|
1290
1612
|
|
|
1291
1613
|
|
|
1292
|
-
def
|
|
1614
|
+
def _check_load_param_into_net(net, parameter_dict):
|
|
1615
|
+
"""check load_param_into_net"""
|
|
1616
|
+
if not isinstance(net, nn.Cell):
|
|
1617
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1618
|
+
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1619
|
+
raise TypeError(msg)
|
|
1620
|
+
if not isinstance(parameter_dict, dict):
|
|
1621
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1622
|
+
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1623
|
+
"but got {}.".format(type(parameter_dict)))
|
|
1624
|
+
raise TypeError(msg)
|
|
1625
|
+
if "random_op" in parameter_dict.keys():
|
|
1626
|
+
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1627
|
+
parameter_dict.pop("random_op")
|
|
1628
|
+
|
|
1629
|
+
|
|
1630
|
+
def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
|
|
1293
1631
|
"""
|
|
1294
1632
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
1295
1633
|
|
|
1634
|
+
Note:
|
|
1635
|
+
- When loading a parameter dict that has removed redundancy, the network should be compiled.
|
|
1636
|
+
|
|
1296
1637
|
Args:
|
|
1297
1638
|
net (Cell): The network where the parameters will be loaded.
|
|
1298
1639
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
@@ -1301,10 +1642,13 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1301
1642
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
1302
1643
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
1303
1644
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1645
|
+
remove_redundancy (bool): Whether to enable loading of checkpoint saved with redundancy removal.
|
|
1646
|
+
Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: ``False`` , means
|
|
1647
|
+
redundant-free loading is not enabled.
|
|
1304
1648
|
|
|
1305
1649
|
Returns:
|
|
1306
|
-
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1307
|
-
ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1650
|
+
- param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1651
|
+
- ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1308
1652
|
|
|
1309
1653
|
Raises:
|
|
1310
1654
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
|
@@ -1313,7 +1657,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1313
1657
|
>>> import mindspore as ms
|
|
1314
1658
|
>>>
|
|
1315
1659
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1316
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1660
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1317
1661
|
>>> net = LeNet5()
|
|
1318
1662
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1319
1663
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
@@ -1323,20 +1667,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1323
1667
|
|
|
1324
1668
|
Tutorial Examples:
|
|
1325
1669
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1326
|
-
<https://mindspore.cn/tutorials/en/
|
|
1670
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1327
1671
|
"""
|
|
1328
|
-
|
|
1329
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1330
|
-
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1331
|
-
raise TypeError(msg)
|
|
1332
|
-
if not isinstance(parameter_dict, dict):
|
|
1333
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1334
|
-
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1335
|
-
"but got {}.".format(type(parameter_dict)))
|
|
1336
|
-
raise TypeError(msg)
|
|
1337
|
-
if "random_op" in parameter_dict.keys():
|
|
1338
|
-
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1339
|
-
parameter_dict.pop("random_op")
|
|
1672
|
+
_check_load_param_into_net(net, parameter_dict)
|
|
1340
1673
|
for key, value in parameter_dict.items():
|
|
1341
1674
|
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1342
1675
|
logger.critical("Load parameters into net failed.")
|
|
@@ -1345,8 +1678,11 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1345
1678
|
raise TypeError(msg)
|
|
1346
1679
|
|
|
1347
1680
|
strict_load = Validator.check_bool(strict_load)
|
|
1681
|
+
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1348
1682
|
logger.info("Execute the process of loading parameters into net.")
|
|
1349
|
-
|
|
1683
|
+
for _, param in net.parameters_and_names():
|
|
1684
|
+
param.from_ckpt = True
|
|
1685
|
+
if not (_is_in_auto_parallel_mode() or _is_parallel_mode()):
|
|
1350
1686
|
net.init_parameters_data()
|
|
1351
1687
|
else:
|
|
1352
1688
|
_init_parameter_data_in_parallel_mode(net, parameter_dict)
|
|
@@ -1360,7 +1696,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1360
1696
|
# Add has attr protection when load server checkpoint file on worker.
|
|
1361
1697
|
if not hasattr(parameter_dict[param.name], "data"):
|
|
1362
1698
|
continue
|
|
1363
|
-
new_param =
|
|
1699
|
+
new_param = parameter_dict[param.name]
|
|
1364
1700
|
_update_param(param, new_param, strict_load)
|
|
1365
1701
|
ckpt_not_load.remove(param.name)
|
|
1366
1702
|
else:
|
|
@@ -1369,18 +1705,31 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1369
1705
|
if param_not_load and not strict_load:
|
|
1370
1706
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1371
1707
|
|
|
1372
|
-
logger.debug("Params not matched(in net but not in parameter_dict):")
|
|
1373
|
-
for param_name in param_not_load:
|
|
1374
|
-
logger.debug("%s", param_name)
|
|
1375
|
-
|
|
1376
1708
|
logger.info("Loading parameters into net is finished.")
|
|
1377
1709
|
if param_not_load:
|
|
1378
1710
|
logger.warning("For 'load_param_into_net', "
|
|
1379
1711
|
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1380
1712
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1381
|
-
"when training and loading checkpoint."
|
|
1382
|
-
|
|
1383
|
-
|
|
1713
|
+
"when training and loading checkpoint. Another possibility is that "
|
|
1714
|
+
"the redundant loading is not enabled, but the loaded checkpoint is saved with "
|
|
1715
|
+
"redundancy removed. ".format(len(param_not_load)))
|
|
1716
|
+
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1717
|
+
if remove_redundancy:
|
|
1718
|
+
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
1719
|
+
if parallel_mode == "stand_alone":
|
|
1720
|
+
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1721
|
+
f"in parallel scenarios, but got {parallel_mode}.")
|
|
1722
|
+
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1723
|
+
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1724
|
+
"the network should be compiled.")
|
|
1725
|
+
param_layout = net.parameter_layout_dict
|
|
1726
|
+
rank_id = get_rank()
|
|
1727
|
+
device_num = _get_device_num()
|
|
1728
|
+
stage_num = _get_auto_parallel_context("pipeline_stages")
|
|
1729
|
+
chunk_size = device_num // stage_num
|
|
1730
|
+
initial_rank = (rank_id // chunk_size) * chunk_size
|
|
1731
|
+
_single_parameter_broadcast(net, param_layout, rank_id, initial_rank)
|
|
1732
|
+
|
|
1384
1733
|
return param_not_load, ckpt_not_load
|
|
1385
1734
|
|
|
1386
1735
|
|
|
@@ -1486,7 +1835,7 @@ def _save_graph(network, file_name):
|
|
|
1486
1835
|
"""
|
|
1487
1836
|
logger.info("Execute the process of saving graph.")
|
|
1488
1837
|
|
|
1489
|
-
file_name = os.path.
|
|
1838
|
+
file_name = os.path.realpath(file_name)
|
|
1490
1839
|
graph_pb = network.get_func_graph_proto()
|
|
1491
1840
|
if graph_pb:
|
|
1492
1841
|
with open(file_name, "wb") as f:
|
|
@@ -1494,6 +1843,23 @@ def _save_graph(network, file_name):
|
|
|
1494
1843
|
f.write(graph_pb)
|
|
1495
1844
|
|
|
1496
1845
|
|
|
1846
|
+
def _reshape_tensor(tensor, dst_shape):
|
|
1847
|
+
"""reshape tensor to dst shape"""
|
|
1848
|
+
np_tensor = tensor.asnumpy()
|
|
1849
|
+
np_tensor = np_tensor.reshape(dst_shape)
|
|
1850
|
+
return Tensor(np_tensor, tensor.dtype)
|
|
1851
|
+
|
|
1852
|
+
|
|
1853
|
+
def _check_param_for_integrate_save(pipeline_stages, uniform_split):
|
|
1854
|
+
"""check whether current settings and parameters are supported in integrated save checkpoint mode"""
|
|
1855
|
+
if pipeline_stages > 1:
|
|
1856
|
+
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1857
|
+
if uniform_split == 0:
|
|
1858
|
+
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1859
|
+
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1860
|
+
"is only supports uniform split tensor now.")
|
|
1861
|
+
|
|
1862
|
+
|
|
1497
1863
|
def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
|
|
1498
1864
|
"""
|
|
1499
1865
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
|
@@ -1507,7 +1873,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1507
1873
|
Tensor, the combined tensor which with the whole data value.
|
|
1508
1874
|
"""
|
|
1509
1875
|
layout = parameter_layout_dict[param_name]
|
|
1510
|
-
if len(layout) <
|
|
1876
|
+
if len(layout) < 8:
|
|
1511
1877
|
logger.info("The layout dict does not contain the key %s", param_name)
|
|
1512
1878
|
return param_data
|
|
1513
1879
|
|
|
@@ -1515,6 +1881,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1515
1881
|
tensor_map = layout[1]
|
|
1516
1882
|
uniform_split = layout[4]
|
|
1517
1883
|
opt_shard_group = layout[5]
|
|
1884
|
+
before_reshape_slice_shape = layout[2]
|
|
1885
|
+
before_reshape_full_shape = layout[6]
|
|
1886
|
+
after_reshape_slice_shape = layout[7]
|
|
1887
|
+
do_reshape = False
|
|
1888
|
+
if before_reshape_full_shape and after_reshape_slice_shape \
|
|
1889
|
+
and after_reshape_slice_shape != before_reshape_slice_shape:
|
|
1890
|
+
do_reshape = True
|
|
1518
1891
|
|
|
1519
1892
|
allgather_net = None
|
|
1520
1893
|
mp_weight = False
|
|
@@ -1527,26 +1900,26 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1527
1900
|
else:
|
|
1528
1901
|
logger.info("Need to create allgather net for %s", param_name)
|
|
1529
1902
|
if integrated_save:
|
|
1530
|
-
|
|
1531
|
-
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1532
|
-
if uniform_split == 0:
|
|
1533
|
-
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1534
|
-
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1535
|
-
"is only supports uniform split tensor now.")
|
|
1903
|
+
_check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
|
|
1536
1904
|
# while any dim is not equal to -1, means param is split and needs to be merged
|
|
1537
1905
|
# pipeline parallel need to be supported here later
|
|
1538
1906
|
if mp_weight:
|
|
1539
|
-
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group)
|
|
1907
|
+
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
|
|
1908
|
+
tuple(after_reshape_slice_shape))
|
|
1540
1909
|
object.__setattr__(allgather_net, "keep_input_unchanged", True)
|
|
1541
1910
|
elif opt_shard_group:
|
|
1542
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1911
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1912
|
+
tuple(after_reshape_slice_shape))
|
|
1543
1913
|
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
|
1544
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1914
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1915
|
+
tuple(after_reshape_slice_shape))
|
|
1545
1916
|
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
|
1546
1917
|
if allgather_net:
|
|
1547
1918
|
param_data = allgather_net(param_data)
|
|
1548
1919
|
if mp_weight and integrated_save:
|
|
1549
1920
|
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
|
|
1921
|
+
if do_reshape:
|
|
1922
|
+
param_data = _reshape_tensor(param_data, before_reshape_full_shape)
|
|
1550
1923
|
return param_data
|
|
1551
1924
|
|
|
1552
1925
|
|
|
@@ -1556,7 +1929,8 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1556
1929
|
|
|
1557
1930
|
Note:
|
|
1558
1931
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1559
|
-
2. When file_name does not have a suffix, the system will automatically add one
|
|
1932
|
+
2. When `file_name` does not have a suffix, the system will automatically add one
|
|
1933
|
+
according to the `file_format`.
|
|
1560
1934
|
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1561
1935
|
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1562
1936
|
class properties in calculations.
|
|
@@ -1576,7 +1950,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1576
1950
|
- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
|
|
1577
1951
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
|
1578
1952
|
- MINDIR: MindSpore Native Intermediate Representation for Anf. An intermediate representation format
|
|
1579
|
-
for MindSpore models.
|
|
1953
|
+
for MindSpore models. MINDIR does not support operators which have dictionary attribute.
|
|
1580
1954
|
|
|
1581
1955
|
kwargs (dict): Configuration options dictionary.
|
|
1582
1956
|
|
|
@@ -1586,9 +1960,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1586
1960
|
- For 'AIR' and 'ONNX' models, only customized encryption is supported.
|
|
1587
1961
|
- For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
|
|
1588
1962
|
or Customized encryption.
|
|
1589
|
-
Default: 'AES-GCM'
|
|
1963
|
+
Default: ``'AES-GCM'``.
|
|
1590
1964
|
- For details of using the customized encryption, please check the `tutorial
|
|
1591
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
1965
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
1592
1966
|
|
|
1593
1967
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1594
1968
|
preprocessing of the dataset into MindIR.
|
|
@@ -1602,32 +1976,49 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1602
1976
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1603
1977
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1604
1978
|
Reference to 'my_func()' in
|
|
1605
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
1979
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1606
1980
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1607
1981
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1608
1982
|
obfuscated model.
|
|
1609
1983
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1610
1984
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1611
|
-
`obf_random_seed` is set, then it should be passed
|
|
1985
|
+
`obf_random_seed` is set, then it should be passed
|
|
1986
|
+
to :class:`mindspore.nn.GraphCell` interface when loading
|
|
1612
1987
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1613
1988
|
be set, and the latter mode would be applied if both of them are set.
|
|
1614
1989
|
|
|
1615
1990
|
- incremental (bool): export MindIR incrementally.
|
|
1616
1991
|
|
|
1992
|
+
- custom_func (function): Functions for custom defined export policies. This function will be used to
|
|
1993
|
+
customize the model during network export. Currently only support for files with mindir format. The
|
|
1994
|
+
function only accepts one input representing the proto object of the mindir file. When modifying a model,
|
|
1995
|
+
it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
|
|
1996
|
+
failure or functional errors. Default: ``None`` .
|
|
1997
|
+
|
|
1617
1998
|
Examples:
|
|
1618
1999
|
>>> import mindspore as ms
|
|
1619
2000
|
>>> import numpy as np
|
|
1620
2001
|
>>> from mindspore import Tensor
|
|
1621
2002
|
>>>
|
|
1622
2003
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1623
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
2004
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1624
2005
|
>>> net = LeNet5()
|
|
1625
2006
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1626
2007
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
2008
|
+
>>>
|
|
2009
|
+
>>> # Export model in MindIR format and modified the model info using custom_func
|
|
2010
|
+
>>> # The custom_func only support one input representing the Proto object of the model
|
|
2011
|
+
>>> # And custom_func does not support return value
|
|
2012
|
+
>>> def _custom_func(mindir_model):
|
|
2013
|
+
... mindir_model.producer_name = "test11111"
|
|
2014
|
+
... mindir_model.producer_version = "11.0"
|
|
2015
|
+
... mindir_model.user_info["version"] = "11.0"
|
|
2016
|
+
>>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
|
|
2017
|
+
|
|
1627
2018
|
|
|
1628
2019
|
Tutorial Examples:
|
|
1629
2020
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1630
|
-
<https://mindspore.cn/tutorials/en/
|
|
2021
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1631
2022
|
"""
|
|
1632
2023
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1633
2024
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
@@ -1658,7 +2049,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1658
2049
|
+ str(columns))
|
|
1659
2050
|
inputs = tuple(inputs_col)
|
|
1660
2051
|
|
|
1661
|
-
file_name = os.path.
|
|
2052
|
+
file_name = os.path.realpath(file_name)
|
|
1662
2053
|
if 'enc_key' in kwargs.keys():
|
|
1663
2054
|
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1664
2055
|
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
@@ -1690,7 +2081,7 @@ def _get_funcgraph(net, *inputs):
|
|
|
1690
2081
|
>>> from mindspore import Tensor
|
|
1691
2082
|
>>>
|
|
1692
2083
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1693
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
2084
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1694
2085
|
>>> net = LeNet5()
|
|
1695
2086
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1696
2087
|
>>> ms.get_funcgraph(net, input_tensor)
|
|
@@ -1712,6 +2103,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1712
2103
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1713
2104
|
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1714
2105
|
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
2106
|
+
if "custom_func" in kwargs and file_format != "MINDIR":
|
|
2107
|
+
raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
|
|
1715
2108
|
if file_format == 'AIR':
|
|
1716
2109
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
1717
2110
|
elif file_format == 'ONNX':
|
|
@@ -1749,8 +2142,8 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
1749
2142
|
if os.path.exists(file_name):
|
|
1750
2143
|
os.chmod(file_name, stat.S_IWUSR)
|
|
1751
2144
|
if "/" in file_name:
|
|
1752
|
-
real_path = os.path.
|
|
1753
|
-
os.makedirs(real_path, exist_ok=True)
|
|
2145
|
+
real_path = os.path.realpath(file_name[:file_name.rfind("/")])
|
|
2146
|
+
os.makedirs(real_path, mode=0o700, exist_ok=True)
|
|
1754
2147
|
if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
|
|
1755
2148
|
_executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
|
|
1756
2149
|
else:
|
|
@@ -1860,24 +2253,24 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1860
2253
|
file_prefix = file_name.split("/")[-1]
|
|
1861
2254
|
if file_prefix.endswith(".mindir"):
|
|
1862
2255
|
file_prefix = file_prefix[:-7]
|
|
1863
|
-
current_path = os.path.
|
|
2256
|
+
current_path = os.path.realpath(file_name)
|
|
1864
2257
|
dirname = os.path.dirname(current_path)
|
|
1865
2258
|
data_path = os.path.join(dirname, file_prefix + "_variables")
|
|
1866
2259
|
if os.path.exists(data_path):
|
|
1867
2260
|
shutil.rmtree(data_path)
|
|
1868
|
-
os.makedirs(data_path, exist_ok=True)
|
|
2261
|
+
os.makedirs(data_path, mode=0o700, exist_ok=True)
|
|
1869
2262
|
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
|
1870
2263
|
index = 0
|
|
1871
2264
|
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
|
1872
2265
|
data_file_name = os.path.join(dirname, external_local)
|
|
1873
2266
|
f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
|
|
1874
2267
|
try:
|
|
1875
|
-
|
|
2268
|
+
round = 0
|
|
1876
2269
|
names = []
|
|
1877
2270
|
for param_proto in model.graph.parameter:
|
|
1878
2271
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1879
2272
|
names.append((name, param_proto))
|
|
1880
|
-
|
|
2273
|
+
names.sort(key=lambda x: x[0])
|
|
1881
2274
|
for pairs in names:
|
|
1882
2275
|
name = pairs[0]
|
|
1883
2276
|
param_proto = pairs[1]
|
|
@@ -1900,8 +2293,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1900
2293
|
offset += (data_length + append_size)
|
|
1901
2294
|
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1902
2295
|
f.write(write_data)
|
|
1903
|
-
|
|
1904
|
-
logger.debug(f"writing {
|
|
2296
|
+
round += 1
|
|
2297
|
+
logger.debug(f"writing {round}th split data, name:{name}")
|
|
1905
2298
|
|
|
1906
2299
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
1907
2300
|
if os.path.exists(graph_file_name):
|
|
@@ -1998,6 +2391,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1998
2391
|
dataset = kwargs.get('dataset')
|
|
1999
2392
|
_save_dataset_to_mindir(model, dataset)
|
|
2000
2393
|
|
|
2394
|
+
custom_func = kwargs.get('custom_func', None)
|
|
2395
|
+
if custom_func is not None:
|
|
2396
|
+
custom_func(model)
|
|
2397
|
+
|
|
2001
2398
|
save_together = _save_together(net_dict, model)
|
|
2002
2399
|
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
|
2003
2400
|
if save_together:
|
|
@@ -2030,9 +2427,9 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2030
2427
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
2031
2428
|
if not file_name.endswith('.mindir'):
|
|
2032
2429
|
file_name += ".mindir"
|
|
2033
|
-
current_path = os.path.
|
|
2430
|
+
current_path = os.path.realpath(file_name)
|
|
2034
2431
|
dirname = os.path.dirname(current_path)
|
|
2035
|
-
os.makedirs(dirname, exist_ok=True)
|
|
2432
|
+
os.makedirs(dirname, mode=0o700, exist_ok=True)
|
|
2036
2433
|
if os.path.exists(file_name):
|
|
2037
2434
|
os.chmod(file_name, stat.S_IWUSR)
|
|
2038
2435
|
with open(file_name, 'wb') as f:
|
|
@@ -2084,6 +2481,45 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
2084
2481
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
2085
2482
|
|
|
2086
2483
|
|
|
2484
|
+
def check_checkpoint(ckpt_file_name):
|
|
2485
|
+
"""
|
|
2486
|
+
Check whether the checkpoint is valid.
|
|
2487
|
+
|
|
2488
|
+
Args:
|
|
2489
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
2490
|
+
|
|
2491
|
+
Returns:
|
|
2492
|
+
bool, whether the checkpoint is valid.
|
|
2493
|
+
|
|
2494
|
+
Examples:
|
|
2495
|
+
>>> import mindspore as ms
|
|
2496
|
+
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
2497
|
+
>>> check_result = ms.check_checkpoint(ckpt_file_name)
|
|
2498
|
+
>>> print(check_result)
|
|
2499
|
+
True
|
|
2500
|
+
"""
|
|
2501
|
+
if not ckpt_file_name.endswith('.ckpt'):
|
|
2502
|
+
return False
|
|
2503
|
+
checkpoint_list = Checkpoint()
|
|
2504
|
+
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
2505
|
+
pb_content = f.read()
|
|
2506
|
+
if pb_content[-17:-10] == b"crc_num":
|
|
2507
|
+
crc_num_bytes = pb_content[-10:]
|
|
2508
|
+
pb_content = pb_content[:-17]
|
|
2509
|
+
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
2510
|
+
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
2511
|
+
if cal_crc_num != crc_num:
|
|
2512
|
+
logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
|
|
2513
|
+
return False
|
|
2514
|
+
try:
|
|
2515
|
+
checkpoint_list.ParseFromString(pb_content)
|
|
2516
|
+
except google.protobuf.message.DecodeError as e:
|
|
2517
|
+
logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
|
|
2518
|
+
logger.warning(e)
|
|
2519
|
+
return False
|
|
2520
|
+
return True
|
|
2521
|
+
|
|
2522
|
+
|
|
2087
2523
|
def parse_print(print_file_name):
|
|
2088
2524
|
"""
|
|
2089
2525
|
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
@@ -2122,7 +2558,7 @@ def parse_print(print_file_name):
|
|
|
2122
2558
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
2123
2559
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
2124
2560
|
"""
|
|
2125
|
-
print_file_path = os.path.
|
|
2561
|
+
print_file_path = os.path.realpath(print_file_name)
|
|
2126
2562
|
|
|
2127
2563
|
if os.path.getsize(print_file_path) == 0:
|
|
2128
2564
|
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
@@ -2411,14 +2847,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
2411
2847
|
return merged_parameter
|
|
2412
2848
|
|
|
2413
2849
|
|
|
2414
|
-
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None,
|
|
2415
|
-
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'
|
|
2850
|
+
def load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None,
|
|
2851
|
+
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM',
|
|
2852
|
+
format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None):
|
|
2416
2853
|
"""
|
|
2417
2854
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
2418
2855
|
|
|
2419
2856
|
Args:
|
|
2420
2857
|
network (Cell): Network for distributed predication.
|
|
2421
|
-
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
|
2858
|
+
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. Default: ``None`` .
|
|
2422
2859
|
predict_strategy (dict): Strategy of predication process. It means that using one device to predict
|
|
2423
2860
|
when setting predict_strategy as None. Default: ``None`` .
|
|
2424
2861
|
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
@@ -2428,13 +2865,21 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2428
2865
|
in at least one of them. Default: ``None`` .
|
|
2429
2866
|
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
2430
2867
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
2431
|
-
parameter in the network. When the types are inconsistent perform type conversion
|
|
2868
|
+
parameter in the network. When the types are inconsistent, perform type conversion
|
|
2432
2869
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
2433
2870
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
2434
2871
|
is not required. Default: ``None`` .
|
|
2435
2872
|
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
2436
2873
|
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
2437
2874
|
Default: ``'AES-GCM'`` .
|
|
2875
|
+
format (str): Input weight format to be loaded into the network.
|
|
2876
|
+
It can be set to either "ckpt" or "safetensors". Default: "ckpt".
|
|
2877
|
+
unified_safetensors_dir (str): Directory of input weight files to be loaded into the network.
|
|
2878
|
+
Default: ``None`` .
|
|
2879
|
+
dst_safetensors_dir (str): In the save mode scenario, the save directory for safetensors.
|
|
2880
|
+
rank_id (int): The logical sequence number of the card. In non save mode, it is automatically obtained
|
|
2881
|
+
globally by initializing the network; In save mode, save the file according to the input
|
|
2882
|
+
sequence number. If it is not input, save the entire file.
|
|
2438
2883
|
|
|
2439
2884
|
Raises:
|
|
2440
2885
|
TypeError: The type of inputs do not match the requirements.
|
|
@@ -2449,14 +2894,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2449
2894
|
|
|
2450
2895
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2451
2896
|
Please see the `rank table startup
|
|
2452
|
-
<https://www.mindspore.cn/
|
|
2897
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
|
|
2453
2898
|
for more details.
|
|
2454
2899
|
|
|
2455
2900
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2456
|
-
<https://www.mindspore.cn/
|
|
2901
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
|
|
2457
2902
|
|
|
2458
2903
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2459
|
-
Startup <https://www.mindspore.cn/
|
|
2904
|
+
Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
|
|
2460
2905
|
|
|
2461
2906
|
>>> import os
|
|
2462
2907
|
>>> import numpy as np
|
|
@@ -2538,6 +2983,54 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2538
2983
|
...
|
|
2539
2984
|
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2540
2985
|
"""
|
|
2986
|
+
if format not in ['safetensors', 'ckpt']:
|
|
2987
|
+
raise ValueError(
|
|
2988
|
+
f"For 'load_distributed_checkpoint', 'format' must be 'ckpt' or 'safetensors', but got {format}.")
|
|
2989
|
+
|
|
2990
|
+
if format == 'safetensors':
|
|
2991
|
+
if unified_safetensors_dir is None:
|
|
2992
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'unified_safetensors_dir' can not be None "
|
|
2993
|
+
f"when format is 'safetensors'.")
|
|
2994
|
+
unsupport_param = [checkpoint_filenames, train_strategy_filename, dec_key]
|
|
2995
|
+
for param in unsupport_param:
|
|
2996
|
+
if param is not None:
|
|
2997
|
+
raise ValueError(f"For 'load_distributed_checkpoint', {param} must be None "
|
|
2998
|
+
f"when format is 'safetensors'.")
|
|
2999
|
+
if strict_load or dec_mode != 'AES-GCM':
|
|
3000
|
+
raise ValueError(f"For 'load_distributed_checkpoint', strict_load and dec_mode must be default "
|
|
3001
|
+
f"when format is 'safetensors'.")
|
|
3002
|
+
if network is not None:
|
|
3003
|
+
rank_id = get_rank()
|
|
3004
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, rank_id=rank_id)
|
|
3005
|
+
else:
|
|
3006
|
+
if dst_safetensors_dir is None:
|
|
3007
|
+
raise ValueError(f"For 'load_distributed_checkpoint', 'dst_safetensors_dir' can not be None "
|
|
3008
|
+
f"when network is None.")
|
|
3009
|
+
if rank_id is not None:
|
|
3010
|
+
_load_parallel_checkpoint(unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir,
|
|
3011
|
+
rank_id)
|
|
3012
|
+
else:
|
|
3013
|
+
dst_strategy_dict = _build_searched_strategy(predict_strategy)
|
|
3014
|
+
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_dict)
|
|
3015
|
+
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_dict)
|
|
3016
|
+
dst_device_num = dst_stage_device_num * dst_stage_num
|
|
3017
|
+
processes = []
|
|
3018
|
+
activate_processes = 0
|
|
3019
|
+
for rank in range(0, dst_device_num):
|
|
3020
|
+
p = Process(target=_load_parallel_checkpoint, args=(
|
|
3021
|
+
unified_safetensors_dir, predict_strategy, network, dst_safetensors_dir, rank))
|
|
3022
|
+
p.start()
|
|
3023
|
+
processes.append(p)
|
|
3024
|
+
activate_processes += 1
|
|
3025
|
+
max_processes = 64
|
|
3026
|
+
if activate_processes >= max_processes:
|
|
3027
|
+
p = processes.pop(0)
|
|
3028
|
+
p.join()
|
|
3029
|
+
activate_processes -= 1
|
|
3030
|
+
for p in processes:
|
|
3031
|
+
p.join()
|
|
3032
|
+
return
|
|
3033
|
+
|
|
2541
3034
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2542
3035
|
_check_checkpoint_file(checkpoint_filenames)
|
|
2543
3036
|
_check_predict_strategy(predict_strategy)
|
|
@@ -2582,17 +3075,24 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2582
3075
|
param_rank = rank_list.get(param.name)[0]
|
|
2583
3076
|
skip_merge_split = rank_list.get(param.name)[1]
|
|
2584
3077
|
shard_stride = train_strategy.get(param.name)[4]
|
|
3078
|
+
tensor_map = train_strategy.get(param.name)[1]
|
|
3079
|
+
first_dim_shard_idx = tensor_map[0] if tensor_map else -1
|
|
3080
|
+
device_arrangement = train_strategy.get(param.name)[0]
|
|
3081
|
+
first_dim_shard_size = 1
|
|
3082
|
+
if first_dim_shard_idx >= 0:
|
|
3083
|
+
first_dim_shard_size = device_arrangement[-1 - first_dim_shard_idx]
|
|
2585
3084
|
if train_strategy.get(param.name)[5]:
|
|
2586
|
-
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
|
|
3085
|
+
shard_size = int(ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] / first_dim_shard_size)
|
|
2587
3086
|
else:
|
|
2588
3087
|
shard_size = 0
|
|
2589
3088
|
for rank in param_rank:
|
|
2590
3089
|
param_total_list = list(range(0, ckpt_file_len))
|
|
3090
|
+
if first_dim_shard_size != 1:
|
|
3091
|
+
param_total_list = _get_param_list_when_first_dim_sharded(device_arrangement, first_dim_shard_idx, rank)
|
|
2591
3092
|
if shard_size > 0:
|
|
2592
|
-
|
|
2593
|
-
|
|
2594
|
-
|
|
2595
|
-
param_total_list = shard_total_list[rank // shard_size]
|
|
3093
|
+
rank_index = param_total_list.index(rank)
|
|
3094
|
+
start = rank_index // shard_size * shard_size
|
|
3095
|
+
param_total_list = param_total_list[start:start + shard_size]
|
|
2596
3096
|
if shard_stride > 0:
|
|
2597
3097
|
param_stride = []
|
|
2598
3098
|
# merge pre parameter
|
|
@@ -2722,11 +3222,10 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2722
3222
|
param_name = merged_param.name
|
|
2723
3223
|
tensor_layout = predict_strategy[param_name]
|
|
2724
3224
|
rank = get_rank()
|
|
2725
|
-
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
3225
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
2726
3226
|
requires_grad = merged_param.requires_grad
|
|
2727
3227
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2728
|
-
|
|
2729
|
-
if data_type == mstype.bfloat16:
|
|
3228
|
+
if merged_param.data.dtype == mstype.bfloat16:
|
|
2730
3229
|
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
2731
3230
|
else:
|
|
2732
3231
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
@@ -2765,7 +3264,7 @@ def _get_mindir_inputs(file_name):
|
|
|
2765
3264
|
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
2766
3265
|
"""
|
|
2767
3266
|
Validator.check_file_name_by_regular(file_name)
|
|
2768
|
-
file_name = os.path.
|
|
3267
|
+
file_name = os.path.realpath(file_name)
|
|
2769
3268
|
model = read_proto(file_name)
|
|
2770
3269
|
input_tensor = []
|
|
2771
3270
|
|
|
@@ -2794,7 +3293,7 @@ def _get_mindir_inputs(file_name):
|
|
|
2794
3293
|
|
|
2795
3294
|
def convert_model(mindir_file, convert_file, file_format):
|
|
2796
3295
|
"""
|
|
2797
|
-
Convert mindir model to other format model.
|
|
3296
|
+
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
2798
3297
|
|
|
2799
3298
|
.. warning::
|
|
2800
3299
|
This is an experimental API that is subject to change or deletion.
|