mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -23,7 +23,7 @@ import os
|
|
|
23
23
|
import shutil
|
|
24
24
|
import stat
|
|
25
25
|
import threading
|
|
26
|
-
from threading import Thread,
|
|
26
|
+
from threading import Thread, RLock
|
|
27
27
|
from collections import defaultdict, OrderedDict
|
|
28
28
|
from io import BytesIO
|
|
29
29
|
|
|
@@ -48,7 +48,7 @@ from mindspore.common.api import _MindsporeFunctionExecutor
|
|
|
48
48
|
from mindspore.common.api import _get_parameter_layout
|
|
49
49
|
from mindspore.common.api import _generate_branch_control_input
|
|
50
50
|
from mindspore.common.initializer import initializer, One
|
|
51
|
-
from mindspore.common.parameter import Parameter
|
|
51
|
+
from mindspore.common.parameter import Parameter, _offload_if_config
|
|
52
52
|
from mindspore.common.tensor import Tensor
|
|
53
53
|
from mindspore.common._utils import is_shape_unknown
|
|
54
54
|
from mindspore.communication.management import get_rank, get_group_size
|
|
@@ -59,8 +59,11 @@ from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_
|
|
|
59
59
|
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
60
60
|
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
61
61
|
_restore_group_info_list
|
|
62
|
+
from mindspore.parallel._ps_context import _set_checkpoint_load_status, _store_warm_up_ptr_by_tensor, \
|
|
63
|
+
_store_warm_up_ptr_by_tensor_list, _cache_enable
|
|
62
64
|
from mindspore.train._utils import read_proto
|
|
63
|
-
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
|
|
65
|
+
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir, \
|
|
66
|
+
split_mindir, split_dynamic_mindir
|
|
64
67
|
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
65
68
|
|
|
66
69
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
@@ -72,11 +75,13 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
|
|
|
72
75
|
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
73
76
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
74
77
|
|
|
78
|
+
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
79
|
+
|
|
75
80
|
mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
|
|
76
81
|
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
|
|
77
82
|
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
|
|
78
83
|
|
|
79
|
-
_ckpt_mutex =
|
|
84
|
+
_ckpt_mutex = RLock()
|
|
80
85
|
|
|
81
86
|
# unit is KB
|
|
82
87
|
SLICE_SIZE = 512 * 1024
|
|
@@ -124,7 +129,7 @@ def _update_param(param, new_param, strict_load):
|
|
|
124
129
|
if param.data.dtype != new_param.data.dtype:
|
|
125
130
|
if _type_convert(param, new_param, strict_load):
|
|
126
131
|
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
|
|
127
|
-
param.set_data(new_tensor)
|
|
132
|
+
param.set_data(new_tensor, param.sliced)
|
|
128
133
|
return
|
|
129
134
|
|
|
130
135
|
logger.critical("Failed to combine the net and the parameters for param %s.", param.name)
|
|
@@ -205,7 +210,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
205
210
|
logger.warning(f"Checkpoint dir: '{checkpoint_dir}' is not existed.")
|
|
206
211
|
|
|
207
212
|
|
|
208
|
-
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
213
|
+
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False):
|
|
209
214
|
"""Execute the process of saving checkpoint into file."""
|
|
210
215
|
try:
|
|
211
216
|
with _ckpt_mutex:
|
|
@@ -213,37 +218,28 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
213
218
|
os.chmod(ckpt_file_name, stat.S_IWUSR)
|
|
214
219
|
os.remove(ckpt_file_name)
|
|
215
220
|
with open(ckpt_file_name, "ab") as f:
|
|
221
|
+
plain_data = None
|
|
216
222
|
if enc_key is not None:
|
|
217
223
|
plain_data = BytesIO()
|
|
218
224
|
|
|
219
225
|
for name, value in data_list.items():
|
|
226
|
+
if name == "random_op":
|
|
227
|
+
_write_random_seed(name, value, f)
|
|
228
|
+
continue
|
|
220
229
|
if value[0] == "mapparameter":
|
|
221
|
-
_write_mapparameter(name, value, f)
|
|
230
|
+
_write_mapparameter(name, value, f, map_param_inc)
|
|
231
|
+
continue
|
|
232
|
+
if value[0] == "offload_parameter":
|
|
233
|
+
new_value = value[1:]
|
|
234
|
+
new_value[2] = value[3].asnumpy().reshape(-1)
|
|
235
|
+
_write_parameter_data(name, new_value, f, enc_key, plain_data)
|
|
236
|
+
_offload_if_config(value[3])
|
|
222
237
|
continue
|
|
223
238
|
if isinstance(value[2], Tensor):
|
|
224
239
|
_write_hugeparameter(name, value, f)
|
|
225
240
|
continue
|
|
226
241
|
|
|
227
|
-
|
|
228
|
-
if data_size > SLICE_SIZE:
|
|
229
|
-
slice_count = math.ceil(data_size / SLICE_SIZE)
|
|
230
|
-
param_slice_list = np.array_split(value[2], slice_count)
|
|
231
|
-
else:
|
|
232
|
-
param_slice_list = [value[2]]
|
|
233
|
-
|
|
234
|
-
for param_slice in param_slice_list:
|
|
235
|
-
checkpoint_list = Checkpoint()
|
|
236
|
-
param_value = checkpoint_list.value.add()
|
|
237
|
-
param_value.tag = name
|
|
238
|
-
param_tensor = param_value.tensor
|
|
239
|
-
param_tensor.dims.extend(value[0])
|
|
240
|
-
param_tensor.tensor_type = value[1]
|
|
241
|
-
param_tensor.tensor_content = param_slice.tobytes()
|
|
242
|
-
|
|
243
|
-
if enc_key is None:
|
|
244
|
-
f.write(checkpoint_list.SerializeToString())
|
|
245
|
-
else:
|
|
246
|
-
plain_data.write(checkpoint_list.SerializeToString())
|
|
242
|
+
_write_parameter_data(name, value, f, enc_key, plain_data)
|
|
247
243
|
|
|
248
244
|
if enc_key is not None:
|
|
249
245
|
plain_data.seek(0)
|
|
@@ -261,18 +257,59 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
261
257
|
raise e
|
|
262
258
|
|
|
263
259
|
|
|
264
|
-
def
|
|
265
|
-
"""Write
|
|
260
|
+
def _write_random_seed(name, value, f):
|
|
261
|
+
"""Write random op into protobuf file."""
|
|
266
262
|
checkpoint_list = Checkpoint()
|
|
267
263
|
param_value = checkpoint_list.value.add()
|
|
268
264
|
param_value.tag = name
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
265
|
+
param_tensor = param_value.tensor
|
|
266
|
+
param_tensor.dims.extend(0)
|
|
267
|
+
param_tensor.tensor_type = "random_op"
|
|
268
|
+
param_tensor.tensor_content = value
|
|
269
|
+
f.write(checkpoint_list.SerializeToString())
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _write_parameter_data(name, value, f, enc_key, plain_data):
|
|
273
|
+
"""Write parameter data into protobuf file."""
|
|
274
|
+
data_size = value[2].nbytes / 1024
|
|
275
|
+
if data_size > SLICE_SIZE:
|
|
276
|
+
slice_count = math.ceil(data_size / SLICE_SIZE)
|
|
277
|
+
param_slice_list = np.array_split(value[2], slice_count)
|
|
278
|
+
else:
|
|
279
|
+
param_slice_list = [value[2]]
|
|
280
|
+
|
|
281
|
+
for param_slice in param_slice_list:
|
|
282
|
+
checkpoint_list = Checkpoint()
|
|
283
|
+
param_value = checkpoint_list.value.add()
|
|
284
|
+
param_value.tag = name
|
|
285
|
+
param_tensor = param_value.tensor
|
|
286
|
+
param_tensor.dims.extend(value[0])
|
|
287
|
+
param_tensor.tensor_type = value[1]
|
|
288
|
+
param_tensor.tensor_content = param_slice.tobytes()
|
|
289
|
+
|
|
290
|
+
if enc_key is None:
|
|
291
|
+
f.write(checkpoint_list.SerializeToString())
|
|
292
|
+
else:
|
|
293
|
+
plain_data.write(checkpoint_list.SerializeToString())
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _write_mapparameter(name, value, f, map_param_inc=False):
|
|
297
|
+
"""Write map parameter into protobuf file."""
|
|
298
|
+
while True:
|
|
299
|
+
logger.info("Checkpoint save map_parameter.")
|
|
300
|
+
data_map_slice = value[1].export_slice_data(map_param_inc)
|
|
301
|
+
checkpoint_list = Checkpoint()
|
|
302
|
+
param_value = checkpoint_list.value.add()
|
|
303
|
+
param_value.tag = name
|
|
304
|
+
map_tensor = param_value.maptensor
|
|
305
|
+
for numpy_data in data_map_slice[:3]:
|
|
306
|
+
tensor_pro = map_tensor.tensor.add()
|
|
307
|
+
tensor_pro.dims.extend(numpy_data.shape)
|
|
308
|
+
tensor_pro.tensor_type = str(numpy_data.dtype)
|
|
309
|
+
tensor_pro.tensor_content = numpy_data.reshape(-1).tobytes()
|
|
275
310
|
f.write(checkpoint_list.SerializeToString())
|
|
311
|
+
if data_map_slice[3]:
|
|
312
|
+
break
|
|
276
313
|
|
|
277
314
|
|
|
278
315
|
def _write_hugeparameter(name, value, f):
|
|
@@ -298,8 +335,8 @@ def _write_hugeparameter(name, value, f):
|
|
|
298
335
|
|
|
299
336
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
300
337
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
301
|
-
if not isinstance(save_obj, nn.Cell
|
|
302
|
-
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell or
|
|
338
|
+
if not isinstance(save_obj, (nn.Cell, list, dict)):
|
|
339
|
+
raise TypeError("For 'save_checkpoint', the parameter 'save_obj' must be nn.Cell, list or dict, "
|
|
303
340
|
"but got {}.".format(type(save_obj)))
|
|
304
341
|
if not isinstance(ckpt_file_name, str):
|
|
305
342
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
@@ -315,34 +352,63 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
315
352
|
|
|
316
353
|
|
|
317
354
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
318
|
-
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
|
|
319
|
-
"""
|
|
355
|
+
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None, **kwargs):
|
|
356
|
+
r"""
|
|
320
357
|
Save checkpoint to a specified file.
|
|
321
358
|
|
|
322
359
|
Args:
|
|
323
|
-
save_obj (Union[Cell, list]): The
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
360
|
+
save_obj (Union[Cell, list, dict]): The object to be saved. The data type can be :class:`mindspore.nn.Cell`,
|
|
361
|
+
list, or dict. If a list, it can be the returned value of `Cell.trainable_params()`, or a list of dict
|
|
362
|
+
elements(each element is a dictionary, like [{"name": param_name, "data": param_data},...], the type of
|
|
363
|
+
`param_name` must be string, and the type of `param_data` must be parameter or Tensor); If dict,
|
|
364
|
+
it can be the returned value of `mindspore.load_checkpoint()`.
|
|
327
365
|
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
|
328
|
-
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
|
|
329
|
-
async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: False
|
|
366
|
+
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: ``True`` .
|
|
367
|
+
async_save (bool): Whether to open an independent thread to save the checkpoint file. Default: ``False`` .
|
|
330
368
|
append_dict (dict): Additional information that needs to be saved. The key of dict must be str, the value
|
|
331
|
-
of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None.
|
|
332
|
-
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
|
333
|
-
is not required. Default: None.
|
|
334
|
-
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
|
335
|
-
mode, currently supports
|
|
369
|
+
of dict must be one of int, float, bool, string, Parameter or Tensor. Default: ``None`` .
|
|
370
|
+
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is ``None`` , the encryption
|
|
371
|
+
is not required. Default: ``None`` .
|
|
372
|
+
enc_mode (str): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption
|
|
373
|
+
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
374
|
+
Default: ``"AES-GCM"`` .
|
|
375
|
+
choice_func (function) : A function for saving custom selected parameters. The input value of `choice_func` is
|
|
376
|
+
a parameter name in string type, and the returned value is a bool.
|
|
377
|
+
If returns ``True`` , the Parameter that matching the custom condition will be saved.
|
|
378
|
+
If returns ``False`` , the Parameter that not matching the custom condition will not
|
|
379
|
+
be saved. Default: ``None`` .
|
|
380
|
+
kwargs (dict): Configuration options dictionary.
|
|
336
381
|
|
|
337
382
|
Raises:
|
|
338
|
-
TypeError: If the parameter save_obj is not
|
|
339
|
-
|
|
383
|
+
TypeError: If the parameter `save_obj` is not :class:`mindspore.nn.Cell` , list or dict type.
|
|
384
|
+
TypeError: If the parameter `integrated_save` or `async_save` is not bool type.
|
|
385
|
+
TypeError: If the parameter `ckpt_file_name` is not string type.
|
|
340
386
|
|
|
341
387
|
Examples:
|
|
342
388
|
>>> import mindspore as ms
|
|
343
389
|
>>>
|
|
344
|
-
>>>
|
|
345
|
-
>>>
|
|
390
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
391
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
392
|
+
>>> net = LeNet5()
|
|
393
|
+
>>> ms.save_checkpoint(net, "./lenet.ckpt",
|
|
394
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
395
|
+
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
|
|
396
|
+
>>> print(param_dict1)
|
|
397
|
+
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
|
|
398
|
+
>>> params_list = net.trainable_params()
|
|
399
|
+
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
|
|
400
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
|
|
401
|
+
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
|
|
402
|
+
>>> print(param_dict2)
|
|
403
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
404
|
+
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
|
|
405
|
+
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
|
|
406
|
+
>>> print(param_dict3)
|
|
407
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
408
|
+
|
|
409
|
+
Tutorial Examples:
|
|
410
|
+
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
411
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
346
412
|
"""
|
|
347
413
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name)
|
|
348
414
|
integrated_save = Validator.check_bool(integrated_save)
|
|
@@ -350,46 +416,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
350
416
|
append_dict = _check_append_dict(append_dict)
|
|
351
417
|
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
|
352
418
|
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
|
353
|
-
|
|
419
|
+
map_param_inc = kwargs.get('incremental', False)
|
|
354
420
|
logger.info("Execute the process of saving checkpoint files.")
|
|
355
421
|
|
|
356
|
-
|
|
357
|
-
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
358
|
-
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
359
|
-
parameter_layout_dict = _get_parameter_layout()
|
|
360
|
-
save_obj.init_parameters_data()
|
|
361
|
-
param_dict = OrderedDict()
|
|
362
|
-
for _, param in save_obj.parameters_and_names():
|
|
363
|
-
param_dict[param.name] = param
|
|
364
|
-
param_list = []
|
|
365
|
-
for (key, value) in param_dict.items():
|
|
366
|
-
each_param = {"name": key}
|
|
367
|
-
if isinstance(value, MapParameter):
|
|
368
|
-
param_data = []
|
|
369
|
-
for export_data in value.export_data():
|
|
370
|
-
param_data.append(Tensor(export_data))
|
|
371
|
-
each_param["data"] = param_data
|
|
372
|
-
param_list.append(each_param)
|
|
373
|
-
continue
|
|
374
|
-
|
|
375
|
-
if value.data.is_persistent_data():
|
|
376
|
-
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
377
|
-
param_data = ["persistent_data"]
|
|
378
|
-
param_data.append(value.data)
|
|
379
|
-
param_data.append(value.param_info.origin_shape)
|
|
380
|
-
param_data.append(str(value.dtype))
|
|
381
|
-
param_data.append(value.key)
|
|
382
|
-
else:
|
|
383
|
-
param_data = Tensor(value.data.asnumpy())
|
|
384
|
-
|
|
385
|
-
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
386
|
-
# which should be combined before saving
|
|
387
|
-
if key in parameter_layout_dict:
|
|
388
|
-
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data, integrated_save)
|
|
389
|
-
|
|
390
|
-
each_param["data"] = param_data
|
|
391
|
-
param_list.append(each_param)
|
|
392
|
-
save_obj = param_list
|
|
422
|
+
save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
393
423
|
|
|
394
424
|
if append_dict:
|
|
395
425
|
append_info_list = []
|
|
@@ -397,19 +427,27 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
397
427
|
if not isinstance(value, str):
|
|
398
428
|
value = Tensor(value)
|
|
399
429
|
append_info_list.append({"name": k_name, "data": value})
|
|
400
|
-
|
|
430
|
+
save_obj.extend(append_info_list)
|
|
401
431
|
|
|
402
432
|
data_list = OrderedDict()
|
|
403
433
|
with _ckpt_mutex:
|
|
404
434
|
for param in save_obj:
|
|
435
|
+
if param["name"] == "random_op":
|
|
436
|
+
data_list["random_op"] = param["data"]
|
|
437
|
+
continue
|
|
405
438
|
key = param["name"]
|
|
406
439
|
data_list[key] = []
|
|
440
|
+
if isinstance(param["data"], MapParameter):
|
|
441
|
+
data_list[param["name"]].append("mapparameter")
|
|
442
|
+
data_list[param["name"]].append(param["data"])
|
|
443
|
+
continue
|
|
407
444
|
if isinstance(param["data"], list):
|
|
408
445
|
if param["data"][0] == "persistent_data":
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
446
|
+
_save_param_list_data(data_list, key, param)
|
|
447
|
+
elif param["data"][0] == "offload_parameter":
|
|
448
|
+
data_list[key].append("offload_parameter")
|
|
449
|
+
_save_param_list_data(data_list, key, param)
|
|
450
|
+
|
|
413
451
|
if isinstance(param["data"], str):
|
|
414
452
|
data_list[key].append([0])
|
|
415
453
|
data_list[key].append('str')
|
|
@@ -435,28 +473,130 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
435
473
|
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt")
|
|
436
474
|
thr.start()
|
|
437
475
|
else:
|
|
438
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode)
|
|
476
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc)
|
|
439
477
|
|
|
440
478
|
logger.info("Saving checkpoint process is finished.")
|
|
441
479
|
|
|
442
480
|
|
|
443
|
-
def
|
|
444
|
-
"""
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
for
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
481
|
+
def _convert_list_to_param_list(save_obj, choice_func):
|
|
482
|
+
"""Convert a list of Parameter to param_list."""
|
|
483
|
+
param_list = []
|
|
484
|
+
if not save_obj:
|
|
485
|
+
return param_list
|
|
486
|
+
if isinstance(save_obj[0], dict):
|
|
487
|
+
param_list = [param for param in save_obj if choice_func is None or choice_func(param["name"])]
|
|
488
|
+
else:
|
|
489
|
+
for param in save_obj:
|
|
490
|
+
if isinstance(param, Parameter):
|
|
491
|
+
if choice_func is not None and not choice_func(param.name):
|
|
492
|
+
continue
|
|
493
|
+
each_param = {"name": param.name, "data": param}
|
|
494
|
+
param_list.append(each_param)
|
|
495
|
+
else:
|
|
496
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by list of Parameter,"
|
|
497
|
+
f"the param should be parameter, but got {type(param)}")
|
|
498
|
+
return param_list
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
502
|
+
"""Convert a dict of Parameter to param_list."""
|
|
503
|
+
param_list = []
|
|
504
|
+
for (key, value) in save_obj.items():
|
|
505
|
+
if isinstance(key, str) and isinstance(value, (Parameter, str)):
|
|
506
|
+
if choice_func is not None and not choice_func(key):
|
|
507
|
+
continue
|
|
508
|
+
each_param = {"name": key, "data": value}
|
|
509
|
+
param_list.append(each_param)
|
|
510
|
+
else:
|
|
511
|
+
raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
|
|
512
|
+
f"value should be Parameter, but got the type of key is {type(key)} and"
|
|
513
|
+
f"the type of value is {type(value)}")
|
|
514
|
+
return param_list
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _convert_cell_param_and_names_to_dict(save_obj, choice_func):
|
|
518
|
+
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
519
|
+
param_dict = OrderedDict()
|
|
520
|
+
for _, param in save_obj.parameters_and_names():
|
|
521
|
+
not_sliced = not param.sliced
|
|
522
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
523
|
+
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
524
|
+
judgment = not_sliced or param.has_init
|
|
525
|
+
if is_graph_mode and _is_in_auto_parallel_mode() and judgment:
|
|
526
|
+
continue
|
|
527
|
+
if choice_func is not None and not choice_func(param.name):
|
|
528
|
+
continue
|
|
529
|
+
# Add suffix for cache_enabled parameter, and then parameter can carry key info.
|
|
530
|
+
# Notice that suffix needs be removed when loading into net.
|
|
531
|
+
if param.cache_enable:
|
|
532
|
+
param_dict[param.name + ".__param_key__" + str(param.key)] = param
|
|
533
|
+
else:
|
|
534
|
+
param_dict[param.name] = param
|
|
535
|
+
return param_dict
|
|
457
536
|
|
|
458
537
|
|
|
459
|
-
def
|
|
538
|
+
def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
539
|
+
"""Convert nn.Cell to param_list."""
|
|
540
|
+
param_list = []
|
|
541
|
+
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
542
|
+
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
543
|
+
parameter_layout_dict = _get_parameter_layout()
|
|
544
|
+
if not _is_in_auto_parallel_mode():
|
|
545
|
+
save_obj.init_parameters_data()
|
|
546
|
+
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func)
|
|
547
|
+
if append_dict and "random_op" in append_dict:
|
|
548
|
+
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
549
|
+
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
550
|
+
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
551
|
+
param_list.append({"name": "random_op", "data": random_byte})
|
|
552
|
+
append_dict.pop("random_op")
|
|
553
|
+
for (key, value) in param_dict.items():
|
|
554
|
+
each_param = {"name": key}
|
|
555
|
+
if isinstance(value, MapParameter):
|
|
556
|
+
each_param["data"] = value
|
|
557
|
+
param_list.append(each_param)
|
|
558
|
+
continue
|
|
559
|
+
|
|
560
|
+
if value.data.is_persistent_data():
|
|
561
|
+
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
562
|
+
param_data = ["persistent_data", value.data, value.param_info.origin_shape, str(value.dtype), value.key]
|
|
563
|
+
elif value.data.offload_file_path() != "":
|
|
564
|
+
# list save offload data: [Param, shape, type, param.key]
|
|
565
|
+
param_data = ["offload_parameter"]
|
|
566
|
+
param_tensor = value.data
|
|
567
|
+
if key in parameter_layout_dict:
|
|
568
|
+
param_tensor = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_tensor,
|
|
569
|
+
integrated_save)
|
|
570
|
+
param_data.append(param_tensor)
|
|
571
|
+
param_data.append(param_tensor.shape)
|
|
572
|
+
param_data.append(str(param_tensor.dtype))
|
|
573
|
+
param_data.append(value.key)
|
|
574
|
+
else:
|
|
575
|
+
param_data = Tensor(value.data.asnumpy())
|
|
576
|
+
|
|
577
|
+
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
578
|
+
# which should be combined before saving
|
|
579
|
+
if key in parameter_layout_dict:
|
|
580
|
+
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
581
|
+
integrated_save)
|
|
582
|
+
|
|
583
|
+
each_param["data"] = param_data
|
|
584
|
+
param_list.append(each_param)
|
|
585
|
+
return param_list
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func):
|
|
589
|
+
"""Convert a save_obj to param_list."""
|
|
590
|
+
if isinstance(save_obj, list):
|
|
591
|
+
return _convert_list_to_param_list(save_obj, choice_func)
|
|
592
|
+
|
|
593
|
+
if isinstance(save_obj, dict):
|
|
594
|
+
return _convert_dict_to_param_dict(save_obj, choice_func)
|
|
595
|
+
|
|
596
|
+
return _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_func)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def _save_param_list_data(data_list, key, param):
|
|
460
600
|
"""Save persistent data into save_obj."""
|
|
461
601
|
dims = []
|
|
462
602
|
# persistent_data shape can not be ()
|
|
@@ -511,7 +651,7 @@ def load(file_name, **kwargs):
|
|
|
511
651
|
|
|
512
652
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
513
653
|
`obfuscate_model()
|
|
514
|
-
<https://www.mindspore.cn/docs/en/r2.
|
|
654
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.obfuscate_model.html>`_.
|
|
515
655
|
|
|
516
656
|
Returns:
|
|
517
657
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -538,6 +678,10 @@ def load(file_name, **kwargs):
|
|
|
538
678
|
[[[[4. 6. 4.]
|
|
539
679
|
[6. 9. 6.]
|
|
540
680
|
[4. 6. 4.]]]]
|
|
681
|
+
|
|
682
|
+
Tutorial Examples:
|
|
683
|
+
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
684
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
541
685
|
"""
|
|
542
686
|
if not isinstance(file_name, str):
|
|
543
687
|
raise ValueError("For 'load', the argument 'file_name' must be string, but "
|
|
@@ -578,6 +722,57 @@ def load(file_name, **kwargs):
|
|
|
578
722
|
return graph
|
|
579
723
|
|
|
580
724
|
|
|
725
|
+
def export_split_mindir(file_name, device_num=8, rank_id=0, dynamic=True, sapp=False):
|
|
726
|
+
"""
|
|
727
|
+
Auto Split MindIR.
|
|
728
|
+
|
|
729
|
+
The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
|
|
730
|
+
|
|
731
|
+
Args:
|
|
732
|
+
file_name (str): MindIR file name.
|
|
733
|
+
device_num (int): device number.
|
|
734
|
+
rank_id (int): rank id.
|
|
735
|
+
dynamic (bool): Indicates whether the model is a dynamic shape mindir model.
|
|
736
|
+
sapp (bool): Indicates whether to automatically generate split strategy through SAPP.
|
|
737
|
+
|
|
738
|
+
Raises:
|
|
739
|
+
ValueError: MindIR file does not exist or `file_name` is not a string.
|
|
740
|
+
RuntimeError: Failed to split MindIR file.
|
|
741
|
+
|
|
742
|
+
Examples:
|
|
743
|
+
>>> import mindspore as ms
|
|
744
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
745
|
+
>>>
|
|
746
|
+
>>> ms.export_split_mindir("net.mindir", device_num=8, rank_id=0)
|
|
747
|
+
|
|
748
|
+
"""
|
|
749
|
+
if not isinstance(file_name, str):
|
|
750
|
+
raise ValueError("For 'Split MindIR', the argument 'file_name' must be string, but "
|
|
751
|
+
"got {}.".format(type(file_name)))
|
|
752
|
+
if not file_name.endswith(".mindir"):
|
|
753
|
+
raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) should end with '.mindir', "
|
|
754
|
+
"please input the correct 'file_name'.")
|
|
755
|
+
if not os.path.exists(file_name):
|
|
756
|
+
raise ValueError("For 'Split MindIR', the argument 'file_name'(MindIR file) does not exist, "
|
|
757
|
+
"please check whether the 'file_name' is correct.")
|
|
758
|
+
file_name = os.path.abspath(file_name)
|
|
759
|
+
|
|
760
|
+
logger.info("Execute the process of export and split mindir.")
|
|
761
|
+
dynamic = True
|
|
762
|
+
if dynamic:
|
|
763
|
+
graph = split_dynamic_mindir(file_name, device_num, rank_id, sapp)
|
|
764
|
+
else:
|
|
765
|
+
graph = split_mindir(file_name)
|
|
766
|
+
|
|
767
|
+
if graph is None:
|
|
768
|
+
if _is_cipher_file(file_name):
|
|
769
|
+
raise RuntimeError("Export and split MindIR failed. The file may be encrypted and decrypt failed, you "
|
|
770
|
+
"can check whether the values of the arguments 'dec_key' and 'dec_mode'"
|
|
771
|
+
" are the same as when exported MindIR file, or check the file integrity.")
|
|
772
|
+
raise RuntimeError("Export and split MindIR failed.")
|
|
773
|
+
return graph
|
|
774
|
+
|
|
775
|
+
|
|
581
776
|
def _check_param_type(param_config, key, target_type, requested):
|
|
582
777
|
"""check type of parameters"""
|
|
583
778
|
if key in param_config:
|
|
@@ -655,17 +850,20 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
655
850
|
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
656
851
|
is the same as using :func:`mindspore.export`.
|
|
657
852
|
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
658
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
853
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
854
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
659
855
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
660
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
661
|
-
|
|
856
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
857
|
+
Reference to 'my_func()' in
|
|
858
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
859
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
662
860
|
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
663
861
|
when loading obfuscated model.
|
|
664
|
-
- obf_random_seed (int):
|
|
665
|
-
|
|
666
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
667
|
-
noted that at least one of `customized_func` or `obf_random_seed` should
|
|
668
|
-
would be applied if both of them are set.
|
|
862
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
863
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
864
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
865
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
866
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
669
867
|
|
|
670
868
|
kwargs (dict): Configuration options dictionary.
|
|
671
869
|
|
|
@@ -685,12 +883,14 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
685
883
|
ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
|
|
686
884
|
|
|
687
885
|
Examples:
|
|
886
|
+
>>> import mindspore as ms
|
|
887
|
+
>>> import mindspore.nn as nn
|
|
688
888
|
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
689
889
|
... 'save_model_path': "./obf_net",
|
|
690
890
|
... 'model_inputs': [input1, ],
|
|
691
891
|
... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
|
|
692
|
-
>>> obfuscate_model(obf_config)
|
|
693
|
-
>>> obf_func = load("obf_net.mindir")
|
|
892
|
+
>>> ms.obfuscate_model(obf_config)
|
|
893
|
+
>>> obf_func = ms.load("obf_net.mindir")
|
|
694
894
|
>>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
|
|
695
895
|
>>> print(obf_net(input1).asnumpy())
|
|
696
896
|
"""
|
|
@@ -762,23 +962,24 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
762
962
|
|
|
763
963
|
Args:
|
|
764
964
|
ckpt_file_name (str): Checkpoint file name.
|
|
765
|
-
net (Cell): The network where the parameters will be loaded. Default: None
|
|
766
|
-
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
|
965
|
+
net (Cell): The network where the parameters will be loaded. Default: ``None`` .
|
|
966
|
+
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
767
967
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
768
968
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
769
|
-
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
969
|
+
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
770
970
|
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
771
|
-
filter_prefix will not be loaded. Default: None.
|
|
772
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
|
773
|
-
is not required. Default: None.
|
|
774
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
|
775
|
-
mode, currently supports
|
|
971
|
+
filter_prefix will not be loaded. Default: ``None`` .
|
|
972
|
+
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
973
|
+
is not required. Default: ``None`` .
|
|
974
|
+
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
975
|
+
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
|
|
976
|
+
Default: ``"AES-GCM"`` .
|
|
776
977
|
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
777
|
-
specify_prefix will be loaded. Default: None.
|
|
978
|
+
specify_prefix will be loaded. Default: ``None`` .
|
|
778
979
|
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
|
|
779
|
-
and the return value is a bool. If returns True, the Parameter
|
|
780
|
-
that matches the custom condition will be loaded. If returns False, the Parameter that
|
|
781
|
-
matches the custom condition will be removed. Default: None.
|
|
980
|
+
and the return value is a bool. If returns ``True`` , the Parameter
|
|
981
|
+
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
|
|
982
|
+
matches the custom condition will be removed. Default: ``None`` .
|
|
782
983
|
|
|
783
984
|
Returns:
|
|
784
985
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -801,23 +1002,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
801
1002
|
>>> print(param_dict["conv2.weight"])
|
|
802
1003
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
803
1004
|
>>> def func(param_name):
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
1005
|
+
... whether_load = False
|
|
1006
|
+
... if param_name.startswith("conv"):
|
|
1007
|
+
... whether_load = True
|
|
1008
|
+
... if param_name.startswith("conv1"):
|
|
1009
|
+
... whether_load = False
|
|
1010
|
+
... return whether_load
|
|
810
1011
|
>>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
811
1012
|
>>> print(param_dict1["conv2.weight"])
|
|
812
1013
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
813
1014
|
>>> def func(param_name):
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
1015
|
+
... whether_load = False
|
|
1016
|
+
... if param_name.startswith("conv1"):
|
|
1017
|
+
... whether_load = True
|
|
1018
|
+
... return whether_load
|
|
818
1019
|
>>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
819
1020
|
>>> print(param_dict2)
|
|
820
1021
|
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
1022
|
+
|
|
1023
|
+
Tutorial Examples:
|
|
1024
|
+
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1025
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
821
1026
|
"""
|
|
822
1027
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
823
1028
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -830,6 +1035,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
830
1035
|
parameter_dict = {}
|
|
831
1036
|
try:
|
|
832
1037
|
param_data_list = []
|
|
1038
|
+
map_data_list = [[], [], []]
|
|
1039
|
+
map_shape_list = [0, 0, 0]
|
|
833
1040
|
if specify_prefix:
|
|
834
1041
|
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
835
1042
|
"please use `choice_func` instead.")
|
|
@@ -837,13 +1044,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
837
1044
|
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
838
1045
|
"please use `choice_func` instead.")
|
|
839
1046
|
for element_id, element in enumerate(checkpoint_list.value):
|
|
1047
|
+
if element.tag == "random_op":
|
|
1048
|
+
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1049
|
+
continue
|
|
840
1050
|
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
841
1051
|
continue
|
|
842
1052
|
if specify_prefix is None and filter_prefix is None and \
|
|
843
1053
|
choice_func is not None and not choice_func(element.tag):
|
|
844
1054
|
continue
|
|
845
1055
|
if element.tensor.ByteSize() == 0:
|
|
846
|
-
|
|
1056
|
+
_load_map_parameter(checkpoint_list, element, element_id, map_data_list, map_shape_list, parameter_dict)
|
|
1057
|
+
if element.tag in parameter_dict:
|
|
1058
|
+
map_data_list = [[], [], []]
|
|
1059
|
+
map_shape_list = [0, 0, 0]
|
|
847
1060
|
continue
|
|
848
1061
|
data = element.tensor.tensor_content
|
|
849
1062
|
data_type = element.tensor.tensor_type
|
|
@@ -856,7 +1069,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
856
1069
|
param_data_list.append(element_data)
|
|
857
1070
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
858
1071
|
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
859
|
-
|
|
1072
|
+
new_data = b"".join(param_data_list)
|
|
1073
|
+
param_data = np.frombuffer(new_data, np_type)
|
|
860
1074
|
param_data_list.clear()
|
|
861
1075
|
dims = element.tensor.dims
|
|
862
1076
|
if dims == [0] and data_type == 'str':
|
|
@@ -868,7 +1082,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
868
1082
|
param_data = int(param_data[0])
|
|
869
1083
|
if dims not in ([0], [1]):
|
|
870
1084
|
param_data = param_data.reshape(list(dims))
|
|
871
|
-
|
|
1085
|
+
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
1086
|
+
parameter_dict[element.tag] = parameter
|
|
1087
|
+
_offload_if_config(parameter)
|
|
872
1088
|
|
|
873
1089
|
logger.info("Loading checkpoint files process is finished.")
|
|
874
1090
|
|
|
@@ -881,14 +1097,48 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
881
1097
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
882
1098
|
f"'filter_prefix' or 'specify_prefix' are set correctly.")
|
|
883
1099
|
|
|
1100
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1101
|
+
(is_worker, net_dict, warm_up_dict) = _warm_up_host_cache(parameter_dict, net)
|
|
884
1102
|
if net is not None:
|
|
885
1103
|
load_param_into_net(net, parameter_dict, strict_load)
|
|
1104
|
+
if _warm_up_host_cache_enabled(parameter_dict):
|
|
1105
|
+
_warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict)
|
|
886
1106
|
|
|
887
1107
|
return parameter_dict
|
|
888
1108
|
|
|
889
1109
|
|
|
1110
|
+
def _load_map_parameter(checkpoint_list, element, element_id, map_data_list,
|
|
1111
|
+
map_shape_list, parameter_dict):
|
|
1112
|
+
"""load map parameter."""
|
|
1113
|
+
logger.info("Checkpoint load map_parameter.")
|
|
1114
|
+
if (element_id != len(checkpoint_list.value) - 1) and \
|
|
1115
|
+
element.tag == checkpoint_list.value[element_id + 1].tag:
|
|
1116
|
+
for index, tensor in enumerate(element.maptensor.tensor):
|
|
1117
|
+
data = tensor.tensor_content
|
|
1118
|
+
data_type = tensor.tensor_type
|
|
1119
|
+
np_type = np_type_convert.get(data_type)
|
|
1120
|
+
element_data = np.frombuffer(data, np_type)
|
|
1121
|
+
map_data_list[index].append(element_data)
|
|
1122
|
+
map_shape_list[index] += tensor.dims[0]
|
|
1123
|
+
else:
|
|
1124
|
+
map_array = []
|
|
1125
|
+
for index, tensor in enumerate(element.maptensor.tensor):
|
|
1126
|
+
data = tensor.tensor_content
|
|
1127
|
+
data_type = tensor.tensor_type
|
|
1128
|
+
np_type = np_type_convert.get(data_type)
|
|
1129
|
+
element_data = np.frombuffer(data, np_type)
|
|
1130
|
+
map_data_list[index].append(element_data)
|
|
1131
|
+
new_data = b"".join(map_data_list[index])
|
|
1132
|
+
param_data = np.frombuffer(new_data, np_type)
|
|
1133
|
+
dims = tensor.dims
|
|
1134
|
+
dims[0] += map_shape_list[index]
|
|
1135
|
+
param_data = param_data.reshape(list(dims))
|
|
1136
|
+
map_array.append(param_data)
|
|
1137
|
+
parameter_dict[element.tag] = map_array
|
|
1138
|
+
|
|
1139
|
+
|
|
890
1140
|
def _check_ckpt_file_name(ckpt_file_name):
|
|
891
|
-
"""Check function load_checkpoint's
|
|
1141
|
+
"""Check function load_checkpoint's ckpt_file_name."""
|
|
892
1142
|
if not isinstance(ckpt_file_name, str):
|
|
893
1143
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
|
894
1144
|
"but got {}.".format(type(ckpt_file_name)))
|
|
@@ -969,18 +1219,13 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
969
1219
|
return whether_load
|
|
970
1220
|
|
|
971
1221
|
|
|
972
|
-
def
|
|
973
|
-
"""
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
element_data = np.frombuffer(data, np_type)
|
|
980
|
-
dims = tensor.dims
|
|
981
|
-
param_data = element_data.reshape(list(dims))
|
|
982
|
-
map_array.append(param_data)
|
|
983
|
-
parameter_dict[element.tag] = map_array
|
|
1222
|
+
def _init_parameter_data_in_parallel_mode(net, parameter_dict):
|
|
1223
|
+
"""In parallel mode, only init the paraemters in ckpt."""
|
|
1224
|
+
for _, param in net.parameters_and_names():
|
|
1225
|
+
if param.name in parameter_dict and param.has_init:
|
|
1226
|
+
logger.warning("{} is not init while load ckpt.".format(param.name))
|
|
1227
|
+
new_tensor = param.init_data()
|
|
1228
|
+
param._update_tensor_data(new_tensor)
|
|
984
1229
|
|
|
985
1230
|
|
|
986
1231
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
@@ -991,10 +1236,10 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
991
1236
|
net (Cell): The network where the parameters will be loaded.
|
|
992
1237
|
parameter_dict (dict): The dictionary generated by load checkpoint file,
|
|
993
1238
|
it is a dictionary consisting of key: parameters's name, value: parameter.
|
|
994
|
-
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
|
1239
|
+
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
995
1240
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
996
1241
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
997
|
-
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
1242
|
+
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
998
1243
|
|
|
999
1244
|
Returns:
|
|
1000
1245
|
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
@@ -1006,25 +1251,33 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1006
1251
|
Examples:
|
|
1007
1252
|
>>> import mindspore as ms
|
|
1008
1253
|
>>>
|
|
1009
|
-
>>>
|
|
1254
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1255
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1256
|
+
>>> net = LeNet5()
|
|
1010
1257
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
1011
1258
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
1012
1259
|
>>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
|
|
1013
1260
|
>>> print(param_not_load)
|
|
1014
1261
|
['conv1.weight']
|
|
1262
|
+
|
|
1263
|
+
Tutorial Examples:
|
|
1264
|
+
- `Saving and Loading the Model - Saving and Loading the Model Weight
|
|
1265
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1015
1266
|
"""
|
|
1016
1267
|
if not isinstance(net, nn.Cell):
|
|
1017
1268
|
logger.critical("Failed to combine the net and the parameters.")
|
|
1018
1269
|
msg = ("For 'load_param_into_net', the argument 'net' should be a Cell, but got {}.".format(type(net)))
|
|
1019
1270
|
raise TypeError(msg)
|
|
1020
|
-
|
|
1021
1271
|
if not isinstance(parameter_dict, dict):
|
|
1022
1272
|
logger.critical("Failed to combine the net and the parameters.")
|
|
1023
1273
|
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1024
1274
|
"but got {}.".format(type(parameter_dict)))
|
|
1025
1275
|
raise TypeError(msg)
|
|
1276
|
+
if "random_op" in parameter_dict.keys():
|
|
1277
|
+
net._add_attr("random_op_snapshot", parameter_dict["random_op"])
|
|
1278
|
+
parameter_dict.pop("random_op")
|
|
1026
1279
|
for key, value in parameter_dict.items():
|
|
1027
|
-
if not isinstance(key, str) or not isinstance(value, (Parameter, str)):
|
|
1280
|
+
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1028
1281
|
logger.critical("Load parameters into net failed.")
|
|
1029
1282
|
msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
|
|
1030
1283
|
"'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
|
|
@@ -1032,11 +1285,20 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1032
1285
|
|
|
1033
1286
|
strict_load = Validator.check_bool(strict_load)
|
|
1034
1287
|
logger.info("Execute the process of loading parameters into net.")
|
|
1035
|
-
|
|
1288
|
+
if not _is_in_auto_parallel_mode():
|
|
1289
|
+
net.init_parameters_data()
|
|
1290
|
+
else:
|
|
1291
|
+
_init_parameter_data_in_parallel_mode(net, parameter_dict)
|
|
1036
1292
|
param_not_load = []
|
|
1037
1293
|
ckpt_not_load = list(parameter_dict.keys())
|
|
1038
1294
|
for _, param in net.parameters_and_names():
|
|
1039
1295
|
if param.name in parameter_dict:
|
|
1296
|
+
if isinstance(param, MapParameter):
|
|
1297
|
+
param.import_data(parameter_dict[param.name])
|
|
1298
|
+
continue
|
|
1299
|
+
# Add has attr protection when load server checkpoint file on worker.
|
|
1300
|
+
if not hasattr(parameter_dict[param.name], "data"):
|
|
1301
|
+
continue
|
|
1040
1302
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
1041
1303
|
_update_param(param, new_param, strict_load)
|
|
1042
1304
|
ckpt_not_load.remove(param.name)
|
|
@@ -1061,6 +1323,72 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
1061
1323
|
return param_not_load, ckpt_not_load
|
|
1062
1324
|
|
|
1063
1325
|
|
|
1326
|
+
def _warm_up_host_cache_enabled(parameter_dict):
|
|
1327
|
+
"""Warm up host cache enabled."""
|
|
1328
|
+
if _cache_enable():
|
|
1329
|
+
return True
|
|
1330
|
+
for key in parameter_dict.keys():
|
|
1331
|
+
if key.find(".__param_key__") != -1:
|
|
1332
|
+
return True
|
|
1333
|
+
return False
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
def _warm_up_host_cache(parameter_dict, net):
|
|
1337
|
+
"""Warm up host cache."""
|
|
1338
|
+
ms_role = os.getenv("MS_ROLE")
|
|
1339
|
+
is_worker = ms_role == "MS_WORKER"
|
|
1340
|
+
param_key_dict = {}
|
|
1341
|
+
# Traverse key, value in parameter_dict, warm up param key and record param key into param_key_dict.
|
|
1342
|
+
if is_worker:
|
|
1343
|
+
net.init_parameters_data()
|
|
1344
|
+
net_dict = {}
|
|
1345
|
+
for name, value in net.parameters_and_names():
|
|
1346
|
+
net_dict[name] = value
|
|
1347
|
+
for param_name, value in parameter_dict.items():
|
|
1348
|
+
pos = param_name.find(".__param_key__")
|
|
1349
|
+
if pos != -1:
|
|
1350
|
+
net_param_name = param_name[:pos]
|
|
1351
|
+
param_key_dict[param_name] = net_param_name
|
|
1352
|
+
net_value = None
|
|
1353
|
+
if net_param_name not in net_dict:
|
|
1354
|
+
logger.warning("net param name : %s is not in net", net_param_name)
|
|
1355
|
+
else:
|
|
1356
|
+
net_value = net_dict.get(net_param_name, None)
|
|
1357
|
+
pos += len(".__param_key__")
|
|
1358
|
+
param_key = int(param_name[pos:])
|
|
1359
|
+
value_is_map_parameter = isinstance(value, list) and len(value) == 3
|
|
1360
|
+
if value_is_map_parameter and (net_value is None or isinstance(net_value, Parameter)):
|
|
1361
|
+
key_tensor = Tensor.from_numpy(value[0])
|
|
1362
|
+
value_tensor = Tensor.from_numpy(value[1])
|
|
1363
|
+
status_tensor = Tensor.from_numpy(value[2])
|
|
1364
|
+
_store_warm_up_ptr_by_tensor_list(param_key, key_tensor, value_tensor, status_tensor)
|
|
1365
|
+
elif not isinstance(value, list) and isinstance(net_value, Parameter):
|
|
1366
|
+
_store_warm_up_ptr_by_tensor(param_key, value)
|
|
1367
|
+
else:
|
|
1368
|
+
logger.warning("Unknown matches parameter type %s and net_value %s", type(value), type(net_value))
|
|
1369
|
+
else:
|
|
1370
|
+
for param_name, value in parameter_dict.items():
|
|
1371
|
+
pos = param_name.find(".__param_key__")
|
|
1372
|
+
if pos != -1:
|
|
1373
|
+
net_param_name = param_name[:pos]
|
|
1374
|
+
param_key_dict[param_name] = net_param_name
|
|
1375
|
+
# Split param key from parameter_dict since worker cannot load param key.
|
|
1376
|
+
warm_up_dict = {}
|
|
1377
|
+
for key, value in param_key_dict.items():
|
|
1378
|
+
if is_worker:
|
|
1379
|
+
warm_up_dict[value] = parameter_dict.pop(key)
|
|
1380
|
+
else:
|
|
1381
|
+
parameter_dict[value] = parameter_dict.pop(key)
|
|
1382
|
+
return (is_worker, parameter_dict, warm_up_dict)
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _warm_up_host_cache_post_process(is_worker, net_dict, warm_up_dict):
|
|
1386
|
+
"""Warm up host cache post process."""
|
|
1387
|
+
if is_worker:
|
|
1388
|
+
net_dict.update(warm_up_dict)
|
|
1389
|
+
_set_checkpoint_load_status(True)
|
|
1390
|
+
|
|
1391
|
+
|
|
1064
1392
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
|
1065
1393
|
"""When some net parameter did not load, try to continue loading."""
|
|
1066
1394
|
prefix_name = ""
|
|
@@ -1161,31 +1489,6 @@ def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, i
|
|
|
1161
1489
|
return param_data
|
|
1162
1490
|
|
|
1163
1491
|
|
|
1164
|
-
def _fill_param_into_net(net, parameter_list):
|
|
1165
|
-
"""
|
|
1166
|
-
Fills parameter_list into net.
|
|
1167
|
-
|
|
1168
|
-
Args:
|
|
1169
|
-
net (Cell): train network.
|
|
1170
|
-
parameter_list (list): parameters list from ge callback.
|
|
1171
|
-
"""
|
|
1172
|
-
parameter_dict = {}
|
|
1173
|
-
for each_param in parameter_list:
|
|
1174
|
-
param_name = each_param["name"]
|
|
1175
|
-
if isinstance(each_param["data"], Parameter):
|
|
1176
|
-
each_param["data"].init_data()
|
|
1177
|
-
np_val = each_param["data"].asnumpy()
|
|
1178
|
-
if np_val.shape == (1,):
|
|
1179
|
-
parameter_dict[param_name] = Parameter(np_val, name=param_name)
|
|
1180
|
-
elif np_val.shape == ():
|
|
1181
|
-
parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)),
|
|
1182
|
-
name=param_name)
|
|
1183
|
-
else:
|
|
1184
|
-
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
|
|
1185
|
-
|
|
1186
|
-
load_param_into_net(net, parameter_dict, strict_load=True)
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
1492
|
def export(net, *inputs, file_name, file_format, **kwargs):
|
|
1190
1493
|
"""
|
|
1191
1494
|
Export the MindSpore network into an offline model in the specified format.
|
|
@@ -1193,9 +1496,9 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1193
1496
|
Note:
|
|
1194
1497
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
1195
1498
|
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
|
|
1196
|
-
3. Exporting functions decorated with
|
|
1197
|
-
4. When exporting a function decorated with
|
|
1198
|
-
calculations.
|
|
1499
|
+
3. Exporting functions decorated with :func:`mindspore.jit` to mindir format is supported.
|
|
1500
|
+
4. When exporting a function decorated with :func:`mindspore.jit`, the function should not involve
|
|
1501
|
+
class properties in calculations.
|
|
1199
1502
|
|
|
1200
1503
|
Args:
|
|
1201
1504
|
net (Union[Cell, function]): MindSpore network.
|
|
@@ -1231,17 +1534,20 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1231
1534
|
|
|
1232
1535
|
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1233
1536
|
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1234
|
-
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
1537
|
+
should be in range of (0, 1] or in ["small", "medium", "large"]. "small", "medium" and "large" are
|
|
1538
|
+
correspond to 0.1, 0.3, and 0.6 respectively.
|
|
1235
1539
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1236
|
-
the switch branch of obfuscation structure. The outputs of customized_func should be boolean
|
|
1237
|
-
|
|
1540
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean and const (
|
|
1541
|
+
Reference to 'my_func()' in
|
|
1542
|
+
`tutorials <https://www.mindspore.cn/mindarmour/docs/en/r2.0/dynamic_obfuscation_protection.html>`_).
|
|
1543
|
+
This function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1238
1544
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1239
1545
|
obfuscated model.
|
|
1240
|
-
- obf_random_seed (int):
|
|
1241
|
-
|
|
1242
|
-
then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1243
|
-
be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1244
|
-
would be applied if both of them are set.
|
|
1546
|
+
- obf_random_seed (int): Obfuscation random seed, which should be in (0, 9223372036854775807]. The
|
|
1547
|
+
structure of obfuscated models corresponding to different random seeds is different. If
|
|
1548
|
+
`obf_random_seed` is set, then it should be passed to :class:`nn.GraphCell()` interface when loading
|
|
1549
|
+
obfuscated model. It should be noted that at least one of `customized_func` or `obf_random_seed` should
|
|
1550
|
+
be set, and the latter mode would be applied if both of them are set.
|
|
1245
1551
|
|
|
1246
1552
|
- incremental (bool): export MindIR incrementally.
|
|
1247
1553
|
|
|
@@ -1250,10 +1556,19 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1250
1556
|
>>> import numpy as np
|
|
1251
1557
|
>>> from mindspore import Tensor
|
|
1252
1558
|
>>>
|
|
1253
|
-
>>>
|
|
1559
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1560
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1561
|
+
>>> net = LeNet5()
|
|
1254
1562
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1255
1563
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
1564
|
+
|
|
1565
|
+
Tutorial Examples:
|
|
1566
|
+
- `Saving and Loading the Model - Saving and Loading MindIR
|
|
1567
|
+
<https://mindspore.cn/tutorials/en/r2.2/beginner/save_load.html#saving-and-loading-mindir>`_
|
|
1256
1568
|
"""
|
|
1569
|
+
old_ms_jit_value = context.get_context("jit_syntax_level")
|
|
1570
|
+
context.set_context(jit_syntax_level=mindspore.STRICT)
|
|
1571
|
+
|
|
1257
1572
|
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
1258
1573
|
if file_format not in supported_formats:
|
|
1259
1574
|
raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
|
|
@@ -1282,6 +1597,47 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1282
1597
|
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1283
1598
|
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
1284
1599
|
|
|
1600
|
+
context.set_context(jit_syntax_level=old_ms_jit_value)
|
|
1601
|
+
|
|
1602
|
+
|
|
1603
|
+
def _get_funcgraph(net, *inputs):
|
|
1604
|
+
"""
|
|
1605
|
+
Compile the MindSpore network and get FuncGraph.
|
|
1606
|
+
|
|
1607
|
+
Arg:
|
|
1608
|
+
net (Union[Cell, function]): MindSpore network.
|
|
1609
|
+
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
|
|
1610
|
+
of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
|
|
1611
|
+
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
|
|
1612
|
+
In second situation, you should adjust batch size of dataset script manually which will impact on
|
|
1613
|
+
the batch size of 'net' input. Only supports parse "image" column from dataset currently.
|
|
1614
|
+
|
|
1615
|
+
Returns:
|
|
1616
|
+
FuncGraph, a mindspore._c_expression.FuncGraph obj.
|
|
1617
|
+
|
|
1618
|
+
Raises:
|
|
1619
|
+
ValueError: input `net` is not a nn.Cell.
|
|
1620
|
+
|
|
1621
|
+
Examples:
|
|
1622
|
+
>>> import mindspore as ms
|
|
1623
|
+
>>> import numpy as np
|
|
1624
|
+
>>> from mindspore import Tensor
|
|
1625
|
+
>>>
|
|
1626
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1627
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
1628
|
+
>>> net = LeNet5()
|
|
1629
|
+
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
1630
|
+
>>> ms.get_funcgraph(net, input_tensor)
|
|
1631
|
+
|
|
1632
|
+
"""
|
|
1633
|
+
if not isinstance(net, nn.Cell):
|
|
1634
|
+
raise ValueError(f"For get_funcgraph's parameter 'net', currently only support Cell right now.")
|
|
1635
|
+
phase_name = "lite_infer_predict" if _is_in_auto_parallel_mode() else "lite_infer_get_func_graph"
|
|
1636
|
+
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
1637
|
+
# pylint: disable=protected-access
|
|
1638
|
+
func_graph = _executor._get_func_graph(net, graph_id)
|
|
1639
|
+
return func_graph
|
|
1640
|
+
|
|
1285
1641
|
|
|
1286
1642
|
def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
1287
1643
|
"""
|
|
@@ -1290,7 +1646,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1290
1646
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1291
1647
|
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1292
1648
|
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
1293
|
-
|
|
1294
1649
|
if file_format == 'AIR':
|
|
1295
1650
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
1296
1651
|
elif file_format == 'ONNX':
|
|
@@ -1454,7 +1809,7 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1454
1809
|
for param_proto in model.graph.parameter:
|
|
1455
1810
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1456
1811
|
param = net_dict[name]
|
|
1457
|
-
raw_data = param.data.
|
|
1812
|
+
raw_data = param.data.get_bytes()
|
|
1458
1813
|
data_length = len(raw_data)
|
|
1459
1814
|
append_size = 0
|
|
1460
1815
|
if data_length % 64 != 0:
|
|
@@ -1508,7 +1863,7 @@ def _msfunc_info(net, *inputs):
|
|
|
1508
1863
|
|
|
1509
1864
|
def _cell_info(net, incremental, *inputs):
|
|
1510
1865
|
"""Get mindir stream and net dict of cell"""
|
|
1511
|
-
phase_name = "
|
|
1866
|
+
phase_name = "export.mindir"
|
|
1512
1867
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
1513
1868
|
# pylint: disable=protected-access
|
|
1514
1869
|
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
|
|
@@ -1581,7 +1936,7 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1581
1936
|
for param_proto in model.graph.parameter:
|
|
1582
1937
|
param_name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1583
1938
|
if param_name in net_dict.keys():
|
|
1584
|
-
param_data = net_dict[param_name].data.
|
|
1939
|
+
param_data = net_dict[param_name].data.get_bytes()
|
|
1585
1940
|
param_proto.raw_data = param_data
|
|
1586
1941
|
else:
|
|
1587
1942
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
@@ -1591,10 +1946,10 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1591
1946
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
|
1592
1947
|
if map_param_name in net_dict.keys():
|
|
1593
1948
|
map_parameter = net_dict[map_param_name]
|
|
1594
|
-
|
|
1595
|
-
map_param_proto.key_tensor.raw_data =
|
|
1596
|
-
map_param_proto.value_tensor.raw_data =
|
|
1597
|
-
map_param_proto.status_tensor.raw_data =
|
|
1949
|
+
key_bytes, value_bytes, status_bytes = map_parameter.export_bytes(incremental)
|
|
1950
|
+
map_param_proto.key_tensor.raw_data = key_bytes
|
|
1951
|
+
map_param_proto.value_tensor.raw_data = value_bytes
|
|
1952
|
+
map_param_proto.status_tensor.raw_data = status_bytes
|
|
1598
1953
|
else:
|
|
1599
1954
|
raise ValueError("The map_parameter '{}' is not belongs to any cell,"
|
|
1600
1955
|
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
@@ -1625,7 +1980,7 @@ def _save_together(net_dict, model):
|
|
|
1625
1980
|
for param_proto in model.graph.parameter:
|
|
1626
1981
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
1627
1982
|
if name in net_dict.keys():
|
|
1628
|
-
data_total += sys.getsizeof(net_dict[name].data.
|
|
1983
|
+
data_total += sys.getsizeof(net_dict[name].data.get_bytes()) / 1024
|
|
1629
1984
|
else:
|
|
1630
1985
|
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1631
1986
|
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
@@ -1656,7 +2011,7 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
1656
2011
|
|
|
1657
2012
|
def parse_print(print_file_name):
|
|
1658
2013
|
"""
|
|
1659
|
-
Parse data file generated by mindspore.ops.Print
|
|
2014
|
+
Parse data file generated by :class:`mindspore.ops.Print`.
|
|
1660
2015
|
|
|
1661
2016
|
Args:
|
|
1662
2017
|
print_file_name (str): The file name needs to be parsed.
|
|
@@ -1671,9 +2026,7 @@ def parse_print(print_file_name):
|
|
|
1671
2026
|
Examples:
|
|
1672
2027
|
>>> import numpy as np
|
|
1673
2028
|
>>> import mindspore as ms
|
|
1674
|
-
>>>
|
|
1675
|
-
>>> from mindspore import nn
|
|
1676
|
-
>>> from mindspore import Tensor
|
|
2029
|
+
>>> from mindspore import nn, Tensor, ops
|
|
1677
2030
|
>>> ms.set_context(mode=ms.GRAPH_MODE, print_file_path='log.data')
|
|
1678
2031
|
>>> class PrintInputTensor(nn.Cell):
|
|
1679
2032
|
... def __init__(self):
|
|
@@ -1688,8 +2041,7 @@ def parse_print(print_file_name):
|
|
|
1688
2041
|
>>> net = PrintInputTensor()
|
|
1689
2042
|
>>> net(input_pra)
|
|
1690
2043
|
>>>
|
|
1691
|
-
>>>
|
|
1692
|
-
>>> data = mindspore.parse_print('./log.data')
|
|
2044
|
+
>>> data = ms.parse_print('./log.data')
|
|
1693
2045
|
>>> print(data)
|
|
1694
2046
|
['print:', Tensor(shape=[2, 4], dtype=Float32, value=
|
|
1695
2047
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
@@ -1836,8 +2188,8 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
1836
2188
|
def restore_group_info_list(group_info_file_name):
|
|
1837
2189
|
"""
|
|
1838
2190
|
Build rank list, the checkpoint of ranks in the rank list has the same contents with the local rank
|
|
1839
|
-
who saves the group_info_file_name
|
|
1840
|
-
like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
2191
|
+
who saves the `group_info_file_name`. To save the group info file, please export GROUP_INFO_FIL
|
|
2192
|
+
environment variables like "export GROUP_INFO_FILE=/data/group_info.pb".
|
|
1841
2193
|
|
|
1842
2194
|
Args:
|
|
1843
2195
|
group_info_file_name (str): Name of group information file.
|
|
@@ -1847,10 +2199,11 @@ def restore_group_info_list(group_info_file_name):
|
|
|
1847
2199
|
|
|
1848
2200
|
Raises:
|
|
1849
2201
|
ValueError: group information file is incorrect.
|
|
1850
|
-
TypeError: group_info_file_name is not str.
|
|
2202
|
+
TypeError: `group_info_file_name` is not str.
|
|
1851
2203
|
|
|
1852
2204
|
Examples:
|
|
1853
|
-
>>>
|
|
2205
|
+
>>> import mindspore as ms
|
|
2206
|
+
>>> ms.restore_list = restore_group_info_list("./group_info.pb")
|
|
1854
2207
|
"""
|
|
1855
2208
|
if not isinstance(group_info_file_name, str):
|
|
1856
2209
|
raise TypeError(f"For 'restore_group_info_list', the argument 'group_info_file_name' should be str, "
|
|
@@ -1868,9 +2221,6 @@ def restore_group_info_list(group_info_file_name):
|
|
|
1868
2221
|
def build_searched_strategy(strategy_filename):
|
|
1869
2222
|
"""
|
|
1870
2223
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
1871
|
-
For details of it, please check:
|
|
1872
|
-
`Saving and Loading Models in Hybrid Parallel Mode
|
|
1873
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1874
2224
|
|
|
1875
2225
|
Args:
|
|
1876
2226
|
strategy_filename (str): Name of strategy file.
|
|
@@ -1880,10 +2230,11 @@ def build_searched_strategy(strategy_filename):
|
|
|
1880
2230
|
|
|
1881
2231
|
Raises:
|
|
1882
2232
|
ValueError: Strategy file is incorrect.
|
|
1883
|
-
TypeError: strategy_filename is not a string.
|
|
2233
|
+
TypeError: `strategy_filename` is not a string.
|
|
1884
2234
|
|
|
1885
2235
|
Examples:
|
|
1886
|
-
>>>
|
|
2236
|
+
>>> import mindspore as ms
|
|
2237
|
+
>>> strategy = ms.build_searched_strategy("./strategy_train.ckpt")
|
|
1887
2238
|
"""
|
|
1888
2239
|
return _build_searched_strategy(strategy_filename)
|
|
1889
2240
|
|
|
@@ -1891,14 +2242,12 @@ def build_searched_strategy(strategy_filename):
|
|
|
1891
2242
|
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
1892
2243
|
"""
|
|
1893
2244
|
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
1894
|
-
For details of it, please check:
|
|
1895
|
-
`<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1896
2245
|
|
|
1897
2246
|
Args:
|
|
1898
2247
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
1899
2248
|
strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
|
|
1900
2249
|
value is slice strategy of this parameter. If strategy is None, just merge
|
|
1901
|
-
parameter slices in 0 axis order. Default: None
|
|
2250
|
+
parameter slices in 0 axis order. Default: ``None``.
|
|
1902
2251
|
|
|
1903
2252
|
Returns:
|
|
1904
2253
|
Parameter, the merged parameter which has the whole data.
|
|
@@ -1986,32 +2335,128 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
1986
2335
|
train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM'):
|
|
1987
2336
|
"""
|
|
1988
2337
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
1989
|
-
For details of distributed inference, please check:
|
|
1990
|
-
`Distributed Inference
|
|
1991
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/distributed_inference.html>`_ .
|
|
1992
2338
|
|
|
1993
2339
|
Args:
|
|
1994
2340
|
network (Cell): Network for distributed predication.
|
|
1995
2341
|
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
|
|
1996
2342
|
predict_strategy (dict): Strategy of predication process. It means that using one device to predict
|
|
1997
|
-
when setting predict_strategy as None. Default: None.
|
|
2343
|
+
when setting predict_strategy as None. Default: ``None`` .
|
|
1998
2344
|
train_strategy_filename (str): The filename of training strategy protocol buffer file.
|
|
1999
2345
|
When train_strategy_filename is None, the training strategy file will be
|
|
2000
2346
|
obtained from context.get_auto_parallel_context("strategy_ckpt_load_file").
|
|
2001
2347
|
Therefore, the training strategy file needs to be specified
|
|
2002
|
-
in at least one of them. Default: None.
|
|
2003
|
-
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
|
2348
|
+
in at least one of them. Default: ``None`` .
|
|
2349
|
+
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
|
|
2004
2350
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
2005
2351
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
2006
|
-
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
2007
|
-
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
|
2008
|
-
is not required. Default: None.
|
|
2009
|
-
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
|
2010
|
-
mode, currently supports 'AES-GCM', 'AES-CBC'
|
|
2352
|
+
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
|
|
2353
|
+
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
|
|
2354
|
+
is not required. Default: ``None`` .
|
|
2355
|
+
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
|
|
2356
|
+
mode, currently supports ``'AES-GCM'`` , ``'AES-CBC'`` and ``'SM4-CBC'`` .
|
|
2357
|
+
Default: ``'AES-GCM'`` .
|
|
2011
2358
|
|
|
2012
2359
|
Raises:
|
|
2013
2360
|
TypeError: The type of inputs do not match the requirements.
|
|
2014
2361
|
ValueError: Failed to load checkpoint into net.
|
|
2362
|
+
|
|
2363
|
+
Supported Platforms:
|
|
2364
|
+
``Ascend`` ``GPU``
|
|
2365
|
+
|
|
2366
|
+
Examples:
|
|
2367
|
+
.. note::
|
|
2368
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
2369
|
+
|
|
2370
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
2371
|
+
Please see the `rank table startup
|
|
2372
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
2373
|
+
for more details.
|
|
2374
|
+
|
|
2375
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun startup
|
|
2376
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
2377
|
+
|
|
2378
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
2379
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
2380
|
+
|
|
2381
|
+
>>> import os
|
|
2382
|
+
>>> import numpy as np
|
|
2383
|
+
>>> import mindspore as ms
|
|
2384
|
+
>>> import mindspore.dataset as ds
|
|
2385
|
+
>>> from mindspore import nn, ops, train
|
|
2386
|
+
>>> from mindspore.communication import init
|
|
2387
|
+
>>>
|
|
2388
|
+
>>> step_per_epoch = 4
|
|
2389
|
+
>>> device_num = 8
|
|
2390
|
+
>>>
|
|
2391
|
+
>>> # Define the network structure.
|
|
2392
|
+
>>> class Net(nn.Cell):
|
|
2393
|
+
... def __init__(self, matmul_size, strategy=None):
|
|
2394
|
+
... super().__init__()
|
|
2395
|
+
... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32)
|
|
2396
|
+
... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np))
|
|
2397
|
+
... self.matmul = ops.MatMul()
|
|
2398
|
+
... self.neg = ops.Neg()
|
|
2399
|
+
... if strategy is not None:
|
|
2400
|
+
... self.matmul.shard(strategy)
|
|
2401
|
+
...
|
|
2402
|
+
... def construct(self, inputs):
|
|
2403
|
+
... x = self.matmul(inputs, self.matmul_weight)
|
|
2404
|
+
... x = self.neg(x)
|
|
2405
|
+
... return x
|
|
2406
|
+
>>>
|
|
2407
|
+
>>> # Create dataset.
|
|
2408
|
+
>>> def get_dataset(*inputs):
|
|
2409
|
+
... def generate():
|
|
2410
|
+
... for _ in range(step_per_epoch):
|
|
2411
|
+
... yield inputs
|
|
2412
|
+
... return generate
|
|
2413
|
+
>>>
|
|
2414
|
+
>>> # Train network and save distributed checkpoint.
|
|
2415
|
+
>>> def train_net():
|
|
2416
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2417
|
+
... init()
|
|
2418
|
+
... np.random.seed(1)
|
|
2419
|
+
... input_data = np.random.rand(16, 96).astype(np.float32)
|
|
2420
|
+
... label_data = np.random.rand(16, 16).astype(np.float32)
|
|
2421
|
+
... fake_dataset = get_dataset(input_data, label_data)
|
|
2422
|
+
... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"])
|
|
2423
|
+
...
|
|
2424
|
+
... # Set parallel strategy.
|
|
2425
|
+
... strategy = ((1, 4), (4, 1))
|
|
2426
|
+
... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num,
|
|
2427
|
+
... strategy_ckpt_save_file="./train_strategy.ckpt")
|
|
2428
|
+
... network = Net(matmul_size=(96, 16), strategy=strategy)
|
|
2429
|
+
... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
|
2430
|
+
... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
2431
|
+
... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt)
|
|
2432
|
+
... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False)
|
|
2433
|
+
... global_rank_id = int(os.getenv("RANK_ID"))
|
|
2434
|
+
... ckpt_path = "./rank_{}_ckpt".format(global_rank_id)
|
|
2435
|
+
... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config)
|
|
2436
|
+
... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False)
|
|
2437
|
+
... ms.reset_auto_parallel_context()
|
|
2438
|
+
>>>
|
|
2439
|
+
>>> # Load distributed checkpoint and test.
|
|
2440
|
+
>>> def load_model():
|
|
2441
|
+
... ms.set_context(mode=ms.GRAPH_MODE)
|
|
2442
|
+
... init()
|
|
2443
|
+
... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel",
|
|
2444
|
+
... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num)
|
|
2445
|
+
... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32))
|
|
2446
|
+
... network = Net(matmul_size=(96, 16))
|
|
2447
|
+
... model = ms.Model(network)
|
|
2448
|
+
... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data))
|
|
2449
|
+
... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)]
|
|
2450
|
+
... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout)
|
|
2451
|
+
... predict_result = model.predict(predict_data)
|
|
2452
|
+
... print(predict_result)
|
|
2453
|
+
>>>
|
|
2454
|
+
>>> train_net()
|
|
2455
|
+
>>> load_model()
|
|
2456
|
+
[[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ]
|
|
2457
|
+
[ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887]
|
|
2458
|
+
...
|
|
2459
|
+
[ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]
|
|
2015
2460
|
"""
|
|
2016
2461
|
network = Validator.check_isinstance("network", network, nn.Cell)
|
|
2017
2462
|
_check_checkpoint_file(checkpoint_filenames)
|
|
@@ -2127,6 +2572,11 @@ def async_ckpt_thread_status():
|
|
|
2127
2572
|
Returns:
|
|
2128
2573
|
bool, True, Asynchronous save checkpoint thread is running.
|
|
2129
2574
|
False, Asynchronous save checkpoint thread is not executing.
|
|
2575
|
+
|
|
2576
|
+
Examples:
|
|
2577
|
+
>>> import mindspore as ms
|
|
2578
|
+
>>> ms.async_ckpt_thread_status()
|
|
2579
|
+
False
|
|
2130
2580
|
"""
|
|
2131
2581
|
thr_list = threading.enumerate()
|
|
2132
2582
|
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
|
|
@@ -2184,7 +2634,8 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
2184
2634
|
return merged_param
|
|
2185
2635
|
param_name = merged_param.name
|
|
2186
2636
|
tensor_layout = predict_strategy[param_name]
|
|
2187
|
-
|
|
2637
|
+
rank = get_rank()
|
|
2638
|
+
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1], rank)
|
|
2188
2639
|
requires_grad = merged_param.requires_grad
|
|
2189
2640
|
layerwise_parallel = merged_param.layerwise_parallel
|
|
2190
2641
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
|
|
@@ -2268,7 +2719,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
2268
2719
|
ValueError: If the parameter `file_format` is not "ONNX".
|
|
2269
2720
|
|
|
2270
2721
|
Examples:
|
|
2271
|
-
>>>
|
|
2722
|
+
>>> import mindspore as ms
|
|
2723
|
+
>>> ms.convert_model("lenet.mindir", "lenet.onnx", "ONNX")
|
|
2272
2724
|
"""
|
|
2273
2725
|
Validator.check_file_name_by_regular(mindir_file)
|
|
2274
2726
|
Validator.check_file_name_by_regular(convert_file)
|