mindspore 2.0.0a0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -25,7 +25,6 @@ import stat
|
|
|
25
25
|
import threading
|
|
26
26
|
from threading import Thread, Lock
|
|
27
27
|
from collections import defaultdict, OrderedDict
|
|
28
|
-
from functools import wraps
|
|
29
28
|
from io import BytesIO
|
|
30
29
|
|
|
31
30
|
import math
|
|
@@ -41,17 +40,19 @@ import mindspore
|
|
|
41
40
|
import mindspore.nn as nn
|
|
42
41
|
from mindspore import context
|
|
43
42
|
from mindspore import log as logger
|
|
44
|
-
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
43
|
+
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
44
|
+
from mindspore import _checkparam as Validator
|
|
45
45
|
from mindspore.common import dtype as mstype
|
|
46
46
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
47
47
|
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
48
48
|
from mindspore.common.api import _get_parameter_layout
|
|
49
|
+
from mindspore.common.api import _generate_branch_control_input
|
|
49
50
|
from mindspore.common.initializer import initializer, One
|
|
50
51
|
from mindspore.common.parameter import Parameter
|
|
51
52
|
from mindspore.common.tensor import Tensor
|
|
52
53
|
from mindspore.common._utils import is_shape_unknown
|
|
53
54
|
from mindspore.communication.management import get_rank, get_group_size
|
|
54
|
-
from mindspore.
|
|
55
|
+
from mindspore.experimental import MapParameter
|
|
55
56
|
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
56
57
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
57
58
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
@@ -216,6 +217,13 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
216
217
|
plain_data = BytesIO()
|
|
217
218
|
|
|
218
219
|
for name, value in data_list.items():
|
|
220
|
+
if value[0] == "mapparameter":
|
|
221
|
+
_write_mapparameter(name, value, f)
|
|
222
|
+
continue
|
|
223
|
+
if isinstance(value[2], Tensor):
|
|
224
|
+
_write_hugeparameter(name, value, f)
|
|
225
|
+
continue
|
|
226
|
+
|
|
219
227
|
data_size = value[2].nbytes / 1024
|
|
220
228
|
if data_size > SLICE_SIZE:
|
|
221
229
|
slice_count = math.ceil(data_size / SLICE_SIZE)
|
|
@@ -253,6 +261,41 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
253
261
|
raise e
|
|
254
262
|
|
|
255
263
|
|
|
264
|
+
def _write_mapparameter(name, value, f):
|
|
265
|
+
"""Write map parameter into protobuf file."""
|
|
266
|
+
checkpoint_list = Checkpoint()
|
|
267
|
+
param_value = checkpoint_list.value.add()
|
|
268
|
+
param_value.tag = name
|
|
269
|
+
map_tensor = param_value.maptensor
|
|
270
|
+
for v in value[1:]:
|
|
271
|
+
tensor = map_tensor.tensor.add()
|
|
272
|
+
tensor.dims.extend(v[0])
|
|
273
|
+
tensor.tensor_type = v[1]
|
|
274
|
+
tensor.tensor_content = v[2].tobytes()
|
|
275
|
+
f.write(checkpoint_list.SerializeToString())
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _write_hugeparameter(name, value, f):
|
|
279
|
+
"""Write huge parameter into protobuf file."""
|
|
280
|
+
slice_num = value[2].slice_num
|
|
281
|
+
offset = 0
|
|
282
|
+
max_size = value[0][0]
|
|
283
|
+
for param_slice in range(slice_num):
|
|
284
|
+
checkpoint_list = Checkpoint()
|
|
285
|
+
param_value = checkpoint_list.value.add()
|
|
286
|
+
param_value.tag = name
|
|
287
|
+
param_tensor = param_value.tensor
|
|
288
|
+
param_tensor.dims.extend(value[0])
|
|
289
|
+
param_tensor.tensor_type = value[1]
|
|
290
|
+
param_key = value[3]
|
|
291
|
+
numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
|
|
292
|
+
if offset + numpy_data.shape[0] > max_size:
|
|
293
|
+
numpy_data = numpy_data[:max_size - offset]
|
|
294
|
+
param_tensor.tensor_content = numpy_data.tobytes()
|
|
295
|
+
f.write(checkpoint_list.SerializeToString())
|
|
296
|
+
offset += numpy_data.shape[0]
|
|
297
|
+
|
|
298
|
+
|
|
256
299
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
257
300
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
258
301
|
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
|
@@ -262,7 +305,7 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
262
305
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
263
306
|
"'ckpt_file_name' must be "
|
|
264
307
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
|
265
|
-
ckpt_file_name = os.path.
|
|
308
|
+
ckpt_file_name = os.path.abspath(ckpt_file_name)
|
|
266
309
|
if os.path.isdir(ckpt_file_name):
|
|
267
310
|
raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
|
|
268
311
|
"it must be a file name.".format(ckpt_file_name))
|
|
@@ -321,7 +364,23 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
321
364
|
param_list = []
|
|
322
365
|
for (key, value) in param_dict.items():
|
|
323
366
|
each_param = {"name": key}
|
|
324
|
-
|
|
367
|
+
if isinstance(value, MapParameter):
|
|
368
|
+
param_data = []
|
|
369
|
+
for export_data in value.export_data():
|
|
370
|
+
param_data.append(Tensor(export_data))
|
|
371
|
+
each_param["data"] = param_data
|
|
372
|
+
param_list.append(each_param)
|
|
373
|
+
continue
|
|
374
|
+
|
|
375
|
+
if value.data.is_persistent_data():
|
|
376
|
+
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
377
|
+
param_data = ["persistent_data"]
|
|
378
|
+
param_data.append(value.data)
|
|
379
|
+
param_data.append(value.param_info.origin_shape)
|
|
380
|
+
param_data.append(str(value.dtype))
|
|
381
|
+
param_data.append(value.key)
|
|
382
|
+
else:
|
|
383
|
+
param_data = Tensor(value.data.asnumpy())
|
|
325
384
|
|
|
326
385
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
327
386
|
# which should be combined before saving
|
|
@@ -345,6 +404,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
345
404
|
for param in save_obj:
|
|
346
405
|
key = param["name"]
|
|
347
406
|
data_list[key] = []
|
|
407
|
+
if isinstance(param["data"], list):
|
|
408
|
+
if param["data"][0] == "persistent_data":
|
|
409
|
+
_save_persistent_data(data_list, key, param)
|
|
410
|
+
else:
|
|
411
|
+
_save_mapparameter(data_list, param)
|
|
412
|
+
continue
|
|
348
413
|
if isinstance(param["data"], str):
|
|
349
414
|
data_list[key].append([0])
|
|
350
415
|
data_list[key].append('str')
|
|
@@ -375,6 +440,34 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
375
440
|
logger.info("Saving checkpoint process is finished.")
|
|
376
441
|
|
|
377
442
|
|
|
443
|
+
def _save_mapparameter(data_list, param):
|
|
444
|
+
"""Save map parameter into save_obj."""
|
|
445
|
+
data_list[param["name"]].append("mapparameter")
|
|
446
|
+
for value in param["data"]:
|
|
447
|
+
dims = []
|
|
448
|
+
tmp_list = []
|
|
449
|
+
for dim in value.shape:
|
|
450
|
+
dims.append(dim)
|
|
451
|
+
tmp_list.append(dims)
|
|
452
|
+
tensor_type = str(value.dtype)
|
|
453
|
+
tmp_list.append(tensor_type)
|
|
454
|
+
data = value.asnumpy().reshape(-1)
|
|
455
|
+
tmp_list.append(data)
|
|
456
|
+
data_list[param["name"]].append(tmp_list)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _save_persistent_data(data_list, key, param):
|
|
460
|
+
"""Save persistent data into save_obj."""
|
|
461
|
+
dims = []
|
|
462
|
+
# persistent_data shape can not be ()
|
|
463
|
+
for dim in param['data'][2]:
|
|
464
|
+
dims.append(dim)
|
|
465
|
+
data_list[key].append(dims)
|
|
466
|
+
data_list[key].append(param['data'][3])
|
|
467
|
+
data_list[key].append(param['data'][1])
|
|
468
|
+
data_list[key].append(param['data'][4])
|
|
469
|
+
|
|
470
|
+
|
|
378
471
|
def _check_append_dict(append_dict):
|
|
379
472
|
"""Check the argument append_dict for save_checkpoint."""
|
|
380
473
|
if append_dict is None:
|
|
@@ -414,11 +507,11 @@ def load(file_name, **kwargs):
|
|
|
414
507
|
|
|
415
508
|
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'.
|
|
416
509
|
- For details of using the customized decryption, please check the `tutorial
|
|
417
|
-
<https://mindspore.cn/mindarmour/docs/en/r2.0
|
|
510
|
+
<https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
|
|
418
511
|
|
|
419
512
|
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
420
513
|
`obfuscate_model()
|
|
421
|
-
<https://www.mindspore.cn/docs/en/r2.0
|
|
514
|
+
<https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore/mindspore.obfuscate_model.html>` .
|
|
422
515
|
|
|
423
516
|
Returns:
|
|
424
517
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
@@ -432,6 +525,8 @@ def load(file_name, **kwargs):
|
|
|
432
525
|
>>> import mindspore as ms
|
|
433
526
|
>>> import mindspore.nn as nn
|
|
434
527
|
>>> from mindspore import Tensor
|
|
528
|
+
>>> from mindspore import context
|
|
529
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
435
530
|
>>>
|
|
436
531
|
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
|
437
532
|
>>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
@@ -453,7 +548,7 @@ def load(file_name, **kwargs):
|
|
|
453
548
|
if not os.path.exists(file_name):
|
|
454
549
|
raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
|
|
455
550
|
"please check whether the 'file_name' is correct.")
|
|
456
|
-
file_name = os.path.
|
|
551
|
+
file_name = os.path.abspath(file_name)
|
|
457
552
|
|
|
458
553
|
# set customized functions for dynamic obfuscation
|
|
459
554
|
obfuscated = _check_load_obfuscate(**kwargs)
|
|
@@ -488,14 +583,15 @@ def _check_param_type(param_config, key, target_type, requested):
|
|
|
488
583
|
if key in param_config:
|
|
489
584
|
if not isinstance(param_config[key], target_type):
|
|
490
585
|
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
|
|
491
|
-
if key == '
|
|
586
|
+
if key == 'obf_random_seed':
|
|
492
587
|
if param_config[key] > INT_64_MAX or param_config[key] <= 0:
|
|
493
588
|
raise ValueError(
|
|
494
|
-
"'
|
|
589
|
+
"'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
|
|
590
|
+
param_config[key]))
|
|
495
591
|
return param_config[key]
|
|
496
592
|
if requested:
|
|
497
593
|
raise ValueError("The parameter {} is requested, but not got.".format(key))
|
|
498
|
-
if key == "
|
|
594
|
+
if key == "obf_random_seed":
|
|
499
595
|
return 0
|
|
500
596
|
return None
|
|
501
597
|
|
|
@@ -517,10 +613,10 @@ def _check_customized_func(customized_func):
|
|
|
517
613
|
|
|
518
614
|
|
|
519
615
|
def _check_obfuscate_params(obf_config):
|
|
520
|
-
"""
|
|
521
|
-
if '
|
|
616
|
+
"""Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
|
|
617
|
+
if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
|
|
522
618
|
raise ValueError(
|
|
523
|
-
"At least one of '
|
|
619
|
+
"At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
|
|
524
620
|
obfuscate_type = _check_param_type(obf_config, "type", str, False)
|
|
525
621
|
if obfuscate_type not in (None, "dynamic"):
|
|
526
622
|
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
|
|
@@ -535,9 +631,13 @@ def _check_obfuscate_params(obf_config):
|
|
|
535
631
|
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
|
|
536
632
|
customized_funcs = []
|
|
537
633
|
if 'customized_func' in obf_config.keys():
|
|
634
|
+
device_target = context.get_context('device_target')
|
|
635
|
+
if device_target in ["GPU", "Ascend"]:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
538
638
|
customized_funcs.append(_check_customized_func(obf_config['customized_func']))
|
|
539
|
-
|
|
540
|
-
return obf_ratio, customized_funcs,
|
|
639
|
+
obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
|
|
640
|
+
return obf_ratio, customized_funcs, obf_random_seed
|
|
541
641
|
|
|
542
642
|
|
|
543
643
|
def obfuscate_model(obf_config, **kwargs):
|
|
@@ -553,18 +653,19 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
553
653
|
model is encrypted, then enc_key and enc_mode should be provided.
|
|
554
654
|
- save_model_path (str): The path to save the obfuscated model.
|
|
555
655
|
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
556
|
-
is the same as using
|
|
656
|
+
is the same as using :func:`mindspore.export`.
|
|
557
657
|
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
558
658
|
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
559
659
|
- customized_func (function): A python function used for customized function mode, which used for control
|
|
560
660
|
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
|
|
561
661
|
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
562
|
-
predicates. If customized_func is set, then it should be passed to
|
|
563
|
-
obfuscated model.
|
|
564
|
-
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
662
|
+
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
663
|
+
when loading obfuscated model.
|
|
664
|
+
- obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
|
|
665
|
+
weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
|
|
666
|
+
then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should be
|
|
667
|
+
noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
|
|
668
|
+
would be applied if both of them are set.
|
|
568
669
|
|
|
569
670
|
kwargs (dict): Configuration options dictionary.
|
|
570
671
|
|
|
@@ -573,24 +674,24 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
573
674
|
Option: 'AES-GCM' | 'AES-CBC' | 'SM4-CBC'. Default: 'AES-GCM'.
|
|
574
675
|
|
|
575
676
|
Raises:
|
|
576
|
-
TypeError: If obf_config is not a dict.
|
|
577
|
-
ValueError: If enc_key is passed and enc_mode is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
|
|
578
|
-
ValueError: If original_model_path is not provided in obf_config
|
|
579
|
-
ValueError: If the model saved in original_model_path has been obfuscated.
|
|
580
|
-
ValueError: If save_model_path is not provided in obf_config
|
|
581
|
-
ValueError: If obf_ratio is not provided in obf_config
|
|
582
|
-
ValueError: If both customized_func and
|
|
583
|
-
ValueError: If
|
|
584
|
-
ValueError: If
|
|
677
|
+
TypeError: If `obf_config` is not a dict.
|
|
678
|
+
ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
|
|
679
|
+
ValueError: If `original_model_path` is not provided in `obf_config`.
|
|
680
|
+
ValueError: If the model saved in `original_model_path` has been obfuscated.
|
|
681
|
+
ValueError: If `save_model_path` is not provided in `obf_config`.
|
|
682
|
+
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
683
|
+
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
684
|
+
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
685
|
+
ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
|
|
585
686
|
|
|
586
687
|
Examples:
|
|
587
688
|
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
588
689
|
... 'save_model_path': "./obf_net",
|
|
589
690
|
... 'model_inputs': [input1, ],
|
|
590
|
-
... 'obf_ratio': 0.1, '
|
|
691
|
+
... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
|
|
591
692
|
>>> obfuscate_model(obf_config)
|
|
592
693
|
>>> obf_func = load("obf_net.mindir")
|
|
593
|
-
>>> obf_net = nn.GraphCell(obf_func,
|
|
694
|
+
>>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
|
|
594
695
|
>>> print(obf_net(input1).asnumpy())
|
|
595
696
|
"""
|
|
596
697
|
if not isinstance(obf_config, dict):
|
|
@@ -610,22 +711,18 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
610
711
|
if -1 in item.shape:
|
|
611
712
|
raise ValueError(
|
|
612
713
|
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
613
|
-
obf_ratio, customized_funcs,
|
|
614
|
-
if customized_funcs and
|
|
615
|
-
logger.warning("Although 'customized_func' and '
|
|
616
|
-
" applied, remember to set '
|
|
714
|
+
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
|
|
715
|
+
if customized_funcs and obf_random_seed > 0:
|
|
716
|
+
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
717
|
+
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
617
718
|
|
|
618
|
-
if
|
|
719
|
+
if obf_random_seed == 0: # apply customized_func mode
|
|
619
720
|
clean_funcs()
|
|
620
721
|
for func in customized_funcs:
|
|
621
722
|
add_opaque_predicate(func.__name__, func)
|
|
622
|
-
|
|
623
|
-
else:
|
|
624
|
-
|
|
625
|
-
int_max = 2 ** 31 - 1
|
|
626
|
-
np.random.seed(obf_password % seed_max)
|
|
627
|
-
append_password = np.random.randint(int_max)
|
|
628
|
-
obf_password %= int_max
|
|
723
|
+
branch_control_input = 0
|
|
724
|
+
else: # apply password mode
|
|
725
|
+
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
629
726
|
|
|
630
727
|
if 'enc_key' in kwargs.keys():
|
|
631
728
|
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
|
@@ -636,29 +733,32 @@ def obfuscate_model(obf_config, **kwargs):
|
|
|
636
733
|
raise ValueError(
|
|
637
734
|
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
638
735
|
"obfuscate_model(), but got {}.".format(enc_mode))
|
|
639
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
640
|
-
|
|
736
|
+
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
737
|
+
branch_control_input=branch_control_input, dec_key=enc_key,
|
|
738
|
+
key_len=len(enc_key),
|
|
641
739
|
dec_mode=enc_mode)
|
|
642
740
|
else:
|
|
643
|
-
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
644
|
-
|
|
741
|
+
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
742
|
+
branch_control_input=branch_control_input)
|
|
645
743
|
|
|
646
744
|
obf_net = nn.GraphCell(obf_graph)
|
|
647
|
-
if
|
|
648
|
-
y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
745
|
+
if obf_random_seed != 0:
|
|
649
746
|
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
650
|
-
model_inputs += [
|
|
747
|
+
model_inputs += [append_y_tensor,]
|
|
651
748
|
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
652
749
|
|
|
653
750
|
|
|
654
751
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
655
|
-
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
|
|
752
|
+
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
656
753
|
"""
|
|
657
754
|
Load checkpoint info from a specified file.
|
|
658
755
|
|
|
659
756
|
Note:
|
|
660
|
-
|
|
661
|
-
|
|
757
|
+
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
758
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
759
|
+
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
760
|
+
`choice_func` is recommended instead.
|
|
761
|
+
And using either of those two args will override `choice_func` at the same time.
|
|
662
762
|
|
|
663
763
|
Args:
|
|
664
764
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -667,14 +767,18 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
667
767
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
668
768
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
669
769
|
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
670
|
-
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the
|
|
671
|
-
will not be loaded. Default: None.
|
|
770
|
+
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
771
|
+
filter_prefix will not be loaded. Default: None.
|
|
672
772
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
|
673
773
|
is not required. Default: None.
|
|
674
774
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
|
675
775
|
mode, currently supports 'AES-GCM' and 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
|
|
676
|
-
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the
|
|
677
|
-
will be loaded. Default: None.
|
|
776
|
+
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
777
|
+
specify_prefix will be loaded. Default: None.
|
|
778
|
+
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
|
|
779
|
+
and the return value is a bool. If returns True, the Parameter
|
|
780
|
+
that matches the custom condition will be loaded. If returns False, the Parameter that
|
|
781
|
+
matches the custom condition will be removed. Default: None.
|
|
678
782
|
|
|
679
783
|
Returns:
|
|
680
784
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
@@ -692,9 +796,28 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
692
796
|
>>> import mindspore as ms
|
|
693
797
|
>>>
|
|
694
798
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
695
|
-
>>> param_dict = ms.load_checkpoint(ckpt_file_name,
|
|
799
|
+
>>> param_dict = ms.load_checkpoint(ckpt_file_name,
|
|
800
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
696
801
|
>>> print(param_dict["conv2.weight"])
|
|
697
802
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
803
|
+
>>> def func(param_name):
|
|
804
|
+
>>> whether_load = False
|
|
805
|
+
>>> if param_name.startswith("conv"):
|
|
806
|
+
>>> whether_load = True
|
|
807
|
+
>>> if param_name.startswith("conv1"):
|
|
808
|
+
>>> whether_load = False
|
|
809
|
+
>>> return whether_load
|
|
810
|
+
>>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
811
|
+
>>> print(param_dict1["conv2.weight"])
|
|
812
|
+
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
813
|
+
>>> def func(param_name):
|
|
814
|
+
>>> whether_load = False
|
|
815
|
+
>>> if param_name.startswith("conv1"):
|
|
816
|
+
>>> whether_load = True
|
|
817
|
+
>>> return whether_load
|
|
818
|
+
>>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
819
|
+
>>> print(param_dict2)
|
|
820
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
698
821
|
"""
|
|
699
822
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
700
823
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -707,9 +830,21 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
707
830
|
parameter_dict = {}
|
|
708
831
|
try:
|
|
709
832
|
param_data_list = []
|
|
833
|
+
if specify_prefix:
|
|
834
|
+
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
835
|
+
"please use `choice_func` instead.")
|
|
836
|
+
if filter_prefix:
|
|
837
|
+
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
838
|
+
"please use `choice_func` instead.")
|
|
710
839
|
for element_id, element in enumerate(checkpoint_list.value):
|
|
711
840
|
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
712
841
|
continue
|
|
842
|
+
if specify_prefix is None and filter_prefix is None and \
|
|
843
|
+
choice_func is not None and not choice_func(element.tag):
|
|
844
|
+
continue
|
|
845
|
+
if element.tensor.ByteSize() == 0:
|
|
846
|
+
_load_mapparameter(element, parameter_dict)
|
|
847
|
+
continue
|
|
713
848
|
data = element.tensor.tensor_content
|
|
714
849
|
data_type = element.tensor.tensor_type
|
|
715
850
|
np_type = tensor_to_np_type.get(data_type)
|
|
@@ -762,7 +897,7 @@ def _check_ckpt_file_name(ckpt_file_name):
|
|
|
762
897
|
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
|
|
763
898
|
"input the correct 'ckpt_file_name'.")
|
|
764
899
|
|
|
765
|
-
ckpt_file_name = os.path.
|
|
900
|
+
ckpt_file_name = os.path.abspath(ckpt_file_name)
|
|
766
901
|
if not os.path.exists(ckpt_file_name):
|
|
767
902
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
|
768
903
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
|
@@ -834,6 +969,20 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
834
969
|
return whether_load
|
|
835
970
|
|
|
836
971
|
|
|
972
|
+
def _load_mapparameter(element, parameter_dict):
|
|
973
|
+
"""Load map parameter from ckpt file."""
|
|
974
|
+
map_array = []
|
|
975
|
+
for tensor in element.maptensor.tensor:
|
|
976
|
+
data = tensor.tensor_content
|
|
977
|
+
data_type = tensor.tensor_type
|
|
978
|
+
np_type = tensor_to_np_type.get(data_type)
|
|
979
|
+
element_data = np.frombuffer(data, np_type)
|
|
980
|
+
dims = tensor.dims
|
|
981
|
+
param_data = element_data.reshape(list(dims))
|
|
982
|
+
map_array.append(param_data)
|
|
983
|
+
parameter_dict[element.tag] = map_array
|
|
984
|
+
|
|
985
|
+
|
|
837
986
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
838
987
|
"""
|
|
839
988
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
@@ -848,7 +997,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
848
997
|
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
849
998
|
|
|
850
999
|
Returns:
|
|
851
|
-
List, the parameter name which are not loaded into the network.
|
|
1000
|
+
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1001
|
+
ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
852
1002
|
|
|
853
1003
|
Raises:
|
|
854
1004
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
|
@@ -859,7 +1009,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
859
1009
|
>>> net = Net()
|
|
860
1010
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
861
1011
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
862
|
-
>>> param_not_load = ms.load_param_into_net(net, param_dict)
|
|
1012
|
+
>>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
|
|
863
1013
|
>>> print(param_not_load)
|
|
864
1014
|
['conv1.weight']
|
|
865
1015
|
"""
|
|
@@ -884,10 +1034,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
884
1034
|
logger.info("Execute the process of loading parameters into net.")
|
|
885
1035
|
net.init_parameters_data()
|
|
886
1036
|
param_not_load = []
|
|
1037
|
+
ckpt_not_load = list(parameter_dict.keys())
|
|
887
1038
|
for _, param in net.parameters_and_names():
|
|
888
1039
|
if param.name in parameter_dict:
|
|
889
1040
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
890
1041
|
_update_param(param, new_param, strict_load)
|
|
1042
|
+
ckpt_not_load.remove(param.name)
|
|
891
1043
|
else:
|
|
892
1044
|
param_not_load.append(param.name)
|
|
893
1045
|
|
|
@@ -906,7 +1058,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
906
1058
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
907
1059
|
for param_name in param_not_load:
|
|
908
1060
|
logger.warning("{} is not loaded.".format(param_name))
|
|
909
|
-
return param_not_load
|
|
1061
|
+
return param_not_load, ckpt_not_load
|
|
910
1062
|
|
|
911
1063
|
|
|
912
1064
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
|
@@ -945,7 +1097,7 @@ def _save_graph(network, file_name):
|
|
|
945
1097
|
"""
|
|
946
1098
|
logger.info("Execute the process of saving graph.")
|
|
947
1099
|
|
|
948
|
-
file_name = os.path.
|
|
1100
|
+
file_name = os.path.abspath(file_name)
|
|
949
1101
|
graph_pb = network.get_func_graph_proto()
|
|
950
1102
|
if graph_pb:
|
|
951
1103
|
with open(file_name, "wb") as f:
|
|
@@ -1031,7 +1183,7 @@ def _fill_param_into_net(net, parameter_list):
|
|
|
1031
1183
|
else:
|
|
1032
1184
|
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
|
|
1033
1185
|
|
|
1034
|
-
load_param_into_net(net, parameter_dict)
|
|
1186
|
+
load_param_into_net(net, parameter_dict, strict_load=True)
|
|
1035
1187
|
|
|
1036
1188
|
|
|
1037
1189
|
def export(net, *inputs, file_name, file_format, **kwargs):
|
|
@@ -1062,12 +1214,6 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1062
1214
|
|
|
1063
1215
|
kwargs (dict): Configuration options dictionary.
|
|
1064
1216
|
|
|
1065
|
-
- quant_mode (str): If the network is a quantization aware training network, the quant_mode should
|
|
1066
|
-
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
|
|
1067
|
-
- mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
|
|
1068
|
-
Default: 127.5.
|
|
1069
|
-
- std_dev (float): The variance of input data after preprocessing,
|
|
1070
|
-
used for quantizing the first layer of the network. Default: 127.5.
|
|
1071
1217
|
- enc_key (byte): Byte-type key used for encryption. The valid length is 16, 24, or 32.
|
|
1072
1218
|
- enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
|
|
1073
1219
|
|
|
@@ -1076,7 +1222,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1076
1222
|
or Customized encryption.
|
|
1077
1223
|
Default: 'AES-GCM'.
|
|
1078
1224
|
- For details of using the customized encryption, please check the `tutorial
|
|
1079
|
-
<https://mindspore.cn/mindarmour/docs/en/r2.0
|
|
1225
|
+
<https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
|
|
1080
1226
|
|
|
1081
1227
|
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1082
1228
|
preprocessing of the dataset into MindIR.
|
|
@@ -1091,10 +1237,14 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1091
1237
|
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1092
1238
|
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1093
1239
|
obfuscated model.
|
|
1094
|
-
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1240
|
+
- obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
|
|
1241
|
+
weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
|
|
1242
|
+
then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should
|
|
1243
|
+
be noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
|
|
1244
|
+
would be applied if both of them are set.
|
|
1245
|
+
|
|
1246
|
+
- incremental (bool): export MindIR incrementally.
|
|
1247
|
+
|
|
1098
1248
|
Examples:
|
|
1099
1249
|
>>> import mindspore as ms
|
|
1100
1250
|
>>> import numpy as np
|
|
@@ -1127,8 +1277,7 @@ def export(net, *inputs, file_name, file_format, **kwargs):
|
|
|
1127
1277
|
+ str(columns))
|
|
1128
1278
|
inputs = tuple(inputs_col)
|
|
1129
1279
|
|
|
1130
|
-
file_name = os.path.
|
|
1131
|
-
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
|
1280
|
+
file_name = os.path.abspath(file_name)
|
|
1132
1281
|
if 'enc_key' in kwargs.keys():
|
|
1133
1282
|
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1134
1283
|
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
@@ -1139,19 +1288,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1139
1288
|
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
|
1140
1289
|
"""
|
|
1141
1290
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
logger.warning(f"For 'export', format 'GEIR' is deprecated, "
|
|
1145
|
-
f"it would be removed in future release, use 'AIR' instead.")
|
|
1146
|
-
file_format = 'AIR'
|
|
1147
|
-
|
|
1148
|
-
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
1149
|
-
is_dump_onnx_in_training = False
|
|
1150
|
-
if hasattr(net, 'training'):
|
|
1151
|
-
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
|
1152
|
-
|
|
1153
|
-
if is_dump_onnx_in_training:
|
|
1154
|
-
net.set_train(mode=False)
|
|
1291
|
+
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1292
|
+
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
1155
1293
|
|
|
1156
1294
|
if file_format == 'AIR':
|
|
1157
1295
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
@@ -1160,9 +1298,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
1160
1298
|
elif file_format == 'MINDIR':
|
|
1161
1299
|
_save_mindir(net, file_name, *inputs, **kwargs)
|
|
1162
1300
|
|
|
1163
|
-
if is_dump_onnx_in_training:
|
|
1164
|
-
net.set_train(mode=True)
|
|
1165
|
-
|
|
1166
1301
|
|
|
1167
1302
|
def _check_key_mode_type(file_format, **kwargs):
|
|
1168
1303
|
"""check enc_key and enc_mode are valid"""
|
|
@@ -1193,7 +1328,7 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
1193
1328
|
if os.path.exists(file_name):
|
|
1194
1329
|
os.chmod(file_name, stat.S_IWUSR)
|
|
1195
1330
|
if "/" in file_name:
|
|
1196
|
-
real_path = os.path.
|
|
1331
|
+
real_path = os.path.abspath(file_name[:file_name.rfind("/")])
|
|
1197
1332
|
os.makedirs(real_path, exist_ok=True)
|
|
1198
1333
|
if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
|
|
1199
1334
|
_executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
|
|
@@ -1204,6 +1339,12 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
1204
1339
|
|
|
1205
1340
|
def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
1206
1341
|
"""Save ONNX format file."""
|
|
1342
|
+
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
1343
|
+
if not isinstance(net, nn.Cell):
|
|
1344
|
+
raise ValueError(f"Export ONNX format model only support nn.Cell object, but got {type(net)}.")
|
|
1345
|
+
_check_dynamic_input(inputs)
|
|
1346
|
+
cell_mode = net.training
|
|
1347
|
+
net.set_train(mode=False)
|
|
1207
1348
|
total_size = _calculation_net_size(net)
|
|
1208
1349
|
if total_size > PROTO_LIMIT_SIZE:
|
|
1209
1350
|
raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
|
|
@@ -1221,6 +1362,13 @@ def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
|
1221
1362
|
with open(file_name, 'wb') as f:
|
|
1222
1363
|
f.write(onnx_stream)
|
|
1223
1364
|
os.chmod(file_name, stat.S_IRUSR)
|
|
1365
|
+
net.set_train(mode=cell_mode)
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def _check_dynamic_input(inputs):
|
|
1369
|
+
for ele in inputs:
|
|
1370
|
+
if isinstance(ele, Tensor) and -1 in ele.shape:
|
|
1371
|
+
raise ValueError(f"Export ONNX format model not support dynamic shape mode.")
|
|
1224
1372
|
|
|
1225
1373
|
|
|
1226
1374
|
def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
|
|
@@ -1270,10 +1418,24 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
|
|
|
1270
1418
|
return f, parameter_size, offset
|
|
1271
1419
|
|
|
1272
1420
|
|
|
1273
|
-
def
|
|
1421
|
+
def _encrypt_data(is_encrypt, write_data, kwargs):
|
|
1422
|
+
"""Encrypt parameter data."""
|
|
1423
|
+
if is_encrypt():
|
|
1424
|
+
if callable(kwargs.get('enc_mode')):
|
|
1425
|
+
enc_func = kwargs.get('enc_mode')
|
|
1426
|
+
write_data = enc_func(write_data, kwargs.get('enc_key'))
|
|
1427
|
+
else:
|
|
1428
|
+
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
|
1429
|
+
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
|
1430
|
+
return write_data
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
1274
1434
|
"""The function to save parameter data."""
|
|
1275
1435
|
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
|
1276
1436
|
# save parameter
|
|
1437
|
+
if model.graph.map_parameter:
|
|
1438
|
+
raise ValueError("MapParameter not support save in split MindIR file now.")
|
|
1277
1439
|
file_prefix = file_name.split("/")[-1]
|
|
1278
1440
|
if file_prefix.endswith(".mindir"):
|
|
1279
1441
|
file_prefix = file_prefix[:-7]
|
|
@@ -1308,13 +1470,7 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1308
1470
|
param_proto.external_data.offset = offset
|
|
1309
1471
|
write_data = raw_data + bytes(append_size)
|
|
1310
1472
|
offset += (data_length + append_size)
|
|
1311
|
-
|
|
1312
|
-
if callable(kwargs.get('enc_mode')):
|
|
1313
|
-
enc_func = kwargs.get('enc_mode')
|
|
1314
|
-
write_data = enc_func(write_data, kwargs.get('enc_key'))
|
|
1315
|
-
else:
|
|
1316
|
-
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
|
1317
|
-
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
|
1473
|
+
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1318
1474
|
f.write(write_data)
|
|
1319
1475
|
|
|
1320
1476
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
@@ -1342,7 +1498,7 @@ def _msfunc_info(net, *inputs):
|
|
|
1342
1498
|
# pylint: disable=protected-access
|
|
1343
1499
|
net_dict = OrderedDict()
|
|
1344
1500
|
_ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
|
|
1345
|
-
graph_id = _ms_func_executor.compile(
|
|
1501
|
+
graph_id = _ms_func_executor.compile(net.__name__, *inputs)
|
|
1346
1502
|
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
1347
1503
|
params = _ms_func_executor._graph_executor.get_params(graph_id)
|
|
1348
1504
|
for name, value in params.items():
|
|
@@ -1350,12 +1506,12 @@ def _msfunc_info(net, *inputs):
|
|
|
1350
1506
|
return mindir_stream, net_dict
|
|
1351
1507
|
|
|
1352
1508
|
|
|
1353
|
-
def _cell_info(net, *inputs):
|
|
1509
|
+
def _cell_info(net, incremental, *inputs):
|
|
1354
1510
|
"""Get mindir stream and net dict of cell"""
|
|
1355
1511
|
phase_name = "predict" if _is_in_auto_parallel_mode() else "export.mindir"
|
|
1356
1512
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
1357
1513
|
# pylint: disable=protected-access
|
|
1358
|
-
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
1514
|
+
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
|
|
1359
1515
|
# clean obfuscation config to prevent the next call
|
|
1360
1516
|
_executor.obfuscate_config = None
|
|
1361
1517
|
|
|
@@ -1372,16 +1528,20 @@ def _set_obfuscate_config(**kwargs):
|
|
|
1372
1528
|
raise ValueError(
|
|
1373
1529
|
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
1374
1530
|
"obfuscation, but got {}.".format(enc_mode))
|
|
1375
|
-
obf_ratio, customized_funcs,
|
|
1376
|
-
if customized_funcs and
|
|
1377
|
-
logger.warning("Although 'customized_func' and '
|
|
1378
|
-
" applied, remember to set '
|
|
1379
|
-
|
|
1380
|
-
if
|
|
1531
|
+
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
|
|
1532
|
+
if customized_funcs and obf_random_seed > 0:
|
|
1533
|
+
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
1534
|
+
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
1535
|
+
|
|
1536
|
+
if obf_random_seed == 0: # apply customized_func mode
|
|
1537
|
+
device_target = context.get_context('device_target')
|
|
1538
|
+
if device_target in ["GPU", "Ascend"]:
|
|
1539
|
+
raise ValueError(
|
|
1540
|
+
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
1381
1541
|
clean_funcs()
|
|
1382
1542
|
for func in customized_funcs:
|
|
1383
1543
|
add_opaque_predicate(func.__name__, func)
|
|
1384
|
-
_executor.obfuscate_config = {'obf_ratio': obf_ratio, '
|
|
1544
|
+
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
|
|
1385
1545
|
|
|
1386
1546
|
|
|
1387
1547
|
def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
@@ -1394,11 +1554,13 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1394
1554
|
raise ValueError(
|
|
1395
1555
|
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
1396
1556
|
|
|
1557
|
+
incremental = kwargs.get('incremental', False)
|
|
1558
|
+
|
|
1397
1559
|
model = mindir_model()
|
|
1398
1560
|
if not isinstance(net, nn.Cell):
|
|
1399
1561
|
mindir_stream, net_dict = _msfunc_info(net, *inputs)
|
|
1400
1562
|
else:
|
|
1401
|
-
mindir_stream, net_dict = _cell_info(net, *inputs)
|
|
1563
|
+
mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
|
|
1402
1564
|
model.ParseFromString(mindir_stream)
|
|
1403
1565
|
|
|
1404
1566
|
if kwargs.get('dataset'):
|
|
@@ -1411,7 +1573,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1411
1573
|
if save_together:
|
|
1412
1574
|
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
|
|
1413
1575
|
else:
|
|
1414
|
-
|
|
1576
|
+
_split_save(net_dict, model, file_name, is_encrypt, **kwargs)
|
|
1415
1577
|
|
|
1416
1578
|
|
|
1417
1579
|
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
@@ -1422,19 +1584,20 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1422
1584
|
param_data = net_dict[param_name].data.asnumpy().tobytes()
|
|
1423
1585
|
param_proto.raw_data = param_data
|
|
1424
1586
|
else:
|
|
1425
|
-
|
|
1426
|
-
|
|
1587
|
+
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1588
|
+
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
1589
|
+
incremental = kwargs.get('incremental', False)
|
|
1427
1590
|
for map_param_proto in model.graph.map_parameter:
|
|
1428
1591
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
|
1429
1592
|
if map_param_name in net_dict.keys():
|
|
1430
1593
|
map_parameter = net_dict[map_param_name]
|
|
1431
|
-
key_nparr, value_nparr, status_nparr = map_parameter.export_data()
|
|
1594
|
+
key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
|
|
1432
1595
|
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
|
1433
1596
|
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
|
1434
1597
|
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
|
1435
1598
|
else:
|
|
1436
|
-
|
|
1437
|
-
|
|
1599
|
+
raise ValueError("The map_parameter '{}' is not belongs to any cell,"
|
|
1600
|
+
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
1438
1601
|
if not file_name.endswith('.mindir'):
|
|
1439
1602
|
file_name += ".mindir"
|
|
1440
1603
|
current_path = os.path.abspath(file_name)
|
|
@@ -1464,8 +1627,8 @@ def _save_together(net_dict, model):
|
|
|
1464
1627
|
if name in net_dict.keys():
|
|
1465
1628
|
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
|
|
1466
1629
|
else:
|
|
1467
|
-
|
|
1468
|
-
|
|
1630
|
+
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1631
|
+
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
1469
1632
|
if data_total > TOTAL_SAVE:
|
|
1470
1633
|
return False
|
|
1471
1634
|
return True
|
|
@@ -1491,62 +1654,6 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
1491
1654
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
1492
1655
|
|
|
1493
1656
|
|
|
1494
|
-
def quant_mode_manage(func):
|
|
1495
|
-
"""Inherit the quant_mode in old version."""
|
|
1496
|
-
|
|
1497
|
-
@wraps(func)
|
|
1498
|
-
def wrapper(network, *inputs, file_format, **kwargs):
|
|
1499
|
-
if 'quant_mode' not in kwargs:
|
|
1500
|
-
return network
|
|
1501
|
-
quant_mode = kwargs.get('quant_mode')
|
|
1502
|
-
if not isinstance(quant_mode, str):
|
|
1503
|
-
raise TypeError("For 'export', the type of 'quant_mode' should be string, "
|
|
1504
|
-
"but got {}.".format(type(quant_mode)))
|
|
1505
|
-
if quant_mode in ('AUTO', 'MANUAL'):
|
|
1506
|
-
kwargs['quant_mode'] = 'QUANT'
|
|
1507
|
-
return func(network, *inputs, file_format=file_format, **kwargs)
|
|
1508
|
-
|
|
1509
|
-
return wrapper
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
@quant_mode_manage
|
|
1513
|
-
def _quant_export(network, *inputs, file_format, **kwargs):
|
|
1514
|
-
"""Exports MindSpore quantization predict model to deploy with AIR and MINDIR."""
|
|
1515
|
-
supported_device = ["Ascend", "GPU"]
|
|
1516
|
-
supported_formats = ['AIR', 'MINDIR']
|
|
1517
|
-
quant_mode_formats = ['QUANT', 'NONQUANT']
|
|
1518
|
-
|
|
1519
|
-
quant_mode = kwargs['quant_mode']
|
|
1520
|
-
if quant_mode not in quant_mode_formats:
|
|
1521
|
-
raise KeyError(f"For 'export', the argument 'quant_mode' must be one of {quant_mode_formats}, "
|
|
1522
|
-
f"but got {quant_mode}.")
|
|
1523
|
-
if quant_mode == 'NONQUANT':
|
|
1524
|
-
return network
|
|
1525
|
-
quant_net = copy.deepcopy(network)
|
|
1526
|
-
quant_net._create_time = int(time.time() * 1e9)
|
|
1527
|
-
|
|
1528
|
-
mean = 127.5 if kwargs.get('mean', None) is None else kwargs.get('mean')
|
|
1529
|
-
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs.get('std_dev')
|
|
1530
|
-
mean = Validator.check_value_type("mean", mean, (int, float))
|
|
1531
|
-
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
|
|
1532
|
-
|
|
1533
|
-
if context.get_context('device_target') not in supported_device:
|
|
1534
|
-
raise KeyError(f"For 'export', quant export only support {supported_device} device target now, "
|
|
1535
|
-
f"but got {context.get_context('device_target')}")
|
|
1536
|
-
|
|
1537
|
-
if file_format not in supported_formats:
|
|
1538
|
-
raise ValueError(f"For 'export', quant export only support 'file_format' {supported_formats}, "
|
|
1539
|
-
f"but got {file_format}.")
|
|
1540
|
-
|
|
1541
|
-
quant_net.set_train(False)
|
|
1542
|
-
if file_format == "MINDIR":
|
|
1543
|
-
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
|
|
1544
|
-
else:
|
|
1545
|
-
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
|
|
1546
|
-
deploy_net = exporter.run()
|
|
1547
|
-
return deploy_net
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
1657
|
def parse_print(print_file_name):
|
|
1551
1658
|
"""
|
|
1552
1659
|
Parse data file generated by mindspore.ops.Print.
|
|
@@ -1588,7 +1695,7 @@ def parse_print(print_file_name):
|
|
|
1588
1695
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
1589
1696
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
1590
1697
|
"""
|
|
1591
|
-
print_file_path = os.path.
|
|
1698
|
+
print_file_path = os.path.abspath(print_file_name)
|
|
1592
1699
|
|
|
1593
1700
|
if os.path.getsize(print_file_path) == 0:
|
|
1594
1701
|
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
@@ -1763,7 +1870,7 @@ def build_searched_strategy(strategy_filename):
|
|
|
1763
1870
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
1764
1871
|
For details of it, please check:
|
|
1765
1872
|
`Saving and Loading Models in Hybrid Parallel Mode
|
|
1766
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
1873
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1767
1874
|
|
|
1768
1875
|
Args:
|
|
1769
1876
|
strategy_filename (str): Name of strategy file.
|
|
@@ -1785,7 +1892,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
1785
1892
|
"""
|
|
1786
1893
|
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
1787
1894
|
For details of it, please check:
|
|
1788
|
-
`<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
1895
|
+
`<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1789
1896
|
|
|
1790
1897
|
Args:
|
|
1791
1898
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
@@ -1880,8 +1987,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
1880
1987
|
"""
|
|
1881
1988
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
1882
1989
|
For details of distributed inference, please check:
|
|
1883
|
-
`Distributed Inference
|
|
1884
|
-
distributed_inference.html>`_ .
|
|
1990
|
+
`Distributed Inference
|
|
1991
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/distributed_inference.html>`_ .
|
|
1885
1992
|
|
|
1886
1993
|
Args:
|
|
1887
1994
|
network (Cell): Network for distributed predication.
|
|
@@ -2116,7 +2223,7 @@ def _get_mindir_inputs(file_name):
|
|
|
2116
2223
|
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
2117
2224
|
"""
|
|
2118
2225
|
Validator.check_file_name_by_regular(file_name)
|
|
2119
|
-
file_name = os.path.
|
|
2226
|
+
file_name = os.path.abspath(file_name)
|
|
2120
2227
|
model = read_proto(file_name)
|
|
2121
2228
|
input_tensor = []
|
|
2122
2229
|
|
|
@@ -2147,8 +2254,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
2147
2254
|
"""
|
|
2148
2255
|
Convert mindir model to other format model. Current version only support convert to "ONNX" format.
|
|
2149
2256
|
|
|
2150
|
-
|
|
2151
|
-
This is an experimental
|
|
2257
|
+
.. warning::
|
|
2258
|
+
This is an experimental API that is subject to change or deletion.
|
|
2152
2259
|
|
|
2153
2260
|
Args:
|
|
2154
2261
|
mindir_file (str): MindIR file name.
|