mindspore 2.2.11__cp39-cp39-win_amd64.whl → 2.3.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +7 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +76 -18
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +258 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +174 -62
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +142 -240
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +343 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +51 -24
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/common/__init__.py +15 -4
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +8 -9
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +411 -106
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +6 -8
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +260 -0
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +30 -11
- mindspore/common/recompute.py +262 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +272 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +468 -496
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +1140 -0
- mindspore/communication/management.py +118 -102
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +378 -65
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +163 -83
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +78 -59
- mindspore/dataset/engine/datasets_vision.py +77 -73
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +56 -38
- mindspore/dataset/engine/validators.py +11 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +8 -8
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +3 -3
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +14 -2
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +71 -127
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +339 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/dataset/vision.h +54 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +2 -2
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +76 -58
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +53 -66
- mindspore/mindrecord/tools/cifar10_to_mr.py +48 -63
- mindspore/mindrecord/tools/csv_to_mr.py +7 -17
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +11 -21
- mindspore/mindrecord/tools/tfrecord_to_mr.py +2 -10
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/mint/__init__.py +1137 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +512 -0
- mindspore/mint/nn/functional.py +573 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +185 -0
- mindspore/multiprocessing/__init__.py +72 -0
- mindspore/nn/__init__.py +1 -0
- mindspore/nn/cell.py +213 -257
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/{rewrite/ast_creator_register.py → nn/extend/layer/__init__.py} +9 -19
- mindspore/nn/extend/layer/normalization.py +109 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/layer/activation.py +84 -94
- mindspore/nn/layer/basic.py +177 -82
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +103 -45
- mindspore/nn/layer/embedding_service.py +531 -0
- mindspore/nn/layer/embedding_service_layer.py +393 -0
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +52 -66
- mindspore/nn/layer/padding.py +30 -39
- mindspore/nn/layer/pooling.py +18 -9
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +63 -84
- mindspore/nn/optim/ada_grad.py +6 -4
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +7 -4
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +58 -13
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +32 -9
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +6 -6
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +61 -67
- mindspore/numpy/utils.py +3 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +8 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -160
- mindspore/ops/_grad_experimental/grad_comm_ops.py +93 -36
- mindspore/ops/_grad_experimental/grad_inner_ops.py +8 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +92 -287
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +1 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/{cpu/concat.py → aicpu/generate_eod_mask.py} +16 -17
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +164 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +130 -58
- mindspore/ops/_vmap/vmap_nn_ops.py +249 -115
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +231 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +250 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +980 -0
- mindspore/ops/auto_generate/gen_ops_def.py +6443 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +13167 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +429 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +121 -23
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +191 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +53 -0
- mindspore/ops/extend/array_func.py +218 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/extend/nn_func.py +308 -0
- mindspore/ops/function/__init__.py +31 -11
- mindspore/ops/function/array_func.py +848 -1736
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +2 -5
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +27 -20
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +30 -53
- mindspore/ops/function/math_func.py +916 -2791
- mindspore/ops/function/nn_func.py +1445 -889
- mindspore/ops/function/other_func.py +6 -7
- mindspore/ops/function/parameter_func.py +6 -92
- mindspore/ops/function/random_func.py +254 -108
- mindspore/ops/function/reshard_func.py +102 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +11 -18
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +15 -14
- mindspore/ops/functional.py +342 -343
- mindspore/ops/op_info_register.py +16 -43
- mindspore/ops/operations/__init__.py +32 -23
- mindspore/ops/operations/_embedding_cache_ops.py +1 -1
- mindspore/ops/operations/_grad_ops.py +21 -853
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +155 -511
- mindspore/ops/operations/_quant_ops.py +4 -4
- mindspore/ops/operations/_rl_inner_ops.py +3 -3
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +112 -2698
- mindspore/ops/operations/comm_ops.py +801 -118
- mindspore/ops/operations/custom_ops.py +62 -121
- mindspore/ops/operations/debug_ops.py +105 -36
- mindspore/ops/operations/image_ops.py +3 -219
- mindspore/ops/operations/inner_ops.py +54 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2016 -0
- mindspore/ops/operations/math_ops.py +621 -4654
- mindspore/ops/operations/nn_ops.py +316 -2226
- mindspore/ops/operations/other_ops.py +53 -45
- mindspore/ops/operations/random_ops.py +4 -51
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +8 -8
- mindspore/ops/primitive.py +204 -103
- mindspore/ops/silent_check.py +162 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +250 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_ops.py +1084 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +968 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +354 -0
- mindspore/ops_generate/template.py +239 -0
- mindspore/parallel/__init__.py +7 -4
- mindspore/parallel/_auto_parallel_context.py +155 -6
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +62 -14
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +18 -9
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +10 -10
- mindspore/parallel/_utils.py +161 -6
- mindspore/parallel/algo_parameter_config.py +6 -8
- mindspore/parallel/checkpoint_transform.py +369 -64
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +344 -0
- mindspore/parallel/cluster/process_entity/_utils.py +126 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +128 -17
- mindspore/profiler/__init__.py +3 -2
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +125 -0
- mindspore/profiler/envprofiling.py +2 -2
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +53 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +159 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +161 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +85 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +57 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +116 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +68 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +27 -5
- mindspore/profiler/parser/ascend_fpbp_generator.py +8 -2
- mindspore/profiler/parser/ascend_hccl_generator.py +31 -280
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +151 -126
- mindspore/profiler/parser/ascend_msprof_generator.py +75 -274
- mindspore/profiler/parser/ascend_op_generator.py +94 -36
- mindspore/profiler/parser/ascend_timeline_generator.py +297 -131
- mindspore/profiler/parser/base_timeline_generator.py +17 -3
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -1
- mindspore/profiler/parser/framework_parser.py +11 -4
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +8 -2
- mindspore/profiler/parser/minddata_analyzer.py +8 -2
- mindspore/profiler/parser/minddata_parser.py +73 -4
- mindspore/profiler/parser/msadvisor_analyzer.py +5 -3
- mindspore/profiler/parser/msadvisor_parser.py +10 -4
- mindspore/profiler/parser/profiler_info.py +16 -1
- mindspore/profiler/profiling.py +522 -195
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +123 -37
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +46 -30
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -5
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +151 -37
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +238 -0
- mindspore/train/callback/_landscape.py +16 -11
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_mindio_ttp.py +443 -0
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +13 -14
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -21
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +89 -15
- mindspore/train/model.py +298 -56
- mindspore/train/serialization.py +501 -221
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +1 -1
- mindspore/train/summary/summary_record.py +56 -34
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/METADATA +3 -3
- mindspore-2.3.0.dist-info/RECORD +1400 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.11.dist-info/RECORD +0 -1920
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.11.dist-info → mindspore-2.3.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from __future__ import absolute_import
|
|
18
18
|
from __future__ import division
|
|
19
19
|
|
|
20
|
+
import binascii
|
|
20
21
|
import copy
|
|
21
22
|
import json
|
|
22
23
|
import os
|
|
@@ -30,6 +31,7 @@ from io import BytesIO
|
|
|
30
31
|
import math
|
|
31
32
|
import sys
|
|
32
33
|
import time
|
|
34
|
+
import google
|
|
33
35
|
import numpy as np
|
|
34
36
|
|
|
35
37
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
@@ -50,9 +52,12 @@ from mindspore.common.api import _generate_branch_control_input
|
|
|
50
52
|
from mindspore.common.initializer import initializer, One
|
|
51
53
|
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
52
54
|
from mindspore.common.tensor import Tensor
|
|
55
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
53
56
|
from mindspore.common._utils import is_shape_unknown
|
|
57
|
+
from mindspore.common.file_system import FileSystem, _register_basic_file_system, _register_mindio_file_system
|
|
54
58
|
from mindspore.communication.management import get_rank, get_group_size
|
|
55
59
|
from mindspore.experimental import MapParameter
|
|
60
|
+
from mindspore.ops import Cast
|
|
56
61
|
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
57
62
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
58
63
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
@@ -61,21 +66,23 @@ from mindspore.parallel._parallel_serialization import _convert_to_list, _conver
|
|
|
61
66
|
_restore_group_info_list
|
|
62
67
|
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
68
|
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
69
|
+
from mindspore.parallel.checkpoint_transform import sync_pipeline_shared_parameters
|
|
64
70
|
from mindspore.train._utils import read_proto
|
|
65
71
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
66
72
|
split_mindir, split_dynamic_mindir
|
|
73
|
+
from mindspore.common.generator import Generator
|
|
74
|
+
from mindspore.train._utils import get_parameter_redundancy, remove_param_redundancy
|
|
75
|
+
from mindspore.parallel.parameter_broadcast import parameter_broadcast
|
|
67
76
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
68
|
-
from ..ops.operations import Cast
|
|
69
77
|
|
|
70
78
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
71
79
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
72
80
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
73
|
-
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16}
|
|
81
|
+
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
|
|
74
82
|
|
|
75
83
|
tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
76
84
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
77
|
-
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"
|
|
78
|
-
"BFloat16": np.float32}
|
|
85
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
79
86
|
|
|
80
87
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
81
88
|
|
|
@@ -95,6 +102,30 @@ INT_64_MAX = 9223372036854775807
|
|
|
95
102
|
|
|
96
103
|
cpu_cast = Cast().set_device("CPU")
|
|
97
104
|
|
|
105
|
+
_ckpt_fs = FileSystem()
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def init_ckpt_file_system(fs: FileSystem):
|
|
109
|
+
"""Initialize checkpoint file system"""
|
|
110
|
+
if _register_mindio_file_system(fs):
|
|
111
|
+
return
|
|
112
|
+
_register_basic_file_system(fs)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# Initialize checkpoint file system
|
|
116
|
+
init_ckpt_file_system(_ckpt_fs)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ParamDictFuture:
|
|
120
|
+
def __init__(self, executor, param_dict_future):
|
|
121
|
+
self.executor = executor
|
|
122
|
+
self.param_dict_future = param_dict_future
|
|
123
|
+
|
|
124
|
+
def result(self):
|
|
125
|
+
param_dict = self.param_dict_future.result()
|
|
126
|
+
self.executor.shutdown()
|
|
127
|
+
return param_dict
|
|
128
|
+
|
|
98
129
|
|
|
99
130
|
def _special_process_par(par, new_par):
|
|
100
131
|
"""
|
|
@@ -176,7 +207,7 @@ def _update_param(param, new_param, strict_load):
|
|
|
176
207
|
|
|
177
208
|
def _type_convert(param, new_param, strict_load):
|
|
178
209
|
"""Whether to convert parameter's type during load checkpoint into network."""
|
|
179
|
-
float_type = (mstype.float16, mstype.float32, mstype.float64)
|
|
210
|
+
float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
|
|
180
211
|
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
|
|
181
212
|
if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
|
|
182
213
|
{param.data.dtype, new_param.data.dtype}.issubset(int_type)):
|
|
@@ -221,18 +252,19 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
221
252
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
222
253
|
|
|
223
254
|
|
|
224
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
|
|
255
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False):
|
|
225
256
|
"""Execute the process of saving checkpoint into file."""
|
|
226
257
|
try:
|
|
227
258
|
with _ckpt_mutex:
|
|
228
259
|
if os.path.exists(ckpt_file_name):
|
|
229
260
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
230
261
|
os.remove(ckpt_file_name)
|
|
231
|
-
with
|
|
262
|
+
with _ckpt_fs.create(ckpt_file_name, *_ckpt_fs.create_args) as f:
|
|
232
263
|
plain_data = None
|
|
233
264
|
if enc_key is not None:
|
|
234
265
|
plain_data = BytesIO()
|
|
235
266
|
|
|
267
|
+
crc_num = 0
|
|
236
268
|
for name, value in data_list.items():
|
|
237
269
|
if name == "random_op":
|
|
238
270
|
_write_random_seed(name, value, f)
|
|
@@ -242,21 +274,21 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
242
274
|
continue
|
|
243
275
|
if value[0] == "offload_parameter":
|
|
244
276
|
new_value = value[1:]
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
else:
|
|
248
|
-
new_value[2] = value[3].asnumpy().reshape(-1)
|
|
249
|
-
_write_parameter_data(name, new_value, f, enc_key, plain_data)
|
|
277
|
+
new_value[2] = value[3]
|
|
278
|
+
_write_parameter_bytes_data(name, new_value, f, enc_key, plain_data)
|
|
250
279
|
_offload_if_config(value[3])
|
|
251
280
|
continue
|
|
252
|
-
if value[
|
|
253
|
-
|
|
281
|
+
if value[1] == "str":
|
|
282
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
283
|
+
continue
|
|
284
|
+
if isinstance(value[2], np.ndarray):
|
|
285
|
+
crc_num = _write_parameter_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
254
286
|
continue
|
|
255
|
-
if isinstance(value[2], Tensor):
|
|
287
|
+
if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
|
|
256
288
|
_write_hugeparameter(name, value, f)
|
|
257
289
|
continue
|
|
258
290
|
|
|
259
|
-
|
|
291
|
+
crc_num = _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num, crc_check)
|
|
260
292
|
|
|
261
293
|
if enc_key is not None:
|
|
262
294
|
plain_data.seek(0)
|
|
@@ -266,7 +298,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
266
298
|
f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
|
|
267
299
|
block_data = plain_data.read(max_block_size)
|
|
268
300
|
|
|
269
|
-
|
|
301
|
+
if crc_check:
|
|
302
|
+
f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
|
|
303
|
+
|
|
304
|
+
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
270
305
|
|
|
271
306
|
except BaseException as e:
|
|
272
307
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
@@ -286,22 +321,7 @@ def _write_random_seed(name, value, f):
|
|
|
286
321
|
f.write(checkpoint_list.SerializeToString())
|
|
287
322
|
|
|
288
323
|
|
|
289
|
-
def
|
|
290
|
-
"""Write bfloat16 data into protobuf file"""
|
|
291
|
-
checkpoint_list = Checkpoint()
|
|
292
|
-
param_value = checkpoint_list.value.add()
|
|
293
|
-
param_value.tag = name
|
|
294
|
-
param_tensor = param_value.tensor
|
|
295
|
-
param_tensor.dims.extend(value[1])
|
|
296
|
-
param_tensor.tensor_type = value[2]
|
|
297
|
-
param_tensor.tensor_content = value[3].get_bytes()
|
|
298
|
-
if enc_key is None:
|
|
299
|
-
f.write(checkpoint_list.SerializeToString())
|
|
300
|
-
else:
|
|
301
|
-
plain_data.write(checkpoint_list.SerializeToString())
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
324
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
305
325
|
"""Write parameter data into protobuf file."""
|
|
306
326
|
data_size = value[2].nbytes / 1024
|
|
307
327
|
if data_size > SLICE_SIZE:
|
|
@@ -320,10 +340,40 @@ def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
|
320
340
|
param_tensor.tensor_content = param_slice.tobytes()
|
|
321
341
|
|
|
322
342
|
if enc_key is None:
|
|
323
|
-
|
|
343
|
+
output_data = checkpoint_list.SerializeToString()
|
|
344
|
+
if crc_check:
|
|
345
|
+
crc_num = binascii.crc32(output_data, crc_num)
|
|
346
|
+
f.write(output_data)
|
|
347
|
+
else:
|
|
348
|
+
plain_data.write(checkpoint_list.SerializeToString())
|
|
349
|
+
|
|
350
|
+
return crc_num
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False):
|
|
354
|
+
"""Write parameter bytes data into protobuf file."""
|
|
355
|
+
bytes_value = value[2].get_bytes()
|
|
356
|
+
chunk_size = 1024 * SLICE_SIZE
|
|
357
|
+
|
|
358
|
+
for i in range(0, len(bytes_value), chunk_size):
|
|
359
|
+
checkpoint_list = Checkpoint()
|
|
360
|
+
param_value = checkpoint_list.value.add()
|
|
361
|
+
param_value.tag = name
|
|
362
|
+
param_tensor = param_value.tensor
|
|
363
|
+
param_tensor.dims.extend(value[0])
|
|
364
|
+
param_tensor.tensor_type = value[1]
|
|
365
|
+
param_tensor.tensor_content = bytes_value[i:i + chunk_size]
|
|
366
|
+
|
|
367
|
+
if enc_key is None:
|
|
368
|
+
output_data = checkpoint_list.SerializeToString()
|
|
369
|
+
if crc_check:
|
|
370
|
+
crc_num = binascii.crc32(output_data, crc_num)
|
|
371
|
+
f.write(output_data)
|
|
324
372
|
else:
|
|
325
373
|
plain_data.write(checkpoint_list.SerializeToString())
|
|
326
374
|
|
|
375
|
+
return crc_num
|
|
376
|
+
|
|
327
377
|
|
|
328
378
|
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
329
379
|
"""Write map parameter into protobuf file."""
|
|
@@ -384,10 +434,14 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
384
434
|
|
|
385
435
|
|
|
386
436
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
387
|
-
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
437
|
+
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
|
|
438
|
+
crc_check=False, **kwargs):
|
|
388
439
|
r"""
|
|
389
440
|
Save checkpoint to a specified file.
|
|
390
441
|
|
|
442
|
+
Note:
|
|
443
|
+
The `enc_mode` and `crc_check` parameters are mutually exclusive and cannot be configured simultaneously.
|
|
444
|
+
|
|
391
445
|
Args:
|
|
392
446
|
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
393
447
|
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
@@ -409,6 +463,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
409
463
|
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
410
464
|
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
411
465
|
be saved. Default: ``None`` .
|
|
466
|
+
crc_check (bool) : Whether to perform crc32 calculation when saving checkpoint and save the calculation
|
|
467
|
+
result to the file. Default: ``False`` .
|
|
412
468
|
kwargs (dict): Configuration options dictionary.
|
|
413
469
|
|
|
414
470
|
Raises:
|
|
@@ -420,7 +476,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
420
476
|
>>> import mindspore as ms
|
|
421
477
|
>>>
|
|
422
478
|
>>> # Define the network structure of LeNet5. Refer to
|
|
423
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
479
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
424
480
|
>>> net = LeNet5()
|
|
425
481
|
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
426
482
|
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
@@ -440,7 +496,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
440
496
|
|
|
441
497
|
Tutorial Examples:
|
|
442
498
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
443
|
-
<https://mindspore.cn/tutorials/en/
|
|
499
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
444
500
|
"""
|
|
445
501
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
446
502
|
integrated_save = Validator.check_bool(integrated_save)
|
|
@@ -448,24 +504,32 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
448
504
|
append_dict = _check_append_dict(append_dict)
|
|
449
505
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
450
506
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
507
|
+
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
451
508
|
map_param_inc = kwargs.get('incremental', False)
|
|
452
509
|
logger.info("Execute the process of saving checkpoint files.")
|
|
510
|
+
global_step_num = kwargs.get('global_step_num', None)
|
|
453
511
|
|
|
454
512
|
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
455
513
|
|
|
456
514
|
if append_dict:
|
|
457
515
|
append_info_list = []
|
|
458
516
|
for k_name, value in append_dict.items():
|
|
459
|
-
if
|
|
517
|
+
if isinstance(value, Generator):
|
|
518
|
+
value = value.get_state()
|
|
519
|
+
elif not isinstance(value, str):
|
|
460
520
|
value = Tensor(value)
|
|
461
521
|
append_info_list.append({"name": k_name, "data": value})
|
|
462
522
|
save_obj.extend(append_info_list)
|
|
463
523
|
|
|
464
524
|
data_list = OrderedDict()
|
|
525
|
+
data_list_np = OrderedDict()
|
|
465
526
|
with _ckpt_mutex:
|
|
466
527
|
for param in save_obj:
|
|
467
528
|
if param["name"] == "random_op":
|
|
468
|
-
|
|
529
|
+
if os.getenv("AITURBO") == "1":
|
|
530
|
+
data_list_np["random_op"] = param["data"]
|
|
531
|
+
else:
|
|
532
|
+
data_list["random_op"] = param["data"]
|
|
469
533
|
continue
|
|
470
534
|
key = param["name"]
|
|
471
535
|
data_list[key] = []
|
|
@@ -479,49 +543,41 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
479
543
|
elif param["data"][0] == "offload_parameter":
|
|
480
544
|
data_list[key].append("offload_parameter")
|
|
481
545
|
_save_param_list_data(data_list, key, param)
|
|
482
|
-
elif param["data"][0] == "BFloat16_tensor":
|
|
483
|
-
data_list[key].append("BFloat16_tensor")
|
|
484
|
-
_save_param_list_data(data_list, key, param)
|
|
485
|
-
continue
|
|
486
546
|
|
|
487
547
|
if isinstance(param["data"], str):
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
548
|
+
if os.getenv("AITURBO") == "1":
|
|
549
|
+
data_list_np[key] = np.array(param["data"])
|
|
550
|
+
else:
|
|
551
|
+
data_list[key].append([0])
|
|
552
|
+
data_list[key].append('str')
|
|
553
|
+
data = np.array(param["data"])
|
|
554
|
+
data_list[key].append(data)
|
|
492
555
|
else:
|
|
493
556
|
if isinstance(param["data"], Parameter):
|
|
494
557
|
param["data"].init_data()
|
|
495
|
-
if
|
|
496
|
-
|
|
497
|
-
dims = []
|
|
498
|
-
for dim in param["data"].shape:
|
|
499
|
-
dims.append(dim)
|
|
500
|
-
data_list[key].append(dims)
|
|
501
|
-
data_list[key].append("BFloat16")
|
|
502
|
-
data_list[key].append(cpu_cast(param["data"], mstype.float32))
|
|
503
|
-
continue
|
|
504
|
-
dims = []
|
|
505
|
-
if param['data'].shape == ():
|
|
506
|
-
dims.append(0)
|
|
558
|
+
if os.getenv("AITURBO") == "1":
|
|
559
|
+
data_list_np[key] = param["data"].asnumpy()
|
|
507
560
|
else:
|
|
561
|
+
dims = []
|
|
508
562
|
for dim in param['data'].shape:
|
|
509
563
|
dims.append(dim)
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
564
|
+
data_list[key].append(dims)
|
|
565
|
+
tensor_type = str(param["data"].dtype)
|
|
566
|
+
data_list[key].append(tensor_type)
|
|
567
|
+
data = param["data"]
|
|
568
|
+
data_list[key].append(data)
|
|
569
|
+
|
|
570
|
+
if os.getenv("AITURBO") == "1":
|
|
571
|
+
import aiturbo
|
|
572
|
+
ckpt_name = os.path.basename(ckpt_file_name)
|
|
573
|
+
aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np)
|
|
574
|
+
elif async_save:
|
|
520
575
|
data_copy = copy.deepcopy(data_list)
|
|
521
|
-
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode
|
|
576
|
+
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check),
|
|
577
|
+
name="asyn_save_ckpt")
|
|
522
578
|
thr.start()
|
|
523
579
|
else:
|
|
524
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
|
|
580
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check)
|
|
525
581
|
|
|
526
582
|
logger.info("Saving checkpoint process is finished.")
|
|
527
583
|
|
|
@@ -532,7 +588,21 @@ def _convert_list_to_param_list(save_obj, choice_func):
|
|
|
532
588
|
if not save_obj:
|
|
533
589
|
return param_list
|
|
534
590
|
if isinstance(save_obj[0], dict):
|
|
535
|
-
|
|
591
|
+
for param in save_obj:
|
|
592
|
+
if isinstance(param, dict) and "name" in param and "data" in param:
|
|
593
|
+
if not isinstance(param["name"], str):
|
|
594
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the name in dict "
|
|
595
|
+
f"should be string, but got {type(param['name'])}.")
|
|
596
|
+
if not isinstance(param["data"], Tensor):
|
|
597
|
+
raise TypeError(f"For save_checkpoint, when save_obj is a list of dict items, the data in dict "
|
|
598
|
+
f"should be parameter, but got {type(param['data'])}.")
|
|
599
|
+
if choice_func is not None and not choice_func(param["name"]):
|
|
600
|
+
continue
|
|
601
|
+
each_param = {"name": param["name"], "data": param["data"]}
|
|
602
|
+
param_list.append(each_param)
|
|
603
|
+
else:
|
|
604
|
+
raise TypeError(f"For save_checkpoint, save_obj should be a list of dict items, and the dict should "
|
|
605
|
+
f"have key values 'name' and 'value', but got {type(param)} and {param}.")
|
|
536
606
|
else:
|
|
537
607
|
for param in save_obj:
|
|
538
608
|
if isinstance(param, Parameter):
|
|
@@ -585,6 +655,7 @@ def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
|
585
655
|
|
|
586
656
|
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
587
657
|
"""Convert nn.Cell to param_list."""
|
|
658
|
+
sync_pipeline_shared_parameters(save_obj)
|
|
588
659
|
param_list = []
|
|
589
660
|
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
590
661
|
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
@@ -597,7 +668,7 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
597
668
|
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
598
669
|
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
599
670
|
param_list.append({"name": "random_op", "data": random_byte})
|
|
600
|
-
|
|
671
|
+
append_dict.pop("random_op")
|
|
601
672
|
for (key, value) in param_dict.items():
|
|
602
673
|
each_param = {"name": key}
|
|
603
674
|
if isinstance(value, MapParameter):
|
|
@@ -619,18 +690,13 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
619
690
|
param_data.append(param_tensor.shape)
|
|
620
691
|
param_data.append(str(param_tensor.dtype))
|
|
621
692
|
param_data.append(value.key)
|
|
622
|
-
elif value.data.dtype == mstype.bfloat16:
|
|
623
|
-
param_data = ["BFloat16_tensor"]
|
|
624
|
-
param_data.append(cpu_cast(value.data, mstype.float32))
|
|
625
|
-
param_data.append(value.data.shape)
|
|
626
|
-
param_data.append("BFloat16")
|
|
627
|
-
param_data.append(value.key)
|
|
628
693
|
else:
|
|
629
|
-
param_data =
|
|
694
|
+
param_data = value.data
|
|
630
695
|
|
|
631
696
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
632
697
|
# which should be combined before saving
|
|
633
698
|
if key in parameter_layout_dict:
|
|
699
|
+
param_data = Tensor(value.data)
|
|
634
700
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
635
701
|
integrated_save)
|
|
636
702
|
|
|
@@ -670,9 +736,9 @@ def _check_append_dict(append_dict):
|
|
|
670
736
|
raise TypeError("For 'save_checkpoint', the argument 'append_dict' must be dict, but got "
|
|
671
737
|
"{}.".format(type(append_dict)))
|
|
672
738
|
for key, value in append_dict.items():
|
|
673
|
-
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor)):
|
|
739
|
+
if not isinstance(key, str) or not isinstance(value, (int, float, bool, str, Parameter, Tensor, Generator)):
|
|
674
740
|
raise TypeError(f"For 'save_checkpoint', the type of dict 'append_info' must be key: string, "
|
|
675
|
-
f"value: int, float or
|
|
741
|
+
f"value: int, float, bool or Generator, but got key: {type(key)}, value: {type(value)}")
|
|
676
742
|
return append_dict
|
|
677
743
|
|
|
678
744
|
|
|
@@ -699,13 +765,13 @@ def load(file_name, **kwargs):
|
|
|
699
765
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
700
766
|
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
|
|
701
767
|
|
|
702
|
-
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'
|
|
768
|
+
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: ``'AES-GCM'``.
|
|
703
769
|
- For details of using the customized decryption, please check the `tutorial
|
|
704
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
770
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
705
771
|
|
|
706
772
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
707
773
|
`obfuscate_model()
|
|
708
|
-
<https://www.mindspore.cn/docs/en/
|
|
774
|
+
<https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
709
775
|
|
|
710
776
|
Returns:
|
|
711
777
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -735,7 +801,7 @@ def load(file_name, **kwargs):
|
|
|
735
801
|
|
|
736
802
|
Tutorial Examples:
|
|
737
803
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
738
|
-
<https://mindspore.cn/tutorials/en/
|
|
804
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
739
805
|
"""
|
|
740
806
|
if not isinstance(file_name, str):
|
|
741
807
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -776,7 +842,7 @@ def load(file_name, **kwargs):
|
|
|
776
842
|
return graph
|
|
777
843
|
|
|
778
844
|
|
|
779
|
-
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=
|
|
845
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=True):
|
|
780
846
|
"""
|
|
781
847
|
Auto Split MindIR.
|
|
782
848
|
|
|
@@ -784,10 +850,10 @@ def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=F
|
|
|
784
850
|
|
|
785
851
|
Args:
|
|
786
852
|
file_name (str): MindIR file name.
|
|
787
|
-
device_num (int): device number.
|
|
788
|
-
rank_id (int): rank id.
|
|
789
|
-
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
790
|
-
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
853
|
+
device_num (int): device number. Default: '8'.
|
|
854
|
+
rank_id (int): rank id. Default: '0'.
|
|
855
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model. Default: 'True'.
|
|
856
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP. Default: 'True'.
|
|
791
857
|
|
|
792
858
|
Raises:
|
|
793
859
|
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
@@ -909,13 +975,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
909
975
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
910
976
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
911
977
|
Reference to 'my_func()' in
|
|
912
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
978
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
913
979
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
914
980
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
915
981
|
when loading obfuscated model.
|
|
916
982
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
917
983
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
918
|
-
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell
|
|
984
|
+
`obf_random_seed` is set, then it should be passed to :class:`mindspore.nn.GraphCell`
|
|
985
|
+
interface when loading
|
|
919
986
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
920
987
|
be set, and the latter mode would be applied if both of them are set.
|
|
921
988
|
|
|
@@ -923,7 +990,7 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
923
990
|
|
|
924
991
|
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
925
992
|
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
926
|
-
|
|
993
|
+
Options: ``'AES-GCM'`` | ``'AES-CBC'`` | ``'SM4-CBC'``. Default: ``'AES-GCM'``.
|
|
927
994
|
|
|
928
995
|
Raises:
|
|
929
996
|
TypeError: If `obf_config` is not a dict.
|
|
@@ -934,11 +1001,15 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
934
1001
|
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
935
1002
|
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
936
1003
|
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
937
|
-
ValueError: If `original_model_path`
|
|
1004
|
+
ValueError: If `original_model_path` does not exist or `original_model_path` does not end with '.mindir'.
|
|
938
1005
|
|
|
939
1006
|
Examples:
|
|
940
1007
|
>>> import mindspore as ms
|
|
941
1008
|
>>> import mindspore.nn as nn
|
|
1009
|
+
>>> import numpy as np
|
|
1010
|
+
>>> # Download ori_net.mindir
|
|
1011
|
+
>>> # https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/mindir/ori_net.mindir
|
|
1012
|
+
>>> input1 = ms.Tensor(np.ones((1, 1, 32, 32)).astype(np.float32))
|
|
942
1013
|
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
943
1014
|
... 'save_model_path': "./obf_net",
|
|
944
1015
|
... 'model_inputs': [input1, ],
|
|
@@ -998,12 +1069,76 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
998
1069
|
obf_net = nn.GraphCell(obf_graph)
|
|
999
1070
|
if obf_random_seed != 0:
|
|
1000
1071
|
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
1001
|
-
model_inputs += [append_y_tensor
|
|
1072
|
+
model_inputs += [append_y_tensor]
|
|
1002
1073
|
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
1003
1074
|
|
|
1004
1075
|
|
|
1076
|
+
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1077
|
+
dec_mode, crc_check):
|
|
1078
|
+
"""load parameter into parameter_dict"""
|
|
1079
|
+
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1080
|
+
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1081
|
+
try:
|
|
1082
|
+
param_data_list = []
|
|
1083
|
+
map_data_list = [[], [], []]
|
|
1084
|
+
map_shape_list = [0, 0, 0]
|
|
1085
|
+
if specify_prefix:
|
|
1086
|
+
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1087
|
+
"please use `choice_func` instead.")
|
|
1088
|
+
if filter_prefix:
|
|
1089
|
+
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1090
|
+
"please use `choice_func` instead.")
|
|
1091
|
+
for element_id, element in enumerate(checkpoint_list.value):
|
|
1092
|
+
if element.tag == "random_op":
|
|
1093
|
+
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1094
|
+
continue
|
|
1095
|
+
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1096
|
+
continue
|
|
1097
|
+
if specify_prefix is None and filter_prefix is None and \
|
|
1098
|
+
choice_func is not None and not choice_func(element.tag):
|
|
1099
|
+
continue
|
|
1100
|
+
if element.tensor.ByteSize() == 0:
|
|
1101
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list,
|
|
1102
|
+
parameter_dict)
|
|
1103
|
+
if element.tag in parameter_dict:
|
|
1104
|
+
map_data_list = [[], [], []]
|
|
1105
|
+
map_shape_list = [0, 0, 0]
|
|
1106
|
+
continue
|
|
1107
|
+
data = element.tensor.tensor_content
|
|
1108
|
+
data_type = element.tensor.tensor_type
|
|
1109
|
+
np_type = tensor_to_np_type.get(data_type)
|
|
1110
|
+
ms_type = tensor_to_ms_type[data_type]
|
|
1111
|
+
if data_type == 'str':
|
|
1112
|
+
str_length = int(len(data) / 4)
|
|
1113
|
+
np_type = np_type + str(str_length)
|
|
1114
|
+
param_data_list.append(data)
|
|
1115
|
+
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1116
|
+
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1117
|
+
new_data = b"".join(param_data_list)
|
|
1118
|
+
param_data_list.clear()
|
|
1119
|
+
dims = element.tensor.dims
|
|
1120
|
+
if data_type == 'str':
|
|
1121
|
+
str_value = np.frombuffer(new_data, np_type)
|
|
1122
|
+
parameter_dict[element.tag] = str(str_value[0])
|
|
1123
|
+
else:
|
|
1124
|
+
if dims == [0]:
|
|
1125
|
+
dims = []
|
|
1126
|
+
param_data = Tensor_.convert_bytes_to_tensor(new_data, tuple(dims), ms_type)
|
|
1127
|
+
parameter = Parameter(param_data, name=element.tag)
|
|
1128
|
+
parameter_dict[element.tag] = parameter
|
|
1129
|
+
_offload_if_config(parameter)
|
|
1130
|
+
|
|
1131
|
+
logger.info("Loading checkpoint files process is finished.")
|
|
1132
|
+
|
|
1133
|
+
except BaseException as e:
|
|
1134
|
+
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
1135
|
+
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
|
1136
|
+
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
|
1137
|
+
|
|
1138
|
+
|
|
1005
1139
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
1006
|
-
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None
|
|
1140
|
+
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None,
|
|
1141
|
+
crc_check=False):
|
|
1007
1142
|
"""
|
|
1008
1143
|
Load checkpoint info from a specified file.
|
|
1009
1144
|
|
|
@@ -1034,6 +1169,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1034
1169
|
and the return value is a bool. If returns ``True`` , the Parameter
|
|
1035
1170
|
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1036
1171
|
matches the custom condition will be removed. Default: ``None`` .
|
|
1172
|
+
crc_check (bool) : Whether to perform crc32 validation when loading checkpoint. Default: ``False`` .
|
|
1037
1173
|
|
|
1038
1174
|
Returns:
|
|
1039
1175
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -1076,83 +1212,31 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1076
1212
|
|
|
1077
1213
|
Tutorial Examples:
|
|
1078
1214
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1079
|
-
<https://mindspore.cn/tutorials/en/
|
|
1215
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1080
1216
|
"""
|
|
1081
|
-
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
1082
1217
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1083
1218
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1084
1219
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1085
1220
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1221
|
+
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1086
1222
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1087
|
-
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
|
|
1088
1223
|
|
|
1089
1224
|
parameter_dict = {}
|
|
1090
|
-
try:
|
|
1091
|
-
param_data_list = []
|
|
1092
|
-
map_data_list = [[], [], []]
|
|
1093
|
-
map_shape_list = [0, 0, 0]
|
|
1094
|
-
if specify_prefix:
|
|
1095
|
-
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
1096
|
-
"please use `choice_func` instead.")
|
|
1097
|
-
if filter_prefix:
|
|
1098
|
-
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1099
|
-
"please use `choice_func` instead.")
|
|
1100
|
-
for element_id, element in enumerate(checkpoint_list.value):
|
|
1101
|
-
if element.tag == "random_op":
|
|
1102
|
-
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1103
|
-
continue
|
|
1104
|
-
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1105
|
-
continue
|
|
1106
|
-
if specify_prefix is None and filter_prefix is None and \
|
|
1107
|
-
choice_func is not None and not choice_func(element.tag):
|
|
1108
|
-
continue
|
|
1109
|
-
if element.tensor.ByteSize() == 0:
|
|
1110
|
-
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
1111
|
-
if element.tag in parameter_dict:
|
|
1112
|
-
map_data_list = [[], [], []]
|
|
1113
|
-
map_shape_list = [0, 0, 0]
|
|
1114
|
-
continue
|
|
1115
|
-
data = element.tensor.tensor_content
|
|
1116
|
-
data_type = element.tensor.tensor_type
|
|
1117
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
1118
|
-
ms_type = tensor_to_ms_type[data_type]
|
|
1119
|
-
if data_type == 'str':
|
|
1120
|
-
str_length = int(len(data) / 4)
|
|
1121
|
-
np_type = np_type + str(str_length)
|
|
1122
|
-
if data_type == "BFloat16":
|
|
1123
|
-
dims = element.tensor.dims
|
|
1124
|
-
param_data = np.frombuffer(data, np_type)
|
|
1125
|
-
param_data = param_data.reshape(list(dims))
|
|
1126
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1127
|
-
parameter_dict[element.tag] = parameter
|
|
1128
|
-
continue
|
|
1129
|
-
element_data = np.frombuffer(data, np_type)
|
|
1130
|
-
param_data_list.append(element_data)
|
|
1131
|
-
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1132
|
-
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
1133
|
-
new_data = b"".join(param_data_list)
|
|
1134
|
-
param_data = np.frombuffer(new_data, np_type)
|
|
1135
|
-
param_data_list.clear()
|
|
1136
|
-
dims = element.tensor.dims
|
|
1137
|
-
if dims == [0] and data_type == 'str':
|
|
1138
|
-
parameter_dict[element.tag] = str(element_data[0])
|
|
1139
|
-
else:
|
|
1140
|
-
if dims == [0] and 'Float' in data_type:
|
|
1141
|
-
param_data = float(param_data[0])
|
|
1142
|
-
if dims == [0] and 'Int' in data_type:
|
|
1143
|
-
param_data = int(param_data[0])
|
|
1144
|
-
if dims not in ([0], [1]):
|
|
1145
|
-
param_data = param_data.reshape(list(dims))
|
|
1146
|
-
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1147
|
-
parameter_dict[element.tag] = parameter
|
|
1148
|
-
_offload_if_config(parameter)
|
|
1149
|
-
|
|
1150
|
-
logger.info("Loading checkpoint files process is finished.")
|
|
1151
1225
|
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1226
|
+
if os.getenv("AITURBO") == "1":
|
|
1227
|
+
rank_id = get_rank()
|
|
1228
|
+
import aiturbo
|
|
1229
|
+
ckpt_path = os.path.dirname(ckpt_file_name)
|
|
1230
|
+
ckpt_name = os.path.basename(ckpt_file_name)
|
|
1231
|
+
np_dict = aiturbo.load_ckpt(ckpt_path, ckpt_name, rank_id)
|
|
1232
|
+
for key, value in np_dict.items():
|
|
1233
|
+
if isinstance(value, str):
|
|
1234
|
+
parameter_dict[key] = value
|
|
1235
|
+
else:
|
|
1236
|
+
parameter_dict[key] = Parameter(Tensor(value), name=key)
|
|
1237
|
+
else:
|
|
1238
|
+
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1239
|
+
dec_mode, crc_check)
|
|
1156
1240
|
|
|
1157
1241
|
if not parameter_dict:
|
|
1158
1242
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1168,6 +1252,86 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1168
1252
|
return parameter_dict
|
|
1169
1253
|
|
|
1170
1254
|
|
|
1255
|
+
def load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None,
|
|
1256
|
+
dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
1257
|
+
"""
|
|
1258
|
+
Load checkpoint info from a specified file asyncly.
|
|
1259
|
+
|
|
1260
|
+
.. warning::
|
|
1261
|
+
This is an experimental API that is subject to change or deletion.
|
|
1262
|
+
|
|
1263
|
+
Note:
|
|
1264
|
+
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
1265
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1266
|
+
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
1267
|
+
`choice_func` is recommended instead.
|
|
1268
|
+
And using either of those two args will override `choice_func` at the same time.
|
|
1269
|
+
|
|
1270
|
+
Args:
|
|
1271
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
1272
|
+
net (Cell, optional): The network where the parameters will be loaded. Default: ``None`` .
|
|
1273
|
+
strict_load (bool, optional): Whether to strict load the parameter into net. If ``False`` , it will load
|
|
1274
|
+
parameter into net when parameter name's suffix in checkpoint file is the
|
|
1275
|
+
same as the parameter in the network. When the types are inconsistent
|
|
1276
|
+
perform type conversion on the parameters of the same type, such as float32
|
|
1277
|
+
to float16. Default: ``False`` .
|
|
1278
|
+
filter_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1279
|
+
starting with the `filter_prefix` will not be loaded. Default: ``None`` .
|
|
1280
|
+
dec_key (Union[None, bytes], optional): Byte type key used for decryption. If the value is ``None`` ,
|
|
1281
|
+
the decryption is not required. Default: ``None`` .
|
|
1282
|
+
dec_mode (str, optional): This parameter is valid only when dec_key is not set to ``None`` . Specifies
|
|
1283
|
+
the decryption mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"``
|
|
1284
|
+
and ``"SM4-CBC"`` . Default: ``"AES-GCM"`` .
|
|
1285
|
+
specify_prefix (Union[str, list[str], tuple[str]], optional): Deprecated(see `choice_func`). Parameters
|
|
1286
|
+
starting with the specify_prefix will be loaded. Default: ``None`` .
|
|
1287
|
+
choice_func (Union[None, function], optional): Input value of the function is a Parameter name of type
|
|
1288
|
+
string, and the return value is a bool. If returns ``True`` , the Parameter
|
|
1289
|
+
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
1290
|
+
matches the custom condition will be removed. Default: ``None`` .
|
|
1291
|
+
|
|
1292
|
+
Returns:
|
|
1293
|
+
A custom inner class, calling its `result` method yields the :func:`mindspore.load_checkpoint` result.
|
|
1294
|
+
|
|
1295
|
+
Raises:
|
|
1296
|
+
ValueError: Checkpoint file's format is incorrect.
|
|
1297
|
+
ValueError: Parameter's dict is None after load checkpoint file.
|
|
1298
|
+
TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
|
1299
|
+
|
|
1300
|
+
Examples:
|
|
1301
|
+
>>> import mindspore
|
|
1302
|
+
>>> from mindspore import nn
|
|
1303
|
+
>>> from mindspore.train import Model
|
|
1304
|
+
>>> from mindspore.amp import FixedLossScaleManager
|
|
1305
|
+
>>> from mindspore import context
|
|
1306
|
+
>>> from mindspore import load_checkpoint_async
|
|
1307
|
+
>>> from mindspore import load_param_into_net
|
|
1308
|
+
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
|
1309
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1310
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1311
|
+
>>> dataset = create_dataset()
|
|
1312
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1313
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1314
|
+
>>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1315
|
+
>>> net = LeNet5()
|
|
1316
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
1317
|
+
>>> loss_scale_manager = FixedLossScaleManager()
|
|
1318
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1319
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1320
|
+
... loss_scale_manager=loss_scale_manager)
|
|
1321
|
+
>>> pd_future = load_checkpoint_async(ckpt_file)
|
|
1322
|
+
>>> model.build(train_dataset=dataset, epoch=2)
|
|
1323
|
+
>>> param_dict = pd_future.result()
|
|
1324
|
+
>>> load_param_into_net(net, param_dict)
|
|
1325
|
+
>>> model.train(2, dataset)
|
|
1326
|
+
>>> print("param dict len: ", len(param_dict), flush=True)
|
|
1327
|
+
"""
|
|
1328
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
1329
|
+
executor = ThreadPoolExecutor(max_workers=2)
|
|
1330
|
+
param_dict_future = executor.submit(load_checkpoint, ckpt_file_name, net, strict_load, filter_prefix,
|
|
1331
|
+
dec_key, dec_mode, specify_prefix, choice_func)
|
|
1332
|
+
return ParamDictFuture(executor, param_dict_future)
|
|
1333
|
+
|
|
1334
|
+
|
|
1171
1335
|
def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
1172
1336
|
map_shape_list, parameter_dict):
|
|
1173
1337
|
"""load map parameter."""
|
|
@@ -1239,17 +1403,28 @@ def _check_prefix(prefix):
|
|
|
1239
1403
|
return prefix
|
|
1240
1404
|
|
|
1241
1405
|
|
|
1242
|
-
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
|
1406
|
+
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check):
|
|
1243
1407
|
"""Parse checkpoint protobuf."""
|
|
1244
1408
|
checkpoint_list = Checkpoint()
|
|
1245
1409
|
try:
|
|
1246
1410
|
if dec_key is None:
|
|
1247
|
-
with open(ckpt_file_name,
|
|
1411
|
+
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
1248
1412
|
pb_content = f.read()
|
|
1249
1413
|
else:
|
|
1250
1414
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
1251
1415
|
if pb_content is None:
|
|
1252
1416
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
|
1417
|
+
if crc_check and pb_content[-17:-10] == b"crc_num":
|
|
1418
|
+
logger.warning("For 'load_checkpoint', the ckpt file do not contain the crc code, please check the file.")
|
|
1419
|
+
if pb_content[-17:-10] == b"crc_num":
|
|
1420
|
+
crc_num_bytes = pb_content[-10:]
|
|
1421
|
+
pb_content = pb_content[:-17]
|
|
1422
|
+
if crc_check:
|
|
1423
|
+
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
1424
|
+
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
1425
|
+
if cal_crc_num != crc_num:
|
|
1426
|
+
raise ValueError("For 'load_checkpoint', the crc check is failed, "
|
|
1427
|
+
"please check whether the ckpt file is damaged.")
|
|
1253
1428
|
checkpoint_list.ParseFromString(pb_content)
|
|
1254
1429
|
except BaseException as e:
|
|
1255
1430
|
if _is_cipher_file(ckpt_file_name):
|
|
@@ -1282,13 +1457,33 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
1282
1457
|
|
|
1283
1458
|
def _init_parameter_data_in_parallel_mode(net, parameter_dict):
|
|
1284
1459
|
"""In parallel mode, only init the paraemters in ckpt."""
|
|
1460
|
+
is_train_phase = net.phase.startswith('train')
|
|
1285
1461
|
for _, param in net.parameters_and_names():
|
|
1462
|
+
if param.name in parameter_dict and param.from_ckpt and not is_train_phase:
|
|
1463
|
+
param.shape = tuple(parameter_dict[param.name].shape)
|
|
1464
|
+
continue
|
|
1286
1465
|
if param.name in parameter_dict and param.has_init:
|
|
1287
1466
|
logger.warning("{} is not init while load ckpt.".format(param.name))
|
|
1288
1467
|
new_tensor = param.init_data()
|
|
1289
1468
|
param._update_tensor_data(new_tensor)
|
|
1290
1469
|
|
|
1291
1470
|
|
|
1471
|
+
def _check_load_param_into_net(net, parameter_dict):
|
|
1472
|
+
"""check load_param_into_net"""
|
|
1473
|
+
if not isinstance(net, nn.Cell):
|
|
1474
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1475
|
+
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1476
|
+
raise TypeError(msg)
|
|
1477
|
+
if not isinstance(parameter_dict, dict):
|
|
1478
|
+
logger.critical("Failed to combine the net and the parameters.")
|
|
1479
|
+
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1480
|
+
"but got {}.".format(type(parameter_dict)))
|
|
1481
|
+
raise TypeError(msg)
|
|
1482
|
+
if "random_op" in parameter_dict.keys():
|
|
1483
|
+
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1484
|
+
parameter_dict.pop("random_op")
|
|
1485
|
+
|
|
1486
|
+
|
|
1292
1487
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
1293
1488
|
"""
|
|
1294
1489
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
@@ -1303,8 +1498,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1303
1498
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
1304
1499
|
|
|
1305
1500
|
Returns:
|
|
1306
|
-
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1307
|
-
ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1501
|
+
- param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1502
|
+
- ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
1308
1503
|
|
|
1309
1504
|
Raises:
|
|
1310
1505
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
|
@@ -1313,7 +1508,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1313
1508
|
>>> import mindspore as ms
|
|
1314
1509
|
>>>
|
|
1315
1510
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1316
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1511
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1317
1512
|
>>> net = LeNet5()
|
|
1318
1513
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1319
1514
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
@@ -1323,20 +1518,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1323
1518
|
|
|
1324
1519
|
Tutorial Examples:
|
|
1325
1520
|
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1326
|
-
<https://mindspore.cn/tutorials/en/
|
|
1521
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1327
1522
|
"""
|
|
1328
|
-
|
|
1329
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1330
|
-
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1331
|
-
raise TypeError(msg)
|
|
1332
|
-
if not isinstance(parameter_dict, dict):
|
|
1333
|
-
logger.critical("Failed to combine the net and the parameters.")
|
|
1334
|
-
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1335
|
-
"but got {}.".format(type(parameter_dict)))
|
|
1336
|
-
raise TypeError(msg)
|
|
1337
|
-
if "random_op" in parameter_dict.keys():
|
|
1338
|
-
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1339
|
-
parameter_dict.pop("random_op")
|
|
1523
|
+
_check_load_param_into_net(net, parameter_dict)
|
|
1340
1524
|
for key, value in parameter_dict.items():
|
|
1341
1525
|
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1342
1526
|
logger.critical("Load parameters into net failed.")
|
|
@@ -1346,6 +1530,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1346
1530
|
|
|
1347
1531
|
strict_load = Validator.check_bool(strict_load)
|
|
1348
1532
|
logger.info("Execute the process of loading parameters into net.")
|
|
1533
|
+
for _, param in net.parameters_and_names():
|
|
1534
|
+
param.from_ckpt = True
|
|
1349
1535
|
if not _is_in_auto_parallel_mode():
|
|
1350
1536
|
net.init_parameters_data()
|
|
1351
1537
|
else:
|
|
@@ -1360,7 +1546,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1360
1546
|
# Add has attr protection when load server checkpoint file on worker.
|
|
1361
1547
|
if not hasattr(parameter_dict[param.name], "data"):
|
|
1362
1548
|
continue
|
|
1363
|
-
new_param =
|
|
1549
|
+
new_param = parameter_dict[param.name]
|
|
1364
1550
|
_update_param(param, new_param, strict_load)
|
|
1365
1551
|
ckpt_not_load.remove(param.name)
|
|
1366
1552
|
else:
|
|
@@ -1369,18 +1555,21 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1369
1555
|
if param_not_load and not strict_load:
|
|
1370
1556
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1371
1557
|
|
|
1372
|
-
logger.debug("Params not matched(in net but not in parameter_dict):")
|
|
1373
|
-
for param_name in param_not_load:
|
|
1374
|
-
logger.debug("%s", param_name)
|
|
1375
|
-
|
|
1376
1558
|
logger.info("Loading parameters into net is finished.")
|
|
1377
1559
|
if param_not_load:
|
|
1378
1560
|
logger.warning("For 'load_param_into_net', "
|
|
1379
1561
|
"{} parameters in the 'net' are not loaded, because they are not in the "
|
|
1380
1562
|
"'parameter_dict', please check whether the network structure is consistent "
|
|
1381
1563
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
1382
|
-
|
|
1383
|
-
|
|
1564
|
+
logger.warning("{} are not loaded.".format(param_not_load))
|
|
1565
|
+
if os.getenv("AITURBO") == "1" and net.parameter_layout_dict is not None:
|
|
1566
|
+
param_layout = net.parameter_layout_dict
|
|
1567
|
+
param_redundancy = get_parameter_redundancy(param_layout)
|
|
1568
|
+
remove_param_redundancy_dict = remove_param_redundancy(param_redundancy)
|
|
1569
|
+
target_parameter_name_set = set(parameter_dict.keys())
|
|
1570
|
+
for rank_id, param_name_set in remove_param_redundancy_dict:
|
|
1571
|
+
if param_name_set == target_parameter_name_set:
|
|
1572
|
+
parameter_broadcast(net, param_layout, rank_id)
|
|
1384
1573
|
return param_not_load, ckpt_not_load
|
|
1385
1574
|
|
|
1386
1575
|
|
|
@@ -1494,6 +1683,23 @@ def _save_graph(network, file_name):
|
|
|
1494
1683
|
f.write(graph_pb)
|
|
1495
1684
|
|
|
1496
1685
|
|
|
1686
|
+
def _reshape_tensor(tensor, dst_shape):
|
|
1687
|
+
"""reshape tensor to dst shape"""
|
|
1688
|
+
np_tensor = tensor.asnumpy()
|
|
1689
|
+
np_tensor = np_tensor.reshape(dst_shape)
|
|
1690
|
+
return Tensor(np_tensor, tensor.dtype)
|
|
1691
|
+
|
|
1692
|
+
|
|
1693
|
+
def _check_param_for_integrate_save(pipeline_stages, uniform_split):
|
|
1694
|
+
"""check whether current settings and parameters are supported in integrated save checkpoint mode"""
|
|
1695
|
+
if pipeline_stages > 1:
|
|
1696
|
+
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1697
|
+
if uniform_split == 0:
|
|
1698
|
+
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1699
|
+
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1700
|
+
"is only supports uniform split tensor now.")
|
|
1701
|
+
|
|
1702
|
+
|
|
1497
1703
|
def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
|
|
1498
1704
|
"""
|
|
1499
1705
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
|
@@ -1507,7 +1713,7 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1507
1713
|
Tensor, the combined tensor which with the whole data value.
|
|
1508
1714
|
"""
|
|
1509
1715
|
layout = parameter_layout_dict[param_name]
|
|
1510
|
-
if len(layout) <
|
|
1716
|
+
if len(layout) < 8:
|
|
1511
1717
|
logger.info("The layout dict does not contain the key %s", param_name)
|
|
1512
1718
|
return param_data
|
|
1513
1719
|
|
|
@@ -1515,6 +1721,13 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1515
1721
|
tensor_map = layout[1]
|
|
1516
1722
|
uniform_split = layout[4]
|
|
1517
1723
|
opt_shard_group = layout[5]
|
|
1724
|
+
before_reshape_slice_shape = layout[2]
|
|
1725
|
+
before_reshape_full_shape = layout[6]
|
|
1726
|
+
after_reshape_slice_shape = layout[7]
|
|
1727
|
+
do_reshape = False
|
|
1728
|
+
if before_reshape_full_shape and after_reshape_slice_shape \
|
|
1729
|
+
and after_reshape_slice_shape != before_reshape_slice_shape:
|
|
1730
|
+
do_reshape = True
|
|
1518
1731
|
|
|
1519
1732
|
allgather_net = None
|
|
1520
1733
|
mp_weight = False
|
|
@@ -1527,26 +1740,26 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1527
1740
|
else:
|
|
1528
1741
|
logger.info("Need to create allgather net for %s", param_name)
|
|
1529
1742
|
if integrated_save:
|
|
1530
|
-
|
|
1531
|
-
raise RuntimeError("Pipeline Parallel don't support Integrated save checkpoint now.")
|
|
1532
|
-
if uniform_split == 0:
|
|
1533
|
-
raise RuntimeError("For 'save_checkpoint' and in automatic model parallel scene, when set "
|
|
1534
|
-
"'integrated_save' to True, the checkpoint will be integrated save, it "
|
|
1535
|
-
"is only supports uniform split tensor now.")
|
|
1743
|
+
_check_param_for_integrate_save(context.get_auto_parallel_context("pipeline_stages"), uniform_split)
|
|
1536
1744
|
# while any dim is not equal to -1, means param is split and needs to be merged
|
|
1537
1745
|
# pipeline parallel need to be supported here later
|
|
1538
1746
|
if mp_weight:
|
|
1539
|
-
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group)
|
|
1747
|
+
allgather_net = get_allgather_cell(opt_shard_group, bool(opt_shard_group), do_reshape,
|
|
1748
|
+
tuple(after_reshape_slice_shape))
|
|
1540
1749
|
object.__setattr__(allgather_net, "keep_input_unchanged", True)
|
|
1541
1750
|
elif opt_shard_group:
|
|
1542
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1751
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1752
|
+
tuple(after_reshape_slice_shape))
|
|
1543
1753
|
elif opt_shard_group and context.get_auto_parallel_context("optimizer_weight_shard_aggregated_save"):
|
|
1544
|
-
allgather_net = get_allgather_cell(opt_shard_group, False
|
|
1754
|
+
allgather_net = get_allgather_cell(opt_shard_group, False, do_reshape,
|
|
1755
|
+
tuple(after_reshape_slice_shape))
|
|
1545
1756
|
net.parallel_parameter_merge_net_dict[param_name] = allgather_net
|
|
1546
1757
|
if allgather_net:
|
|
1547
1758
|
param_data = allgather_net(param_data)
|
|
1548
1759
|
if mp_weight and integrated_save:
|
|
1549
1760
|
param_data = _reshape_param_data(param_data, dev_mat, tensor_map)
|
|
1761
|
+
if do_reshape:
|
|
1762
|
+
param_data = _reshape_tensor(param_data, before_reshape_full_shape)
|
|
1550
1763
|
return param_data
|
|
1551
1764
|
|
|
1552
1765
|
|
|
@@ -1556,10 +1769,13 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1556
1769
|
|
|
1557
1770
|
Note:
|
|
1558
1771
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1559
|
-
2. When file_name does not have a suffix, the system will automatically add one
|
|
1772
|
+
2. When `file_name` does not have a suffix, the system will automatically add one
|
|
1773
|
+
according to the `file_format`.
|
|
1560
1774
|
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1561
1775
|
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1562
1776
|
class properties in calculations.
|
|
1777
|
+
5. AIR format is deprecated, and will be removed in a future version, please use other format or use
|
|
1778
|
+
MindSpore Lite to do offline inference.
|
|
1563
1779
|
|
|
1564
1780
|
Args:
|
|
1565
1781
|
net (Union[Cell, function]): MindSpore network.
|
|
@@ -1584,9 +1800,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1584
1800
|
- For 'AIR' and 'ONNX' models, only customized encryption is supported.
|
|
1585
1801
|
- For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
|
|
1586
1802
|
or Customized encryption.
|
|
1587
|
-
Default: 'AES-GCM'
|
|
1803
|
+
Default: ``'AES-GCM'``.
|
|
1588
1804
|
- For details of using the customized encryption, please check the `tutorial
|
|
1589
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
1805
|
+
<https://mindspore.cn/mindarmour/docs/en/master/model_encrypt_protection.html>`_.
|
|
1590
1806
|
|
|
1591
1807
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1592
1808
|
preprocessing of the dataset into MindIR.
|
|
@@ -1600,32 +1816,49 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1600
1816
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1601
1817
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1602
1818
|
Reference to 'my_func()' in
|
|
1603
|
-
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/
|
|
1819
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/master/dynamic_obfuscation_protection.html>`_).
|
|
1604
1820
|
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1605
1821
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1606
1822
|
obfuscated model.
|
|
1607
1823
|
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1608
1824
|
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1609
|
-
`obf_random_seed` is set, then it should be passed
|
|
1825
|
+
`obf_random_seed` is set, then it should be passed
|
|
1826
|
+
to :class:`mindspore.nn.GraphCell` interface when loading
|
|
1610
1827
|
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1611
1828
|
be set, and the latter mode would be applied if both of them are set.
|
|
1612
1829
|
|
|
1613
1830
|
- incremental (bool): export MindIR incrementally.
|
|
1614
1831
|
|
|
1832
|
+
- custom_func (function): Functions for custom defined export policies. This function will be used to
|
|
1833
|
+
customize the model during network export. Currently only support for files with mindir format. The
|
|
1834
|
+
function only accepts one input representing the proto object of the mindir file. When modifying a model,
|
|
1835
|
+
it is necessary to ensure the correctness of the `custom_func` , otherwise it may lead to model loading
|
|
1836
|
+
failure or functional errors. Default: ``None`` .
|
|
1837
|
+
|
|
1615
1838
|
Examples:
|
|
1616
1839
|
>>> import mindspore as ms
|
|
1617
1840
|
>>> import numpy as np
|
|
1618
1841
|
>>> from mindspore import Tensor
|
|
1619
1842
|
>>>
|
|
1620
1843
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1621
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1844
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1622
1845
|
>>> net = LeNet5()
|
|
1623
1846
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1624
1847
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
1848
|
+
>>>
|
|
1849
|
+
>>> # Export model in MindIR format and modified the model info using custom_func
|
|
1850
|
+
>>> # The custom_func only support one input representing the Proto object of the model
|
|
1851
|
+
>>> # And custom_func does not support return value
|
|
1852
|
+
>>> def _custom_func(mindir_model):
|
|
1853
|
+
... mindir_model.producer_name = "test11111"
|
|
1854
|
+
... mindir_model.producer_version = "11.0"
|
|
1855
|
+
... mindir_model.user_info["version"] = "11.0"
|
|
1856
|
+
>>> ms.export(net, input_tensor, file_name="lenet", file_format='MINDIR', custom_func=_custom_func)
|
|
1857
|
+
|
|
1625
1858
|
|
|
1626
1859
|
Tutorial Examples:
|
|
1627
1860
|
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1628
|
-
<https://mindspore.cn/tutorials/en/
|
|
1861
|
+
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1629
1862
|
"""
|
|
1630
1863
|
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1631
1864
|
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
@@ -1633,6 +1866,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1633
1866
|
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
1634
1867
|
if file_format not in supported_formats:
|
|
1635
1868
|
raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
|
|
1869
|
+
if file_format == 'AIR':
|
|
1870
|
+
logger.warning("AIR format is deprecated, and will be removed in a future version, please use other format or "
|
|
1871
|
+
"use MindSpore Lite to do offline inference")
|
|
1636
1872
|
Validator.check_file_name_by_regular(file_name)
|
|
1637
1873
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1638
1874
|
|
|
@@ -1685,7 +1921,7 @@ def _get_funcgraph(net, *inputs):
|
|
|
1685
1921
|
>>> from mindspore import Tensor
|
|
1686
1922
|
>>>
|
|
1687
1923
|
>>> # Define the network structure of LeNet5. Refer to
|
|
1688
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
1924
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1689
1925
|
>>> net = LeNet5()
|
|
1690
1926
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1691
1927
|
>>> ms.get_funcgraph(net, input_tensor)
|
|
@@ -1707,6 +1943,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1707
1943
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1708
1944
|
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1709
1945
|
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
1946
|
+
if "custom_func" in kwargs and file_format != "MINDIR":
|
|
1947
|
+
raise ValueError(f"Currently only support custom_func for MindIR format, but got {file_format} format.")
|
|
1710
1948
|
if file_format == 'AIR':
|
|
1711
1949
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
1712
1950
|
elif file_format == 'ONNX':
|
|
@@ -1867,12 +2105,12 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1867
2105
|
data_file_name = os.path.join(dirname, external_local)
|
|
1868
2106
|
f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
|
|
1869
2107
|
try:
|
|
1870
|
-
|
|
2108
|
+
round = 0
|
|
1871
2109
|
names = []
|
|
1872
2110
|
for param_proto in model.graph.parameter:
|
|
1873
2111
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1874
2112
|
names.append((name, param_proto))
|
|
1875
|
-
|
|
2113
|
+
names.sort(key=lambda x: x[0])
|
|
1876
2114
|
for pairs in names:
|
|
1877
2115
|
name = pairs[0]
|
|
1878
2116
|
param_proto = pairs[1]
|
|
@@ -1895,8 +2133,8 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1895
2133
|
offset += (data_length + append_size)
|
|
1896
2134
|
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1897
2135
|
f.write(write_data)
|
|
1898
|
-
|
|
1899
|
-
logger.debug(f"writing {
|
|
2136
|
+
round += 1
|
|
2137
|
+
logger.debug(f"writing {round}th split data, name:{name}")
|
|
1900
2138
|
|
|
1901
2139
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
1902
2140
|
if os.path.exists(graph_file_name):
|
|
@@ -1993,6 +2231,10 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1993
2231
|
dataset = kwargs.get('dataset')
|
|
1994
2232
|
_save_dataset_to_mindir(model, dataset)
|
|
1995
2233
|
|
|
2234
|
+
custom_func = kwargs.get('custom_func', None)
|
|
2235
|
+
if custom_func is not None:
|
|
2236
|
+
custom_func(model)
|
|
2237
|
+
|
|
1996
2238
|
save_together = _save_together(net_dict, model)
|
|
1997
2239
|
is_encrypt = lambda: 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys()
|
|
1998
2240
|
if save_together:
|
|
@@ -2079,6 +2321,45 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
2079
2321
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
2080
2322
|
|
|
2081
2323
|
|
|
2324
|
+
def check_checkpoint(ckpt_file_name):
|
|
2325
|
+
"""
|
|
2326
|
+
Check whether the checkpoint is valid.
|
|
2327
|
+
|
|
2328
|
+
Args:
|
|
2329
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
2330
|
+
|
|
2331
|
+
Returns:
|
|
2332
|
+
bool, whether the checkpoint is valid.
|
|
2333
|
+
|
|
2334
|
+
Examples:
|
|
2335
|
+
>>> import mindspore as ms
|
|
2336
|
+
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
2337
|
+
>>> check_result = ms.check_checkpoint(ckpt_file_name)
|
|
2338
|
+
>>> print(check_result)
|
|
2339
|
+
True
|
|
2340
|
+
"""
|
|
2341
|
+
if not ckpt_file_name.endswith('.ckpt'):
|
|
2342
|
+
return False
|
|
2343
|
+
checkpoint_list = Checkpoint()
|
|
2344
|
+
with _ckpt_fs.open(ckpt_file_name, *_ckpt_fs.open_args) as f:
|
|
2345
|
+
pb_content = f.read()
|
|
2346
|
+
if pb_content[-17:-10] == b"crc_num":
|
|
2347
|
+
crc_num_bytes = pb_content[-10:]
|
|
2348
|
+
pb_content = pb_content[:-17]
|
|
2349
|
+
crc_num = int.from_bytes(crc_num_bytes, byteorder='big')
|
|
2350
|
+
cal_crc_num = binascii.crc32(pb_content, 0)
|
|
2351
|
+
if cal_crc_num != crc_num:
|
|
2352
|
+
logger.warning("For 'check_checkpoint', the ckpt crc check is failed.")
|
|
2353
|
+
return False
|
|
2354
|
+
try:
|
|
2355
|
+
checkpoint_list.ParseFromString(pb_content)
|
|
2356
|
+
except google.protobuf.message.DecodeError as e:
|
|
2357
|
+
logger.warning("For 'check_checkpoint', the ckpt parse is failed.")
|
|
2358
|
+
logger.warning(e)
|
|
2359
|
+
return False
|
|
2360
|
+
return True
|
|
2361
|
+
|
|
2362
|
+
|
|
2082
2363
|
def parse_print(print_file_name):
|
|
2083
2364
|
"""
|
|
2084
2365
|
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
@@ -2423,7 +2704,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2423
2704
|
in at least one of them. Default: ``None`` .
|
|
2424
2705
|
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
2425
2706
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
2426
|
-
parameter in the network. When the types are inconsistent perform type conversion
|
|
2707
|
+
parameter in the network. When the types are inconsistent, perform type conversion
|
|
2427
2708
|
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
2428
2709
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
2429
2710
|
is not required. Default: ``None`` .
|
|
@@ -2444,14 +2725,14 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
2444
2725
|
|
|
2445
2726
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2446
2727
|
Please see the `rank table startup
|
|
2447
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
2728
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
|
|
2448
2729
|
for more details.
|
|
2449
2730
|
|
|
2450
2731
|
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2451
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
2732
|
+
<https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
|
|
2452
2733
|
|
|
2453
2734
|
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2454
|
-
Startup <https://www.mindspore.cn/tutorials/experts/en/
|
|
2735
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
|
|
2455
2736
|
|
|
2456
2737
|
>>> import os
|
|
2457
2738
|
>>> import numpy as np
|
|
@@ -2717,11 +2998,10 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2717
2998
|
param_name = merged_param.name
|
|
2718
2999
|
tensor_layout = predict_strategy[param_name]
|
|
2719
3000
|
rank = get_rank()
|
|
2720
|
-
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
3001
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank_id=rank)
|
|
2721
3002
|
requires_grad = merged_param.requires_grad
|
|
2722
3003
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2723
|
-
|
|
2724
|
-
if data_type == mstype.bfloat16:
|
|
3004
|
+
if merged_param.data.dtype == mstype.bfloat16:
|
|
2725
3005
|
split_param = Parameter(Tensor(split_tensor, mstype.bfloat16), param_name, requires_grad, layerwise_parallel)
|
|
2726
3006
|
else:
|
|
2727
3007
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
@@ -2789,7 +3069,7 @@ def _get_mindir_inputs(file_name):
|
|
|
2789
3069
|
|
|
2790
3070
|
def convert_model(mindir_file, convert_file, file_format):
|
|
2791
3071
|
"""
|
|
2792
|
-
Convert mindir model to other format model.
|
|
3072
|
+
Convert mindir model to other format model. The current version only supports conversion to ONNX models.
|
|
2793
3073
|
|
|
2794
3074
|
.. warning::
|
|
2795
3075
|
This is an experimental API that is subject to change or deletion.
|