mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|
|
18
18
|
from __future__ import division
|
|
19
19
|
|
|
20
20
|
import copy
|
|
21
|
-
import functools
|
|
22
21
|
import json
|
|
23
22
|
import os
|
|
24
23
|
import shutil
|
|
@@ -41,24 +40,28 @@ import mindspore
|
|
|
41
40
|
import mindspore.nn as nn
|
|
42
41
|
from mindspore import context
|
|
43
42
|
from mindspore import log as logger
|
|
44
|
-
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
43
|
+
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
44
|
+
from mindspore import _checkparam as Validator
|
|
45
45
|
from mindspore.common import dtype as mstype
|
|
46
46
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
47
|
+
from mindspore.common.api import _MindsporeFunctionExecutor
|
|
48
|
+
from mindspore.common.api import _get_parameter_layout
|
|
49
|
+
from mindspore.common.api import _generate_branch_control_input
|
|
47
50
|
from mindspore.common.initializer import initializer, One
|
|
48
51
|
from mindspore.common.parameter import Parameter
|
|
49
52
|
from mindspore.common.tensor import Tensor
|
|
50
53
|
from mindspore.common._utils import is_shape_unknown
|
|
51
54
|
from mindspore.communication.management import get_rank, get_group_size
|
|
52
|
-
from mindspore.
|
|
55
|
+
from mindspore.experimental import MapParameter
|
|
53
56
|
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
|
54
57
|
from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_tensor_slice_index
|
|
55
58
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
|
|
56
|
-
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
|
57
|
-
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy
|
|
59
|
+
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices, _is_in_auto_parallel_mode
|
|
60
|
+
from mindspore.parallel._parallel_serialization import _convert_to_list, _convert_to_layout, _build_searched_strategy, \
|
|
58
61
|
_restore_group_info_list
|
|
59
62
|
from mindspore.train._utils import read_proto
|
|
60
|
-
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
|
61
|
-
|
|
63
|
+
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, dynamic_obfuscate_mindir
|
|
64
|
+
from ..ops.operations._opaque_predicate_registry import add_opaque_predicate, clean_funcs
|
|
62
65
|
|
|
63
66
|
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
|
64
67
|
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
|
@@ -81,6 +84,7 @@ PROTO_LIMIT_SIZE = 1024 * 1024 * 2
|
|
|
81
84
|
TOTAL_SAVE = 1024 * 1024
|
|
82
85
|
PARAMETER_SPLIT_SIZE = 1024 * 1024 * 1024
|
|
83
86
|
ENCRYPT_BLOCK_SIZE = 64 * 1024
|
|
87
|
+
INT_64_MAX = 9223372036854775807
|
|
84
88
|
|
|
85
89
|
|
|
86
90
|
def _special_process_par(par, new_par):
|
|
@@ -213,6 +217,13 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
213
217
|
plain_data = BytesIO()
|
|
214
218
|
|
|
215
219
|
for name, value in data_list.items():
|
|
220
|
+
if value[0] == "mapparameter":
|
|
221
|
+
_write_mapparameter(name, value, f)
|
|
222
|
+
continue
|
|
223
|
+
if isinstance(value[2], Tensor):
|
|
224
|
+
_write_hugeparameter(name, value, f)
|
|
225
|
+
continue
|
|
226
|
+
|
|
216
227
|
data_size = value[2].nbytes / 1024
|
|
217
228
|
if data_size > SLICE_SIZE:
|
|
218
229
|
slice_count = math.ceil(data_size / SLICE_SIZE)
|
|
@@ -250,6 +261,41 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|
|
250
261
|
raise e
|
|
251
262
|
|
|
252
263
|
|
|
264
|
+
def _write_mapparameter(name, value, f):
|
|
265
|
+
"""Write map parameter into protobuf file."""
|
|
266
|
+
checkpoint_list = Checkpoint()
|
|
267
|
+
param_value = checkpoint_list.value.add()
|
|
268
|
+
param_value.tag = name
|
|
269
|
+
map_tensor = param_value.maptensor
|
|
270
|
+
for v in value[1:]:
|
|
271
|
+
tensor = map_tensor.tensor.add()
|
|
272
|
+
tensor.dims.extend(v[0])
|
|
273
|
+
tensor.tensor_type = v[1]
|
|
274
|
+
tensor.tensor_content = v[2].tobytes()
|
|
275
|
+
f.write(checkpoint_list.SerializeToString())
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _write_hugeparameter(name, value, f):
|
|
279
|
+
"""Write huge parameter into protobuf file."""
|
|
280
|
+
slice_num = value[2].slice_num
|
|
281
|
+
offset = 0
|
|
282
|
+
max_size = value[0][0]
|
|
283
|
+
for param_slice in range(slice_num):
|
|
284
|
+
checkpoint_list = Checkpoint()
|
|
285
|
+
param_value = checkpoint_list.value.add()
|
|
286
|
+
param_value.tag = name
|
|
287
|
+
param_tensor = param_value.tensor
|
|
288
|
+
param_tensor.dims.extend(value[0])
|
|
289
|
+
param_tensor.tensor_type = value[1]
|
|
290
|
+
param_key = value[3]
|
|
291
|
+
numpy_data = value[2].asnumpy_of_slice_persistent_data(param_key, param_slice)
|
|
292
|
+
if offset + numpy_data.shape[0] > max_size:
|
|
293
|
+
numpy_data = numpy_data[:max_size - offset]
|
|
294
|
+
param_tensor.tensor_content = numpy_data.tobytes()
|
|
295
|
+
f.write(checkpoint_list.SerializeToString())
|
|
296
|
+
offset += numpy_data.shape[0]
|
|
297
|
+
|
|
298
|
+
|
|
253
299
|
def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
254
300
|
"""Check save_obj and ckpt_file_name for save_checkpoint."""
|
|
255
301
|
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
|
@@ -259,7 +305,7 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name):
|
|
|
259
305
|
raise TypeError("For 'save_checkpoint', the parameter {} for checkpoint file name is invalid,"
|
|
260
306
|
"'ckpt_file_name' must be "
|
|
261
307
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
|
|
262
|
-
ckpt_file_name = os.path.
|
|
308
|
+
ckpt_file_name = os.path.abspath(ckpt_file_name)
|
|
263
309
|
if os.path.isdir(ckpt_file_name):
|
|
264
310
|
raise IsADirectoryError("For 'save_checkpoint', the parameter `ckpt_file_name`: {} is a directory, "
|
|
265
311
|
"it must be a file name.".format(ckpt_file_name))
|
|
@@ -286,7 +332,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
286
332
|
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
|
287
333
|
is not required. Default: None.
|
|
288
334
|
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
|
289
|
-
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
|
335
|
+
mode, currently supports 'AES-GCM' and 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
|
|
290
336
|
|
|
291
337
|
Raises:
|
|
292
338
|
TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter `integrated_save`
|
|
@@ -308,6 +354,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
308
354
|
logger.info("Execute the process of saving checkpoint files.")
|
|
309
355
|
|
|
310
356
|
if isinstance(save_obj, nn.Cell):
|
|
357
|
+
parameter_layout_dict = save_obj.parameter_layout_dict
|
|
358
|
+
if _is_in_auto_parallel_mode() and not parameter_layout_dict:
|
|
359
|
+
parameter_layout_dict = _get_parameter_layout()
|
|
311
360
|
save_obj.init_parameters_data()
|
|
312
361
|
param_dict = OrderedDict()
|
|
313
362
|
for _, param in save_obj.parameters_and_names():
|
|
@@ -315,12 +364,28 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
315
364
|
param_list = []
|
|
316
365
|
for (key, value) in param_dict.items():
|
|
317
366
|
each_param = {"name": key}
|
|
318
|
-
|
|
367
|
+
if isinstance(value, MapParameter):
|
|
368
|
+
param_data = []
|
|
369
|
+
for export_data in value.export_data():
|
|
370
|
+
param_data.append(Tensor(export_data))
|
|
371
|
+
each_param["data"] = param_data
|
|
372
|
+
param_list.append(each_param)
|
|
373
|
+
continue
|
|
374
|
+
|
|
375
|
+
if value.data.is_persistent_data():
|
|
376
|
+
# list save persistent_data: [Tensor, shape, type, param.key]
|
|
377
|
+
param_data = ["persistent_data"]
|
|
378
|
+
param_data.append(value.data)
|
|
379
|
+
param_data.append(value.param_info.origin_shape)
|
|
380
|
+
param_data.append(str(value.dtype))
|
|
381
|
+
param_data.append(value.key)
|
|
382
|
+
else:
|
|
383
|
+
param_data = Tensor(value.data.asnumpy())
|
|
319
384
|
|
|
320
385
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
321
386
|
# which should be combined before saving
|
|
322
|
-
if key in
|
|
323
|
-
param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save)
|
|
387
|
+
if key in parameter_layout_dict:
|
|
388
|
+
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data, integrated_save)
|
|
324
389
|
|
|
325
390
|
each_param["data"] = param_data
|
|
326
391
|
param_list.append(each_param)
|
|
@@ -339,6 +404,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
339
404
|
for param in save_obj:
|
|
340
405
|
key = param["name"]
|
|
341
406
|
data_list[key] = []
|
|
407
|
+
if isinstance(param["data"], list):
|
|
408
|
+
if param["data"][0] == "persistent_data":
|
|
409
|
+
_save_persistent_data(data_list, key, param)
|
|
410
|
+
else:
|
|
411
|
+
_save_mapparameter(data_list, param)
|
|
412
|
+
continue
|
|
342
413
|
if isinstance(param["data"], str):
|
|
343
414
|
data_list[key].append([0])
|
|
344
415
|
data_list[key].append('str')
|
|
@@ -369,6 +440,34 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
369
440
|
logger.info("Saving checkpoint process is finished.")
|
|
370
441
|
|
|
371
442
|
|
|
443
|
+
def _save_mapparameter(data_list, param):
|
|
444
|
+
"""Save map parameter into save_obj."""
|
|
445
|
+
data_list[param["name"]].append("mapparameter")
|
|
446
|
+
for value in param["data"]:
|
|
447
|
+
dims = []
|
|
448
|
+
tmp_list = []
|
|
449
|
+
for dim in value.shape:
|
|
450
|
+
dims.append(dim)
|
|
451
|
+
tmp_list.append(dims)
|
|
452
|
+
tensor_type = str(value.dtype)
|
|
453
|
+
tmp_list.append(tensor_type)
|
|
454
|
+
data = value.asnumpy().reshape(-1)
|
|
455
|
+
tmp_list.append(data)
|
|
456
|
+
data_list[param["name"]].append(tmp_list)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _save_persistent_data(data_list, key, param):
|
|
460
|
+
"""Save persistent data into save_obj."""
|
|
461
|
+
dims = []
|
|
462
|
+
# persistent_data shape can not be ()
|
|
463
|
+
for dim in param['data'][2]:
|
|
464
|
+
dims.append(dim)
|
|
465
|
+
data_list[key].append(dims)
|
|
466
|
+
data_list[key].append(param['data'][3])
|
|
467
|
+
data_list[key].append(param['data'][1])
|
|
468
|
+
data_list[key].append(param['data'][4])
|
|
469
|
+
|
|
470
|
+
|
|
372
471
|
def _check_append_dict(append_dict):
|
|
373
472
|
"""Check the argument append_dict for save_checkpoint."""
|
|
374
473
|
if append_dict is None:
|
|
@@ -383,6 +482,15 @@ def _check_append_dict(append_dict):
|
|
|
383
482
|
return append_dict
|
|
384
483
|
|
|
385
484
|
|
|
485
|
+
def _check_load_obfuscate(**kwargs):
|
|
486
|
+
if 'obf_func' in kwargs.keys():
|
|
487
|
+
customized_func = _check_customized_func(kwargs.get('obf_func'))
|
|
488
|
+
clean_funcs()
|
|
489
|
+
add_opaque_predicate(customized_func.__name__, customized_func)
|
|
490
|
+
return True
|
|
491
|
+
return False
|
|
492
|
+
|
|
493
|
+
|
|
386
494
|
def load(file_name, **kwargs):
|
|
387
495
|
"""
|
|
388
496
|
Load MindIR.
|
|
@@ -397,9 +505,14 @@ def load(file_name, **kwargs):
|
|
|
397
505
|
- dec_key (bytes): Byte-type key used for decryption. The valid length is 16, 24, or 32.
|
|
398
506
|
- dec_mode (Union[str, function]): Specifies the decryption mode, to take effect when dec_key is set.
|
|
399
507
|
|
|
400
|
-
- Option: 'AES-GCM', 'AES-CBC' or customized decryption. Default: 'AES-GCM'.
|
|
508
|
+
- Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC' or customized decryption. Default: 'AES-GCM'.
|
|
401
509
|
- For details of using the customized decryption, please check the `tutorial
|
|
402
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
510
|
+
<https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
|
|
511
|
+
|
|
512
|
+
- obf_func (function): A python function used for loading obfuscated MindIR model, which can refer to
|
|
513
|
+
`obfuscate_model()
|
|
514
|
+
<https://www.mindspore.cn/docs/en/r2.0/api_python/mindspore/mindspore.obfuscate_model.html>` .
|
|
515
|
+
|
|
403
516
|
Returns:
|
|
404
517
|
GraphCell, a compiled graph that can executed by `GraphCell`.
|
|
405
518
|
|
|
@@ -412,6 +525,8 @@ def load(file_name, **kwargs):
|
|
|
412
525
|
>>> import mindspore as ms
|
|
413
526
|
>>> import mindspore.nn as nn
|
|
414
527
|
>>> from mindspore import Tensor
|
|
528
|
+
>>> from mindspore import context
|
|
529
|
+
>>> context.set_context(mode=context.GRAPH_MODE)
|
|
415
530
|
>>>
|
|
416
531
|
>>> net = nn.Conv2d(1, 1, kernel_size=3, weight_init="ones")
|
|
417
532
|
>>> input_tensor = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
@@ -433,7 +548,10 @@ def load(file_name, **kwargs):
|
|
|
433
548
|
if not os.path.exists(file_name):
|
|
434
549
|
raise ValueError("For 'load', the argument 'file_name'(MindIR file) does not exist, "
|
|
435
550
|
"please check whether the 'file_name' is correct.")
|
|
436
|
-
file_name = os.path.
|
|
551
|
+
file_name = os.path.abspath(file_name)
|
|
552
|
+
|
|
553
|
+
# set customized functions for dynamic obfuscation
|
|
554
|
+
obfuscated = _check_load_obfuscate(**kwargs)
|
|
437
555
|
|
|
438
556
|
logger.info("Execute the process of loading mindir.")
|
|
439
557
|
if 'dec_key' in kwargs.keys():
|
|
@@ -447,9 +565,9 @@ def load(file_name, **kwargs):
|
|
|
447
565
|
else:
|
|
448
566
|
dec_mode = Validator.check_isinstance('dec_mode', kwargs.get('dec_mode'), str)
|
|
449
567
|
graph = load_mindir(file_name, dec_key=dec_key, key_len=len(dec_key), dec_mode=dec_mode,
|
|
450
|
-
decrypt=dec_func)
|
|
568
|
+
decrypt=dec_func, obfuscated=obfuscated)
|
|
451
569
|
else:
|
|
452
|
-
graph = load_mindir(file_name)
|
|
570
|
+
graph = load_mindir(file_name, obfuscated=obfuscated)
|
|
453
571
|
|
|
454
572
|
if graph is None:
|
|
455
573
|
if _is_cipher_file(file_name):
|
|
@@ -460,14 +578,187 @@ def load(file_name, **kwargs):
|
|
|
460
578
|
return graph
|
|
461
579
|
|
|
462
580
|
|
|
581
|
+
def _check_param_type(param_config, key, target_type, requested):
|
|
582
|
+
"""check type of parameters"""
|
|
583
|
+
if key in param_config:
|
|
584
|
+
if not isinstance(param_config[key], target_type):
|
|
585
|
+
raise TypeError("The type of {} must be {}, but got {}.".format(key, target_type, type(param_config[key])))
|
|
586
|
+
if key == 'obf_random_seed':
|
|
587
|
+
if param_config[key] > INT_64_MAX or param_config[key] <= 0:
|
|
588
|
+
raise ValueError(
|
|
589
|
+
"'obf_random_seed' must be in (0, INT_64_MAX({})], but got {}.".format(INT_64_MAX,
|
|
590
|
+
param_config[key]))
|
|
591
|
+
return param_config[key]
|
|
592
|
+
if requested:
|
|
593
|
+
raise ValueError("The parameter {} is requested, but not got.".format(key))
|
|
594
|
+
if key == "obf_random_seed":
|
|
595
|
+
return 0
|
|
596
|
+
return None
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def _check_customized_func(customized_func):
|
|
600
|
+
""" check customized function of dynamic obfuscation """
|
|
601
|
+
if not callable(customized_func):
|
|
602
|
+
raise TypeError(
|
|
603
|
+
"'customized_func' must be a function, but not got {}.".format(type(customized_func)))
|
|
604
|
+
# test customized_func
|
|
605
|
+
try:
|
|
606
|
+
func_result = customized_func(1.0, 1.0)
|
|
607
|
+
except Exception as ex:
|
|
608
|
+
raise TypeError("customized_func must be a function with two inputs, but got exception: {}".format(ex))
|
|
609
|
+
else:
|
|
610
|
+
if not isinstance(func_result, bool):
|
|
611
|
+
raise TypeError("Return value of customized_func must be boolean, but got: {}".format(type(func_result)))
|
|
612
|
+
return customized_func
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def _check_obfuscate_params(obf_config):
|
|
616
|
+
"""Check obfuscation parameters, including obf_random_seed, obf_ratio, customized_func"""
|
|
617
|
+
if 'obf_random_seed' not in obf_config.keys() and 'customized_func' not in obf_config.keys():
|
|
618
|
+
raise ValueError(
|
|
619
|
+
"At least one of 'obf_random_seed' or 'customized_func' must be set in obf_config, but got None of them.")
|
|
620
|
+
obfuscate_type = _check_param_type(obf_config, "type", str, False)
|
|
621
|
+
if obfuscate_type not in (None, "dynamic"):
|
|
622
|
+
raise ValueError("Only 'dynamic' type is supported by now, but got {}.".format(obfuscate_type))
|
|
623
|
+
if ('obf_ratio' in obf_config) and isinstance(obf_config['obf_ratio'], str):
|
|
624
|
+
if obf_config['obf_ratio'] not in ["small", "medium", "large"]:
|
|
625
|
+
raise ValueError("'obf_ratio' can only be 'small', 'medium', 'large' or float, but got {}.".format(
|
|
626
|
+
obf_config['obf_ratio']))
|
|
627
|
+
ratio_dict = {"small": 0.1, "medium": 0.3, "large": 0.6}
|
|
628
|
+
obf_config['obf_ratio'] = ratio_dict.get(obf_config['obf_ratio'])
|
|
629
|
+
obf_ratio = _check_param_type(obf_config, "obf_ratio", float, True)
|
|
630
|
+
if (obf_ratio <= 0) or (obf_ratio > 1):
|
|
631
|
+
raise ValueError("'obf_ratio' must be in (0, 1] if it is a float, but got {}.".format(obf_config['obf_ratio']))
|
|
632
|
+
customized_funcs = []
|
|
633
|
+
if 'customized_func' in obf_config.keys():
|
|
634
|
+
device_target = context.get_context('device_target')
|
|
635
|
+
if device_target in ["GPU", "Ascend"]:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
638
|
+
customized_funcs.append(_check_customized_func(obf_config['customized_func']))
|
|
639
|
+
obf_random_seed = _check_param_type(obf_config, "obf_random_seed", int, False)
|
|
640
|
+
return obf_ratio, customized_funcs, obf_random_seed
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def obfuscate_model(obf_config, **kwargs):
|
|
644
|
+
"""
|
|
645
|
+
Obfuscate a model of MindIR format. Obfuscation means changing the struct of a network without affecting its
|
|
646
|
+
predict correctness. The obfuscated model can prevent attackers from stealing the model.
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
obf_config (dict): obfuscation config.
|
|
650
|
+
|
|
651
|
+
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
652
|
+
- original_model_path (str): The path of MindIR format model that need to be obfuscated. If the original
|
|
653
|
+
model is encrypted, then enc_key and enc_mode should be provided.
|
|
654
|
+
- save_model_path (str): The path to save the obfuscated model.
|
|
655
|
+
- model_inputs (list(Tensor)): The inputs of the original model, the values of Tensor can be random, which
|
|
656
|
+
is the same as using :func:`mindspore.export`.
|
|
657
|
+
- obf_ratio (Union(float, str)): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
658
|
+
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
659
|
+
- customized_func (function): A python function used for customized function mode, which used for control
|
|
660
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
|
|
661
|
+
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
662
|
+
predicates. If customized_func is set, then it should be passed to :func:`mindspore.load` interface
|
|
663
|
+
when loading obfuscated model.
|
|
664
|
+
- obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
|
|
665
|
+
weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
|
|
666
|
+
then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should be
|
|
667
|
+
noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
|
|
668
|
+
would be applied if both of them are set.
|
|
669
|
+
|
|
670
|
+
kwargs (dict): Configuration options dictionary.
|
|
671
|
+
|
|
672
|
+
- enc_key (bytes): Byte type key used for encryption. The valid length is 16, 24, or 32.
|
|
673
|
+
- enc_mode (str): Specifies the encryption mode, to take effect when dec_key is set.
|
|
674
|
+
Option: 'AES-GCM' | 'AES-CBC' | 'SM4-CBC'. Default: 'AES-GCM'.
|
|
675
|
+
|
|
676
|
+
Raises:
|
|
677
|
+
TypeError: If `obf_config` is not a dict.
|
|
678
|
+
ValueError: If `enc_key` is passed and `enc_mode` is not in ["AES-GCM", "AES-CBC", "SM4-CBC"].
|
|
679
|
+
ValueError: If `original_model_path` is not provided in `obf_config`.
|
|
680
|
+
ValueError: If the model saved in `original_model_path` has been obfuscated.
|
|
681
|
+
ValueError: If `save_model_path` is not provided in `obf_config`.
|
|
682
|
+
ValueError: If `obf_ratio` is not provided in `obf_config`.
|
|
683
|
+
ValueError: If both `customized_func` and `obf_random_seed` are not provided in `obf_config`.
|
|
684
|
+
ValueError: If `obf_random_seed` is not in (0, 9223372036854775807].
|
|
685
|
+
ValueError: If `original_model_path` is not exist or `original_model_path` is not end with '.mindir'.
|
|
686
|
+
|
|
687
|
+
Examples:
|
|
688
|
+
>>> obf_config = {'original_model_path': "./net.mindir",
|
|
689
|
+
... 'save_model_path': "./obf_net",
|
|
690
|
+
... 'model_inputs': [input1, ],
|
|
691
|
+
... 'obf_ratio': 0.1, 'obf_random_seed': 173262358423}
|
|
692
|
+
>>> obfuscate_model(obf_config)
|
|
693
|
+
>>> obf_func = load("obf_net.mindir")
|
|
694
|
+
>>> obf_net = nn.GraphCell(obf_func, obf_random_seed=173262358423)
|
|
695
|
+
>>> print(obf_net(input1).asnumpy())
|
|
696
|
+
"""
|
|
697
|
+
if not isinstance(obf_config, dict):
|
|
698
|
+
raise TypeError("'obf_config' must be a dict, but got {}.".format(type(obf_config)))
|
|
699
|
+
file_path = _check_param_type(obf_config, "original_model_path", str, True)
|
|
700
|
+
if not file_path.endswith(".mindir"):
|
|
701
|
+
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) should end with '.mindir', "
|
|
702
|
+
"please input the correct 'file_path'.")
|
|
703
|
+
if not os.path.exists(file_path):
|
|
704
|
+
raise ValueError("For 'obfuscate_model', the argument 'file_path'(MindIR file) does not exist, "
|
|
705
|
+
"please check whether the 'file_path' is correct.")
|
|
706
|
+
saved_path = _check_param_type(obf_config, "save_model_path", str, True)
|
|
707
|
+
model_inputs = _check_param_type(obf_config, "model_inputs", list, True)
|
|
708
|
+
for item in model_inputs:
|
|
709
|
+
if not isinstance(item, Tensor):
|
|
710
|
+
raise TypeError("The item in 'model_inputs' must be Tensor, but got {}.".format(type(item)))
|
|
711
|
+
if -1 in item.shape:
|
|
712
|
+
raise ValueError(
|
|
713
|
+
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
714
|
+
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(obf_config)
|
|
715
|
+
if customized_funcs and obf_random_seed > 0:
|
|
716
|
+
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
717
|
+
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
718
|
+
|
|
719
|
+
if obf_random_seed == 0: # apply customized_func mode
|
|
720
|
+
clean_funcs()
|
|
721
|
+
for func in customized_funcs:
|
|
722
|
+
add_opaque_predicate(func.__name__, func)
|
|
723
|
+
branch_control_input = 0
|
|
724
|
+
else: # apply password mode
|
|
725
|
+
branch_control_input = _generate_branch_control_input(obf_random_seed)
|
|
726
|
+
|
|
727
|
+
if 'enc_key' in kwargs.keys():
|
|
728
|
+
enc_key = Validator.check_isinstance('enc_key', kwargs.get('enc_key'), bytes)
|
|
729
|
+
enc_mode = "AES-GCM"
|
|
730
|
+
if 'enc_mode' in kwargs.keys():
|
|
731
|
+
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
732
|
+
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
733
|
+
raise ValueError(
|
|
734
|
+
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
735
|
+
"obfuscate_model(), but got {}.".format(enc_mode))
|
|
736
|
+
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
737
|
+
branch_control_input=branch_control_input, dec_key=enc_key,
|
|
738
|
+
key_len=len(enc_key),
|
|
739
|
+
dec_mode=enc_mode)
|
|
740
|
+
else:
|
|
741
|
+
obf_graph = dynamic_obfuscate_mindir(file_name=file_path, obf_ratio=obf_ratio,
|
|
742
|
+
branch_control_input=branch_control_input)
|
|
743
|
+
|
|
744
|
+
obf_net = nn.GraphCell(obf_graph)
|
|
745
|
+
if obf_random_seed != 0:
|
|
746
|
+
append_y_tensor = Tensor(np.ones((1, 1)).astype(np.int32))
|
|
747
|
+
model_inputs += [append_y_tensor,]
|
|
748
|
+
export(obf_net, *model_inputs, file_name=saved_path, file_format="MINDIR", **kwargs)
|
|
749
|
+
|
|
750
|
+
|
|
463
751
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
|
464
|
-
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
|
|
752
|
+
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
|
|
465
753
|
"""
|
|
466
754
|
Load checkpoint info from a specified file.
|
|
467
755
|
|
|
468
756
|
Note:
|
|
469
|
-
|
|
470
|
-
|
|
757
|
+
- `specify_prefix` and `filter_prefix` do not affect each other.
|
|
758
|
+
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
759
|
+
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
|
|
760
|
+
`choice_func` is recommended instead.
|
|
761
|
+
And using either of those two args will override `choice_func` at the same time.
|
|
471
762
|
|
|
472
763
|
Args:
|
|
473
764
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -476,20 +767,25 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
476
767
|
into net when parameter name's suffix in checkpoint file is the same as the
|
|
477
768
|
parameter in the network. When the types are inconsistent perform type conversion
|
|
478
769
|
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
479
|
-
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the
|
|
480
|
-
will not be loaded. Default: None.
|
|
770
|
+
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
771
|
+
filter_prefix will not be loaded. Default: None.
|
|
481
772
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
|
482
773
|
is not required. Default: None.
|
|
483
774
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
|
484
|
-
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
|
485
|
-
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the
|
|
486
|
-
will be loaded. Default: None.
|
|
775
|
+
mode, currently supports 'AES-GCM' and 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
|
|
776
|
+
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
|
|
777
|
+
specify_prefix will be loaded. Default: None.
|
|
778
|
+
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
|
|
779
|
+
and the return value is a bool. If returns True, the Parameter
|
|
780
|
+
that matches the custom condition will be loaded. If returns False, the Parameter that
|
|
781
|
+
matches the custom condition will be removed. Default: None.
|
|
487
782
|
|
|
488
783
|
Returns:
|
|
489
784
|
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
|
|
490
|
-
:func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`CheckpointConfig`
|
|
491
|
-
save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
|
|
492
|
-
return value obtained by loading checkpoint is string, and in other cases the return value is
|
|
785
|
+
:func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig`
|
|
786
|
+
are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
|
|
787
|
+
then the return value obtained by loading checkpoint is string, and in other cases the return value is
|
|
788
|
+
Parameter.
|
|
493
789
|
|
|
494
790
|
Raises:
|
|
495
791
|
ValueError: Checkpoint file's format is incorrect.
|
|
@@ -500,9 +796,28 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
500
796
|
>>> import mindspore as ms
|
|
501
797
|
>>>
|
|
502
798
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
503
|
-
>>> param_dict = ms.load_checkpoint(ckpt_file_name,
|
|
799
|
+
>>> param_dict = ms.load_checkpoint(ckpt_file_name,
|
|
800
|
+
... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
|
|
504
801
|
>>> print(param_dict["conv2.weight"])
|
|
505
802
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
803
|
+
>>> def func(param_name):
|
|
804
|
+
>>> whether_load = False
|
|
805
|
+
>>> if param_name.startswith("conv"):
|
|
806
|
+
>>> whether_load = True
|
|
807
|
+
>>> if param_name.startswith("conv1"):
|
|
808
|
+
>>> whether_load = False
|
|
809
|
+
>>> return whether_load
|
|
810
|
+
>>> param_dict1 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
811
|
+
>>> print(param_dict1["conv2.weight"])
|
|
812
|
+
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
|
813
|
+
>>> def func(param_name):
|
|
814
|
+
>>> whether_load = False
|
|
815
|
+
>>> if param_name.startswith("conv1"):
|
|
816
|
+
>>> whether_load = True
|
|
817
|
+
>>> return whether_load
|
|
818
|
+
>>> param_dict2 = ms.load_checkpoint(ckpt_file_name, choice_func=func)
|
|
819
|
+
>>> print(param_dict2)
|
|
820
|
+
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
|
|
506
821
|
"""
|
|
507
822
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
|
508
823
|
specify_prefix = _check_prefix(specify_prefix)
|
|
@@ -515,15 +830,27 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
515
830
|
parameter_dict = {}
|
|
516
831
|
try:
|
|
517
832
|
param_data_list = []
|
|
833
|
+
if specify_prefix:
|
|
834
|
+
logger.warning("For load_checkpoint, this parameter `specity_prefix` will be deprecated, "
|
|
835
|
+
"please use `choice_func` instead.")
|
|
836
|
+
if filter_prefix:
|
|
837
|
+
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
838
|
+
"please use `choice_func` instead.")
|
|
518
839
|
for element_id, element in enumerate(checkpoint_list.value):
|
|
519
840
|
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
520
841
|
continue
|
|
842
|
+
if specify_prefix is None and filter_prefix is None and \
|
|
843
|
+
choice_func is not None and not choice_func(element.tag):
|
|
844
|
+
continue
|
|
845
|
+
if element.tensor.ByteSize() == 0:
|
|
846
|
+
_load_mapparameter(element, parameter_dict)
|
|
847
|
+
continue
|
|
521
848
|
data = element.tensor.tensor_content
|
|
522
849
|
data_type = element.tensor.tensor_type
|
|
523
850
|
np_type = tensor_to_np_type.get(data_type)
|
|
524
851
|
ms_type = tensor_to_ms_type[data_type]
|
|
525
852
|
if data_type == 'str':
|
|
526
|
-
str_length = int(len(data)/4)
|
|
853
|
+
str_length = int(len(data) / 4)
|
|
527
854
|
np_type = np_type + str(str_length)
|
|
528
855
|
element_data = np.frombuffer(data, np_type)
|
|
529
856
|
param_data_list.append(element_data)
|
|
@@ -548,7 +875,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
548
875
|
except BaseException as e:
|
|
549
876
|
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
550
877
|
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', "
|
|
551
|
-
|
|
878
|
+
"failed to load the checkpoint file {}.".format(ckpt_file_name)) from e
|
|
552
879
|
|
|
553
880
|
if not parameter_dict:
|
|
554
881
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -570,7 +897,7 @@ def _check_ckpt_file_name(ckpt_file_name):
|
|
|
570
897
|
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
|
|
571
898
|
"input the correct 'ckpt_file_name'.")
|
|
572
899
|
|
|
573
|
-
ckpt_file_name = os.path.
|
|
900
|
+
ckpt_file_name = os.path.abspath(ckpt_file_name)
|
|
574
901
|
if not os.path.exists(ckpt_file_name):
|
|
575
902
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
|
576
903
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
|
@@ -616,10 +943,10 @@ def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
|
|
616
943
|
except BaseException as e:
|
|
617
944
|
if _is_cipher_file(ckpt_file_name):
|
|
618
945
|
err_info = "Failed to read the checkpoint file {}. The file may be encrypted or tempered with, " \
|
|
619
|
-
|
|
946
|
+
"please pass in the correct 'dec_key' or check the file integrity.".format(ckpt_file_name)
|
|
620
947
|
else:
|
|
621
948
|
err_info = "Failed to read the checkpoint file {}. May not have permission to read it, please check" \
|
|
622
|
-
|
|
949
|
+
" the correct of the file.".format(ckpt_file_name)
|
|
623
950
|
logger.error(err_info)
|
|
624
951
|
raise ValueError(err_info) from e
|
|
625
952
|
return checkpoint_list
|
|
@@ -642,6 +969,20 @@ def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
|
|
642
969
|
return whether_load
|
|
643
970
|
|
|
644
971
|
|
|
972
|
+
def _load_mapparameter(element, parameter_dict):
|
|
973
|
+
"""Load map parameter from ckpt file."""
|
|
974
|
+
map_array = []
|
|
975
|
+
for tensor in element.maptensor.tensor:
|
|
976
|
+
data = tensor.tensor_content
|
|
977
|
+
data_type = tensor.tensor_type
|
|
978
|
+
np_type = tensor_to_np_type.get(data_type)
|
|
979
|
+
element_data = np.frombuffer(data, np_type)
|
|
980
|
+
dims = tensor.dims
|
|
981
|
+
param_data = element_data.reshape(list(dims))
|
|
982
|
+
map_array.append(param_data)
|
|
983
|
+
parameter_dict[element.tag] = map_array
|
|
984
|
+
|
|
985
|
+
|
|
645
986
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
646
987
|
"""
|
|
647
988
|
Load parameters into network, return parameter list that are not loaded in the network.
|
|
@@ -656,7 +997,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
656
997
|
on the parameters of the same type, such as float32 to float16. Default: False.
|
|
657
998
|
|
|
658
999
|
Returns:
|
|
659
|
-
List, the parameter name which are not loaded into the network.
|
|
1000
|
+
param_not_load (List), the parameter name in model which are not loaded into the network.
|
|
1001
|
+
ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.
|
|
660
1002
|
|
|
661
1003
|
Raises:
|
|
662
1004
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
|
@@ -667,7 +1009,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
667
1009
|
>>> net = Net()
|
|
668
1010
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
669
1011
|
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
670
|
-
>>> param_not_load = ms.load_param_into_net(net, param_dict)
|
|
1012
|
+
>>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
|
|
671
1013
|
>>> print(param_not_load)
|
|
672
1014
|
['conv1.weight']
|
|
673
1015
|
"""
|
|
@@ -692,10 +1034,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
692
1034
|
logger.info("Execute the process of loading parameters into net.")
|
|
693
1035
|
net.init_parameters_data()
|
|
694
1036
|
param_not_load = []
|
|
1037
|
+
ckpt_not_load = list(parameter_dict.keys())
|
|
695
1038
|
for _, param in net.parameters_and_names():
|
|
696
1039
|
if param.name in parameter_dict:
|
|
697
1040
|
new_param = copy.deepcopy(parameter_dict[param.name])
|
|
698
1041
|
_update_param(param, new_param, strict_load)
|
|
1042
|
+
ckpt_not_load.remove(param.name)
|
|
699
1043
|
else:
|
|
700
1044
|
param_not_load.append(param.name)
|
|
701
1045
|
|
|
@@ -714,7 +1058,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
714
1058
|
"when training and loading checkpoint.".format(len(param_not_load)))
|
|
715
1059
|
for param_name in param_not_load:
|
|
716
1060
|
logger.warning("{} is not loaded.".format(param_name))
|
|
717
|
-
return param_not_load
|
|
1061
|
+
return param_not_load, ckpt_not_load
|
|
718
1062
|
|
|
719
1063
|
|
|
720
1064
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load):
|
|
@@ -753,7 +1097,7 @@ def _save_graph(network, file_name):
|
|
|
753
1097
|
"""
|
|
754
1098
|
logger.info("Execute the process of saving graph.")
|
|
755
1099
|
|
|
756
|
-
file_name = os.path.
|
|
1100
|
+
file_name = os.path.abspath(file_name)
|
|
757
1101
|
graph_pb = network.get_func_graph_proto()
|
|
758
1102
|
if graph_pb:
|
|
759
1103
|
with open(file_name, "wb") as f:
|
|
@@ -761,7 +1105,7 @@ def _save_graph(network, file_name):
|
|
|
761
1105
|
f.write(graph_pb)
|
|
762
1106
|
|
|
763
1107
|
|
|
764
|
-
def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|
1108
|
+
def _get_merged_param_data(net, parameter_layout_dict, param_name, param_data, integrated_save):
|
|
765
1109
|
"""
|
|
766
1110
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
|
767
1111
|
|
|
@@ -773,7 +1117,7 @@ def _get_merged_param_data(net, param_name, param_data, integrated_save):
|
|
|
773
1117
|
Returns:
|
|
774
1118
|
Tensor, the combined tensor which with the whole data value.
|
|
775
1119
|
"""
|
|
776
|
-
layout =
|
|
1120
|
+
layout = parameter_layout_dict[param_name]
|
|
777
1121
|
if len(layout) < 6:
|
|
778
1122
|
logger.info("The layout dict does not contain the key %s", param_name)
|
|
779
1123
|
return param_data
|
|
@@ -839,19 +1183,22 @@ def _fill_param_into_net(net, parameter_list):
|
|
|
839
1183
|
else:
|
|
840
1184
|
parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name)
|
|
841
1185
|
|
|
842
|
-
load_param_into_net(net, parameter_dict)
|
|
1186
|
+
load_param_into_net(net, parameter_dict, strict_load=True)
|
|
843
1187
|
|
|
844
1188
|
|
|
845
|
-
def export(net, *inputs, file_name, file_format
|
|
1189
|
+
def export(net, *inputs, file_name, file_format, **kwargs):
|
|
846
1190
|
"""
|
|
847
1191
|
Export the MindSpore network into an offline model in the specified format.
|
|
848
1192
|
|
|
849
1193
|
Note:
|
|
850
1194
|
1. When exporting AIR, ONNX format, the size of a single tensor can not exceed 2GB.
|
|
851
1195
|
2. When file_name does not have a suffix, the system will automatically add one according to the file_format.
|
|
1196
|
+
3. Exporting functions decorated with 'jit' to mindir format is supported.
|
|
1197
|
+
4. When exporting a function decorated with 'jit', the function should not involve class properties in
|
|
1198
|
+
calculations.
|
|
852
1199
|
|
|
853
1200
|
Args:
|
|
854
|
-
net (Cell): MindSpore network.
|
|
1201
|
+
net (Union[Cell, function]): MindSpore network.
|
|
855
1202
|
inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]): It represents the inputs
|
|
856
1203
|
of the `net`, if the network has multiple inputs, set them together. While its type is Dataset,
|
|
857
1204
|
it represents the preprocess behavior of the `net`, data preprocess operations will be serialized.
|
|
@@ -859,7 +1206,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|
|
859
1206
|
the batch size of 'net' input. Only supports parse "image" column from dataset currently.
|
|
860
1207
|
file_name (str): File name of the model to be exported.
|
|
861
1208
|
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
|
|
862
|
-
Default: 'AIR'.
|
|
863
1209
|
|
|
864
1210
|
- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
|
|
865
1211
|
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
|
|
@@ -868,20 +1214,36 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|
|
868
1214
|
|
|
869
1215
|
kwargs (dict): Configuration options dictionary.
|
|
870
1216
|
|
|
871
|
-
- quant_mode (str): If the network is a quantization aware training network, the quant_mode should
|
|
872
|
-
be set to "QUANT", else the quant_mode should be set to "NONQUANT".
|
|
873
|
-
- mean (float): The mean of input data after preprocessing, used for quantizing the first layer of network.
|
|
874
|
-
Default: 127.5.
|
|
875
|
-
- std_dev (float): The variance of input data after preprocessing,
|
|
876
|
-
used for quantizing the first layer of the network. Default: 127.5.
|
|
877
1217
|
- enc_key (byte): Byte-type key used for encryption. The valid length is 16, 24, or 32.
|
|
878
1218
|
- enc_mode (Union[str, function]): Specifies the encryption mode, to take effect when enc_key is set.
|
|
879
1219
|
|
|
880
1220
|
- For 'AIR' and 'ONNX' models, only customized encryption is supported.
|
|
881
|
-
- For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC'
|
|
1221
|
+
- For 'MINDIR', all options are supported. Option: 'AES-GCM', 'AES-CBC', 'SM4-CBC'
|
|
1222
|
+
or Customized encryption.
|
|
882
1223
|
Default: 'AES-GCM'.
|
|
883
1224
|
- For details of using the customized encryption, please check the `tutorial
|
|
884
|
-
<https://mindspore.cn/mindarmour/docs/en/
|
|
1225
|
+
<https://mindspore.cn/mindarmour/docs/en/r2.0/model_encrypt_protection.html>`_.
|
|
1226
|
+
|
|
1227
|
+
- dataset (Dataset): Specifies the preprocessing method of the dataset, which is used to import the
|
|
1228
|
+
preprocessing of the dataset into MindIR.
|
|
1229
|
+
|
|
1230
|
+
- obf_config (dict): obfuscation config.
|
|
1231
|
+
|
|
1232
|
+
- type (str): The type of obfuscation, only 'dynamic' is supported until now.
|
|
1233
|
+
- obf_ratio (float, str): The ratio of nodes in original model that would be obfuscated. `obf_ratio`
|
|
1234
|
+
should be in range of (0, 1] or in ["small", "medium", "large"].
|
|
1235
|
+
- customized_func (function): A python function used for customized function mode, which used for control
|
|
1236
|
+
the switch branch of obfuscation structure. The outputs of customized_func should be boolean. This
|
|
1237
|
+
function needs to ensure that its result is constant for any input. Users can refer to opaque
|
|
1238
|
+
predicates. If customized_func is set, then it should be passed to `load()` interface when loading
|
|
1239
|
+
obfuscated model.
|
|
1240
|
+
- obf_random_seed (int): The random seed used for determine the distribution of confusion branches and the
|
|
1241
|
+
weight confusion coefficient, which should be in (0, 9223372036854775807]. If `obf_random_seed` is set,
|
|
1242
|
+
then it should be passed to :class:`nn.GraphCell()` interface when loading obfuscated model. It should
|
|
1243
|
+
be noted that at least one of `customized_func` or `obf_random_seed` should be set, and the latter mode
|
|
1244
|
+
would be applied if both of them are set.
|
|
1245
|
+
|
|
1246
|
+
- incremental (bool): export MindIR incrementally.
|
|
885
1247
|
|
|
886
1248
|
Examples:
|
|
887
1249
|
>>> import mindspore as ms
|
|
@@ -892,7 +1254,12 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|
|
892
1254
|
>>> input_tensor = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32))
|
|
893
1255
|
>>> ms.export(net, input_tensor, file_name='lenet', file_format='MINDIR')
|
|
894
1256
|
"""
|
|
1257
|
+
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
1258
|
+
if file_format not in supported_formats:
|
|
1259
|
+
raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
|
|
1260
|
+
Validator.check_file_name_by_regular(file_name)
|
|
895
1261
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
1262
|
+
|
|
896
1263
|
if check_input_dataset(*inputs, dataset_type=mindspore.dataset.Dataset):
|
|
897
1264
|
if len(inputs) != 1:
|
|
898
1265
|
raise RuntimeError(f"You can only serialize one dataset into MindIR, got " + str(len(inputs)) + " datasets")
|
|
@@ -910,15 +1277,10 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
|
|
910
1277
|
+ str(columns))
|
|
911
1278
|
inputs = tuple(inputs_col)
|
|
912
1279
|
|
|
913
|
-
|
|
914
|
-
file_name = os.path.realpath(file_name)
|
|
915
|
-
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
|
1280
|
+
file_name = os.path.abspath(file_name)
|
|
916
1281
|
if 'enc_key' in kwargs.keys():
|
|
917
|
-
enc_key, enc_mode = _check_key_mode_type(file_format, **kwargs)
|
|
918
|
-
|
|
919
|
-
_export(net, file_name, file_format, *inputs, enc_key=enc_key, enc_mode=enc_mode, dataset=dataset)
|
|
920
|
-
else:
|
|
921
|
-
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
1282
|
+
kwargs['enc_key'], kwargs['enc_mode'] = _check_key_mode_type(file_format, **kwargs)
|
|
1283
|
+
_export(net, file_name, file_format, *inputs, **kwargs)
|
|
922
1284
|
|
|
923
1285
|
|
|
924
1286
|
def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
@@ -926,19 +1288,8 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
926
1288
|
It is an internal conversion function. Export the MindSpore prediction model to a file in the specified format.
|
|
927
1289
|
"""
|
|
928
1290
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
logger.warning(f"For 'export', format 'GEIR' is deprecated, "
|
|
932
|
-
f"it would be removed in future release, use 'AIR' instead.")
|
|
933
|
-
file_format = 'AIR'
|
|
934
|
-
|
|
935
|
-
supported_formats = ['AIR', 'ONNX', 'MINDIR']
|
|
936
|
-
if file_format not in supported_formats:
|
|
937
|
-
raise ValueError(f"For 'export', 'file_format' must be one of {supported_formats}, but got {file_format}.")
|
|
938
|
-
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
939
|
-
is_dump_onnx_in_training = net.training and file_format == 'ONNX'
|
|
940
|
-
if is_dump_onnx_in_training:
|
|
941
|
-
net.set_train(mode=False)
|
|
1291
|
+
if "obf_config" in kwargs and file_format != "MINDIR":
|
|
1292
|
+
raise ValueError(f"Dynamic obfuscation only support for MindIR format, but got {file_format} format.")
|
|
942
1293
|
|
|
943
1294
|
if file_format == 'AIR':
|
|
944
1295
|
_save_air(net, file_name, *inputs, **kwargs)
|
|
@@ -947,9 +1298,6 @@ def _export(net, file_name, file_format, *inputs, **kwargs):
|
|
|
947
1298
|
elif file_format == 'MINDIR':
|
|
948
1299
|
_save_mindir(net, file_name, *inputs, **kwargs)
|
|
949
1300
|
|
|
950
|
-
if is_dump_onnx_in_training:
|
|
951
|
-
net.set_train(mode=True)
|
|
952
|
-
|
|
953
1301
|
|
|
954
1302
|
def _check_key_mode_type(file_format, **kwargs):
|
|
955
1303
|
"""check enc_key and enc_mode are valid"""
|
|
@@ -966,9 +1314,9 @@ def _check_key_mode_type(file_format, **kwargs):
|
|
|
966
1314
|
if file_format in ('AIR', 'ONNX'):
|
|
967
1315
|
raise ValueError(f"AIR/ONNX only support customized encryption, but got {enc_mode}.")
|
|
968
1316
|
|
|
969
|
-
if enc_mode in ('AES-CBC', 'AES-GCM'):
|
|
1317
|
+
if enc_mode in ('AES-CBC', 'AES-GCM', 'SM4-CBC'):
|
|
970
1318
|
return enc_key, enc_mode
|
|
971
|
-
raise ValueError(f"MindIR only support AES-GCM/AES-CBC encryption, but got {enc_mode}")
|
|
1319
|
+
raise ValueError(f"MindIR only support AES-GCM/AES-CBC/SM4-CBC encryption, but got {enc_mode}")
|
|
972
1320
|
|
|
973
1321
|
|
|
974
1322
|
def _save_air(net, file_name, *inputs, **kwargs):
|
|
@@ -980,7 +1328,7 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
980
1328
|
if os.path.exists(file_name):
|
|
981
1329
|
os.chmod(file_name, stat.S_IWUSR)
|
|
982
1330
|
if "/" in file_name:
|
|
983
|
-
real_path = os.path.
|
|
1331
|
+
real_path = os.path.abspath(file_name[:file_name.rfind("/")])
|
|
984
1332
|
os.makedirs(real_path, exist_ok=True)
|
|
985
1333
|
if 'enc_key' in kwargs.keys() and 'enc_mode' in kwargs.keys():
|
|
986
1334
|
_executor.export(file_name, graph_id, enc_key=kwargs.get('enc_key'), encrypt_func=kwargs.get('enc_mode'))
|
|
@@ -991,6 +1339,12 @@ def _save_air(net, file_name, *inputs, **kwargs):
|
|
|
991
1339
|
|
|
992
1340
|
def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
993
1341
|
"""Save ONNX format file."""
|
|
1342
|
+
# When dumping ONNX file, switch network mode to infer when it is training(NOTE: ONNX only designed for prediction)
|
|
1343
|
+
if not isinstance(net, nn.Cell):
|
|
1344
|
+
raise ValueError(f"Export ONNX format model only support nn.Cell object, but got {type(net)}.")
|
|
1345
|
+
_check_dynamic_input(inputs)
|
|
1346
|
+
cell_mode = net.training
|
|
1347
|
+
net.set_train(mode=False)
|
|
994
1348
|
total_size = _calculation_net_size(net)
|
|
995
1349
|
if total_size > PROTO_LIMIT_SIZE:
|
|
996
1350
|
raise RuntimeError('Export onnx model failed. Network size is: {}G, it exceeded the protobuf: {}G limit.'
|
|
@@ -1008,6 +1362,13 @@ def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
|
1008
1362
|
with open(file_name, 'wb') as f:
|
|
1009
1363
|
f.write(onnx_stream)
|
|
1010
1364
|
os.chmod(file_name, stat.S_IRUSR)
|
|
1365
|
+
net.set_train(mode=cell_mode)
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
def _check_dynamic_input(inputs):
|
|
1369
|
+
for ele in inputs:
|
|
1370
|
+
if isinstance(ele, Tensor) and -1 in ele.shape:
|
|
1371
|
+
raise ValueError(f"Export ONNX format model not support dynamic shape mode.")
|
|
1011
1372
|
|
|
1012
1373
|
|
|
1013
1374
|
def _generate_front_info_for_param_data_file(is_encrypt, kwargs):
|
|
@@ -1057,10 +1418,24 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
|
|
|
1057
1418
|
return f, parameter_size, offset
|
|
1058
1419
|
|
|
1059
1420
|
|
|
1060
|
-
def
|
|
1421
|
+
def _encrypt_data(is_encrypt, write_data, kwargs):
|
|
1422
|
+
"""Encrypt parameter data."""
|
|
1423
|
+
if is_encrypt():
|
|
1424
|
+
if callable(kwargs.get('enc_mode')):
|
|
1425
|
+
enc_func = kwargs.get('enc_mode')
|
|
1426
|
+
write_data = enc_func(write_data, kwargs.get('enc_key'))
|
|
1427
|
+
else:
|
|
1428
|
+
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
|
1429
|
+
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
|
1430
|
+
return write_data
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
1061
1434
|
"""The function to save parameter data."""
|
|
1062
1435
|
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
|
1063
1436
|
# save parameter
|
|
1437
|
+
if model.graph.map_parameter:
|
|
1438
|
+
raise ValueError("MapParameter not support save in split MindIR file now.")
|
|
1064
1439
|
file_prefix = file_name.split("/")[-1]
|
|
1065
1440
|
if file_prefix.endswith(".mindir"):
|
|
1066
1441
|
file_prefix = file_prefix[:-7]
|
|
@@ -1095,13 +1470,7 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1095
1470
|
param_proto.external_data.offset = offset
|
|
1096
1471
|
write_data = raw_data + bytes(append_size)
|
|
1097
1472
|
offset += (data_length + append_size)
|
|
1098
|
-
|
|
1099
|
-
if callable(kwargs.get('enc_mode')):
|
|
1100
|
-
enc_func = kwargs.get('enc_mode')
|
|
1101
|
-
write_data = enc_func(write_data, kwargs.get('enc_key'))
|
|
1102
|
-
else:
|
|
1103
|
-
write_data = _encrypt(write_data, len(write_data), kwargs.get('enc_key'),
|
|
1104
|
-
len(kwargs.get('enc_key')), kwargs.get('enc_mode'))
|
|
1473
|
+
write_data = _encrypt_data(is_encrypt, write_data, kwargs)
|
|
1105
1474
|
f.write(write_data)
|
|
1106
1475
|
|
|
1107
1476
|
graph_file_name = os.path.join(dirname, file_prefix + "_graph.mindir")
|
|
@@ -1124,18 +1493,74 @@ def _spilt_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1124
1493
|
os.chmod(data_file_name, stat.S_IRUSR)
|
|
1125
1494
|
|
|
1126
1495
|
|
|
1127
|
-
def
|
|
1128
|
-
"""
|
|
1129
|
-
|
|
1496
|
+
def _msfunc_info(net, *inputs):
|
|
1497
|
+
"""Get mindir stream and parameter dict of ms_function"""
|
|
1498
|
+
# pylint: disable=protected-access
|
|
1499
|
+
net_dict = OrderedDict()
|
|
1500
|
+
_ms_func_executor = _MindsporeFunctionExecutor(net, time.time() * 1e9)
|
|
1501
|
+
graph_id = _ms_func_executor.compile(net.__name__, *inputs)
|
|
1502
|
+
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
1503
|
+
params = _ms_func_executor._graph_executor.get_params(graph_id)
|
|
1504
|
+
for name, value in params.items():
|
|
1505
|
+
net_dict[name] = Parameter(value, name=name)
|
|
1506
|
+
return mindir_stream, net_dict
|
|
1130
1507
|
|
|
1131
|
-
phase_name = "predict" if net._auto_parallel_mode else "export.mindir"
|
|
1132
1508
|
|
|
1133
|
-
|
|
1134
|
-
|
|
1509
|
+
def _cell_info(net, incremental, *inputs):
|
|
1510
|
+
"""Get mindir stream and net dict of cell"""
|
|
1511
|
+
phase_name = "predict" if _is_in_auto_parallel_mode() else "export.mindir"
|
|
1512
|
+
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
1135
1513
|
# pylint: disable=protected-access
|
|
1136
|
-
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
1514
|
+
mindir_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir', incremental=incremental)
|
|
1515
|
+
# clean obfuscation config to prevent the next call
|
|
1516
|
+
_executor.obfuscate_config = None
|
|
1137
1517
|
|
|
1138
1518
|
net_dict = net.parameters_dict()
|
|
1519
|
+
return mindir_stream, net_dict
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
def _set_obfuscate_config(**kwargs):
|
|
1523
|
+
"""Set obfuscation config for executor."""
|
|
1524
|
+
logger.warning("Obfuscate model.")
|
|
1525
|
+
if 'enc_mode' in kwargs.keys():
|
|
1526
|
+
enc_mode = Validator.check_isinstance('enc_mode', kwargs.get('enc_mode'), str)
|
|
1527
|
+
if enc_mode not in ["AES-GCM", "AES-CBC", "SM4-CBC"]:
|
|
1528
|
+
raise ValueError(
|
|
1529
|
+
"Only MindIR files that encrypted with 'AES-GCM', 'AES-CBC' or 'SM4-CBC' is supported for"
|
|
1530
|
+
"obfuscation, but got {}.".format(enc_mode))
|
|
1531
|
+
obf_ratio, customized_funcs, obf_random_seed = _check_obfuscate_params(kwargs.get('obf_config'))
|
|
1532
|
+
if customized_funcs and obf_random_seed > 0:
|
|
1533
|
+
logger.warning("Although 'customized_func' and 'obf_random_seed' are set, the 'obf_random_seed' mode would be"
|
|
1534
|
+
" applied, remember to set 'obf_random_seed' when loading obfuscated model.")
|
|
1535
|
+
|
|
1536
|
+
if obf_random_seed == 0: # apply customized_func mode
|
|
1537
|
+
device_target = context.get_context('device_target')
|
|
1538
|
+
if device_target in ["GPU", "Ascend"]:
|
|
1539
|
+
raise ValueError(
|
|
1540
|
+
"Customized func mode only support 'device_target'='CPU, but got {}.".format(device_target))
|
|
1541
|
+
clean_funcs()
|
|
1542
|
+
for func in customized_funcs:
|
|
1543
|
+
add_opaque_predicate(func.__name__, func)
|
|
1544
|
+
_executor.obfuscate_config = {'obf_ratio': obf_ratio, 'obf_random_seed': obf_random_seed}
|
|
1545
|
+
|
|
1546
|
+
|
|
1547
|
+
def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
1548
|
+
"""Save MindIR format file."""
|
|
1549
|
+
# set obfuscate configs
|
|
1550
|
+
if 'obf_config' in kwargs.keys():
|
|
1551
|
+
_set_obfuscate_config(**kwargs)
|
|
1552
|
+
for item in inputs:
|
|
1553
|
+
if -1 in item.shape:
|
|
1554
|
+
raise ValueError(
|
|
1555
|
+
"Dynamic shape input is not supported now, but got the shape of inputs: {}.".format(item.shape))
|
|
1556
|
+
|
|
1557
|
+
incremental = kwargs.get('incremental', False)
|
|
1558
|
+
|
|
1559
|
+
model = mindir_model()
|
|
1560
|
+
if not isinstance(net, nn.Cell):
|
|
1561
|
+
mindir_stream, net_dict = _msfunc_info(net, *inputs)
|
|
1562
|
+
else:
|
|
1563
|
+
mindir_stream, net_dict = _cell_info(net, incremental, *inputs)
|
|
1139
1564
|
model.ParseFromString(mindir_stream)
|
|
1140
1565
|
|
|
1141
1566
|
if kwargs.get('dataset'):
|
|
@@ -1148,7 +1573,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|
|
1148
1573
|
if save_together:
|
|
1149
1574
|
_save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs)
|
|
1150
1575
|
else:
|
|
1151
|
-
|
|
1576
|
+
_split_save(net_dict, model, file_name, is_encrypt, **kwargs)
|
|
1152
1577
|
|
|
1153
1578
|
|
|
1154
1579
|
def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
@@ -1159,9 +1584,20 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
1159
1584
|
param_data = net_dict[param_name].data.asnumpy().tobytes()
|
|
1160
1585
|
param_proto.raw_data = param_data
|
|
1161
1586
|
else:
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1587
|
+
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1588
|
+
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
1589
|
+
incremental = kwargs.get('incremental', False)
|
|
1590
|
+
for map_param_proto in model.graph.map_parameter:
|
|
1591
|
+
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
|
1592
|
+
if map_param_name in net_dict.keys():
|
|
1593
|
+
map_parameter = net_dict[map_param_name]
|
|
1594
|
+
key_nparr, value_nparr, status_nparr = map_parameter.export_data(incremental)
|
|
1595
|
+
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
|
1596
|
+
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
|
1597
|
+
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
|
1598
|
+
else:
|
|
1599
|
+
raise ValueError("The map_parameter '{}' is not belongs to any cell,"
|
|
1600
|
+
"the data of parameter cannot be exported.".format(map_param_proto.name))
|
|
1165
1601
|
if not file_name.endswith('.mindir'):
|
|
1166
1602
|
file_name += ".mindir"
|
|
1167
1603
|
current_path = os.path.abspath(file_name)
|
|
@@ -1191,9 +1627,8 @@ def _save_together(net_dict, model):
|
|
|
1191
1627
|
if name in net_dict.keys():
|
|
1192
1628
|
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
|
|
1193
1629
|
else:
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
.format(param_proto.name))
|
|
1630
|
+
raise ValueError("The parameter '{}' is not belongs to any cell,"
|
|
1631
|
+
"the data of parameter cannot be exported.".format(param_proto.name))
|
|
1197
1632
|
if data_total > TOTAL_SAVE:
|
|
1198
1633
|
return False
|
|
1199
1634
|
return True
|
|
@@ -1214,67 +1649,11 @@ def _save_dataset_to_mindir(model, dataset):
|
|
|
1214
1649
|
model.preprocessor.op.add()
|
|
1215
1650
|
model.preprocessor.op[-1].input_columns = json.dumps(op['input_columns'])
|
|
1216
1651
|
model.preprocessor.op[-1].output_columns = json.dumps(op['output_columns'])
|
|
1217
|
-
model.preprocessor.op[-1].project_columns = json.dumps(op['project_columns'])
|
|
1218
1652
|
model.preprocessor.op[-1].op_type = json.dumps(op['op_type'])
|
|
1219
1653
|
model.preprocessor.op[-1].operations = json.dumps(op['operations'])
|
|
1220
1654
|
model.preprocessor.op[-1].offload = op['offload'] if 'offload' in op.keys() else False
|
|
1221
1655
|
|
|
1222
1656
|
|
|
1223
|
-
def quant_mode_manage(func):
|
|
1224
|
-
"""Inherit the quant_mode in old version."""
|
|
1225
|
-
@functools.wraps(func)
|
|
1226
|
-
def warpper(network, *inputs, file_format, **kwargs):
|
|
1227
|
-
if 'quant_mode' not in kwargs:
|
|
1228
|
-
return network
|
|
1229
|
-
quant_mode = kwargs.get('quant_mode')
|
|
1230
|
-
if not isinstance(quant_mode, str):
|
|
1231
|
-
raise TypeError("For 'export', the type of 'quant_mode' should be string, "
|
|
1232
|
-
"but got {}.".format(type(quant_mode)))
|
|
1233
|
-
if quant_mode in ('AUTO', 'MANUAL'):
|
|
1234
|
-
kwargs['quant_mode'] = 'QUANT'
|
|
1235
|
-
return func(network, *inputs, file_format=file_format, **kwargs)
|
|
1236
|
-
|
|
1237
|
-
return warpper
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
@quant_mode_manage
|
|
1241
|
-
def _quant_export(network, *inputs, file_format, **kwargs):
|
|
1242
|
-
"""Exports MindSpore quantization predict model to deploy with AIR and MINDIR."""
|
|
1243
|
-
supported_device = ["Ascend", "GPU"]
|
|
1244
|
-
supported_formats = ['AIR', 'MINDIR']
|
|
1245
|
-
quant_mode_formats = ['QUANT', 'NONQUANT']
|
|
1246
|
-
|
|
1247
|
-
quant_mode = kwargs['quant_mode']
|
|
1248
|
-
if quant_mode not in quant_mode_formats:
|
|
1249
|
-
raise KeyError(f"For 'export', the argument 'quant_mode' must be one of {quant_mode_formats}, "
|
|
1250
|
-
f"but got {quant_mode}.")
|
|
1251
|
-
if quant_mode == 'NONQUANT':
|
|
1252
|
-
return network
|
|
1253
|
-
quant_net = copy.deepcopy(network)
|
|
1254
|
-
quant_net._create_time = int(time.time() * 1e9)
|
|
1255
|
-
|
|
1256
|
-
mean = 127.5 if kwargs.get('mean', None) is None else kwargs.get('mean')
|
|
1257
|
-
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs.get('std_dev')
|
|
1258
|
-
mean = Validator.check_value_type("mean", mean, (int, float))
|
|
1259
|
-
std_dev = Validator.check_value_type("std_dev", std_dev, (int, float))
|
|
1260
|
-
|
|
1261
|
-
if context.get_context('device_target') not in supported_device:
|
|
1262
|
-
raise KeyError(f"For 'export', quant export only support {supported_device} device target now, "
|
|
1263
|
-
f"but got {context.get_context('device_target')}")
|
|
1264
|
-
|
|
1265
|
-
if file_format not in supported_formats:
|
|
1266
|
-
raise ValueError(f"For 'export', quant export only support 'file_format' {supported_formats}, "
|
|
1267
|
-
f"but got {file_format}.")
|
|
1268
|
-
|
|
1269
|
-
quant_net.set_train(False)
|
|
1270
|
-
if file_format == "MINDIR":
|
|
1271
|
-
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
|
|
1272
|
-
else:
|
|
1273
|
-
exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
|
|
1274
|
-
deploy_net = exporter.run()
|
|
1275
|
-
return deploy_net
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
1657
|
def parse_print(print_file_name):
|
|
1279
1658
|
"""
|
|
1280
1659
|
Parse data file generated by mindspore.ops.Print.
|
|
@@ -1316,7 +1695,7 @@ def parse_print(print_file_name):
|
|
|
1316
1695
|
[[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00],
|
|
1317
1696
|
[ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])]
|
|
1318
1697
|
"""
|
|
1319
|
-
print_file_path = os.path.
|
|
1698
|
+
print_file_path = os.path.abspath(print_file_name)
|
|
1320
1699
|
|
|
1321
1700
|
if os.path.getsize(print_file_path) == 0:
|
|
1322
1701
|
raise ValueError("For 'parse_print', the print file may be empty, please make sure enter the correct "
|
|
@@ -1490,7 +1869,8 @@ def build_searched_strategy(strategy_filename):
|
|
|
1490
1869
|
"""
|
|
1491
1870
|
Build strategy of every parameter in network. Used in the case of distributed inference.
|
|
1492
1871
|
For details of it, please check:
|
|
1493
|
-
|
|
1872
|
+
`Saving and Loading Models in Hybrid Parallel Mode
|
|
1873
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1494
1874
|
|
|
1495
1875
|
Args:
|
|
1496
1876
|
strategy_filename (str): Name of strategy file.
|
|
@@ -1512,7 +1892,7 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
1512
1892
|
"""
|
|
1513
1893
|
Merge parameter slices into one parameter. Used in the case of distributed inference.
|
|
1514
1894
|
For details of it, please check:
|
|
1515
|
-
`<https://www.mindspore.cn/tutorials/experts/en/
|
|
1895
|
+
`<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/save_load.html>`_.
|
|
1516
1896
|
|
|
1517
1897
|
Args:
|
|
1518
1898
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank id.
|
|
@@ -1608,7 +1988,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
1608
1988
|
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
|
|
1609
1989
|
For details of distributed inference, please check:
|
|
1610
1990
|
`Distributed Inference
|
|
1611
|
-
<https://www.mindspore.cn/tutorials/experts/en/
|
|
1991
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/distributed_inference.html>`_ .
|
|
1612
1992
|
|
|
1613
1993
|
Args:
|
|
1614
1994
|
network (Cell): Network for distributed predication.
|
|
@@ -1627,7 +2007,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
1627
2007
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
|
|
1628
2008
|
is not required. Default: None.
|
|
1629
2009
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
|
1630
|
-
mode, currently supports 'AES-GCM' and '
|
|
2010
|
+
mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM'.
|
|
1631
2011
|
|
|
1632
2012
|
Raises:
|
|
1633
2013
|
TypeError: The type of inputs do not match the requirements.
|
|
@@ -1720,8 +2100,8 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|
|
1720
2100
|
logger.critical("Failed to load opt shard slice in load distributed checkpoint for {}. Data shape is {}"
|
|
1721
2101
|
" and group is {}".format(param.name, split_param.data.shape, opt_shard_group))
|
|
1722
2102
|
raise RuntimeError(e.__str__() + f"\nFor 'load_distributed_checkpoint', failed to load opt shard slice"
|
|
1723
|
-
|
|
1724
|
-
|
|
2103
|
+
f" in load distributed checkpoint for {param.name}. Data shape is "
|
|
2104
|
+
f"{split_param.data.shape} and group is {opt_shard_group}.") from e
|
|
1725
2105
|
split_param = Parameter(Tensor(data_slice), param.name,
|
|
1726
2106
|
split_param.requires_grad, split_param.layerwise_parallel)
|
|
1727
2107
|
param_dict[param.name] = split_param
|
|
@@ -1843,7 +2223,7 @@ def _get_mindir_inputs(file_name):
|
|
|
1843
2223
|
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
|
1844
2224
|
"""
|
|
1845
2225
|
Validator.check_file_name_by_regular(file_name)
|
|
1846
|
-
file_name = os.path.
|
|
2226
|
+
file_name = os.path.abspath(file_name)
|
|
1847
2227
|
model = read_proto(file_name)
|
|
1848
2228
|
input_tensor = []
|
|
1849
2229
|
|
|
@@ -1874,8 +2254,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
1874
2254
|
"""
|
|
1875
2255
|
Convert mindir model to other format model. Current version only support convert to "ONNX" format.
|
|
1876
2256
|
|
|
1877
|
-
|
|
1878
|
-
This is an experimental
|
|
2257
|
+
.. warning::
|
|
2258
|
+
This is an experimental API that is subject to change or deletion.
|
|
1879
2259
|
|
|
1880
2260
|
Args:
|
|
1881
2261
|
mindir_file (str): MindIR file name.
|