mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -15,50 +15,142 @@
|
|
|
15
15
|
"""obfuscate network based on rewrite interfaces."""
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
|
-
import secrets
|
|
19
18
|
from pathlib import Path
|
|
19
|
+
from string import Template
|
|
20
|
+
import numpy as np
|
|
20
21
|
|
|
22
|
+
import mindspore as ms
|
|
21
23
|
from mindspore import ops, nn
|
|
22
|
-
from mindspore
|
|
23
|
-
from mindspore import
|
|
24
|
-
from mindspore import
|
|
25
|
-
from mindspore.rewrite import
|
|
26
|
-
from mindspore.rewrite.parsers
|
|
27
|
-
from mindspore.rewrite.parsers.class_def_parser import ModuleParser
|
|
24
|
+
from mindspore import load_checkpoint, save_checkpoint, log
|
|
25
|
+
from mindspore.ops import functional as F
|
|
26
|
+
from mindspore.rewrite import SymbolTree, Node, NodeType, ScopedValue
|
|
27
|
+
from mindspore.rewrite.parsers import ClassDefParser
|
|
28
|
+
from mindspore.rewrite.parsers import ModuleParser
|
|
28
29
|
|
|
29
30
|
OBF_RATIOS_LENGTH = 1
|
|
30
31
|
MAX_OBF_RATIOS_NUM = 50
|
|
31
32
|
OBF_RATIOS_WIDTH = 0
|
|
32
|
-
|
|
33
|
+
|
|
34
|
+
_supported_ops = {
|
|
35
|
+
'mul': ops.Mul,
|
|
36
|
+
'matmul': ops.MatMul,
|
|
37
|
+
'invert': ops.Inv
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
_supported_config_type = [
|
|
41
|
+
'obf_metadata_config',
|
|
42
|
+
'weight_obf_config',
|
|
43
|
+
'network_obf_config'
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
_supported_metadata_type = [
|
|
47
|
+
'random',
|
|
48
|
+
'rearrange'
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
obf_medatadata_template = {
|
|
52
|
+
'name': 'obf_metadata',
|
|
53
|
+
'shape': [1,],
|
|
54
|
+
'type': 'random',
|
|
55
|
+
'save_metadata': True,
|
|
56
|
+
'metadata_op': 'invert'
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
weight_obf_template = {
|
|
60
|
+
'target': '',
|
|
61
|
+
'weight_obf_ops': [{'name': 'mul', 'input_x': 'weight', 'input_y': 'obf_metadata'}]
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
network_obf_template = {
|
|
65
|
+
'module': '',
|
|
66
|
+
'target': '',
|
|
67
|
+
'insert_new_input': [{'name': 'obf_metadata'}],
|
|
68
|
+
'insert_ops': [{'name': 'mul', 'input_x': 'weight', 'input_y': 'obf_metadata'}]
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _transform_target_modules(target_modules):
|
|
73
|
+
"""transform target_modules to obf config"""
|
|
74
|
+
obf_config = {}
|
|
75
|
+
path = target_modules[0]
|
|
76
|
+
target_list = target_modules[1].split('|')
|
|
77
|
+
max_layers = 12
|
|
78
|
+
layers = []
|
|
79
|
+
obf_medatadata = obf_medatadata_template.copy()
|
|
80
|
+
if len(target_modules) >= 3:
|
|
81
|
+
obfuscate_layers = target_modules[2].split(':')
|
|
82
|
+
if obfuscate_layers[1] != 'all':
|
|
83
|
+
max_layers = int(obfuscate_layers[1])
|
|
84
|
+
layers = [i for i in range(0, max_layers)]
|
|
85
|
+
path_new = path.replace("blocks", "blocks/${layer}")
|
|
86
|
+
network_obf_template['insert_ops'][0]['input_y'] = "obf_metadata_${layer}"
|
|
87
|
+
weight_obf_template['weight_obf_ops'][0]['input_y'] = "obf_metadata_${layer}"
|
|
88
|
+
weight_obf_template['name'] = "obf_metadata_${layer}"
|
|
89
|
+
obf_medatadata['layers'] = layers
|
|
90
|
+
else:
|
|
91
|
+
path_new = path
|
|
92
|
+
obf_config['obf_metadata_config'] = []
|
|
93
|
+
obf_config['weight_obf_config'] = []
|
|
94
|
+
obf_config['network_obf_config'] = []
|
|
95
|
+
obf_config['obf_metadata_config'].append(obf_medatadata)
|
|
96
|
+
|
|
97
|
+
for name in target_list:
|
|
98
|
+
target_weight = path_new + '/' + name + '/weight'
|
|
99
|
+
target_bias = path_new + '/' + name + '/bias'
|
|
100
|
+
weight_obf = weight_obf_template.copy()
|
|
101
|
+
weight_obf['target'] = target_weight
|
|
102
|
+
bias_obf = weight_obf_template.copy()
|
|
103
|
+
bias_obf['target'] = target_bias
|
|
104
|
+
network_obf = network_obf_template.copy()
|
|
105
|
+
network_obf['module'] = '/' + path_new
|
|
106
|
+
network_obf['target'] = name
|
|
107
|
+
if not layers:
|
|
108
|
+
weight_obf['layers'] = layers
|
|
109
|
+
bias_obf['layers'] = layers
|
|
110
|
+
network_obf['layers'] = layers
|
|
111
|
+
obf_config['weight_obf_config'].append(weight_obf)
|
|
112
|
+
obf_config['weight_obf_config'].append(bias_obf)
|
|
113
|
+
obf_config['network_obf_config'].append(network_obf)
|
|
114
|
+
return obf_config
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _get_op(op_name):
|
|
118
|
+
if op_name is None:
|
|
119
|
+
return None
|
|
120
|
+
if op_name not in _supported_ops:
|
|
121
|
+
raise KeyError(f"'op name' must be in {list(_supported_ops.keys())}, but got {op_name}.")
|
|
122
|
+
return _supported_ops[op_name]()
|
|
33
123
|
|
|
34
124
|
|
|
35
|
-
def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
|
|
125
|
+
def obfuscate_ckpt(network, ckpt_files, target_modules=None, obf_config=None, saved_path='./', obfuscate_scale=100):
|
|
36
126
|
"""
|
|
37
|
-
|
|
38
|
-
:func:`mindspore.load_obf_params_into_net`.
|
|
39
|
-
interface.
|
|
127
|
+
Obfuscate the plaintext checkpoint files according to the obfuscation config.
|
|
40
128
|
|
|
41
129
|
Args:
|
|
42
130
|
network (nn.Cell): The original network that need to be obfuscated.
|
|
43
131
|
ckpt_files (str): The directory path of original ckpt files.
|
|
44
|
-
target_modules (list[str]): The target
|
|
45
|
-
represents the network path of target
|
|
46
|
-
The second string represents the
|
|
47
|
-
example,
|
|
48
|
-
If target_modules has the third value, it should be in the
|
|
49
|
-
'obfuscate_layers:int', which represents the number of layers
|
|
50
|
-
(such as transformer layers or resnet blocks).
|
|
51
|
-
|
|
52
|
-
|
|
132
|
+
target_modules (list[str]): The target ops that need to be obfuscated in the network. The first string
|
|
133
|
+
represents the network path of the target ops in the original network, which should be in form of
|
|
134
|
+
``"A/B/C"``. The second string represents the names of multiple target ops in the same path, which
|
|
135
|
+
should be in form of ``"D|E|F"``. For example, the target_modules of GPT2 can be ``['backbone/blocks
|
|
136
|
+
/attention', 'dense1|dense2|dense3']``. If target_modules has the third value, it should be in the
|
|
137
|
+
format of 'obfuscate_layers:all' or 'obfuscate_layers:int', which represents the number of layers
|
|
138
|
+
need to be obfuscated of duplicate layers (such as transformer layers or resnet blocks).
|
|
139
|
+
Default: ``None``.
|
|
140
|
+
obf_config (dict): The configuration of model obfuscation polices. Default: ``None``.
|
|
53
141
|
saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
|
|
54
142
|
obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
|
|
55
143
|
range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
|
|
56
144
|
|
|
145
|
+
Returns:
|
|
146
|
+
dict[str], obf_metadata, which is the necessary data that needs to be load when running obfuscated network.
|
|
147
|
+
|
|
57
148
|
Raises:
|
|
58
149
|
TypeError: If `network` is not nn.Cell.
|
|
59
150
|
TypeError: If `ckpt_files` is not string or `saved_path` is not string.
|
|
60
151
|
TypeError: If `target_modules` is not list.
|
|
61
152
|
TypeError: If target_modules's elements are not string.
|
|
153
|
+
TypeError: If obf_config is not dict.
|
|
62
154
|
ValueError: If `ckpt_files` is not exist or `saved_path` is not exist.
|
|
63
155
|
ValueError: If the number of elements of `target_modules` is less than ``2``.
|
|
64
156
|
ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
|
|
@@ -68,54 +160,91 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', ob
|
|
|
68
160
|
ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
|
|
69
161
|
'obfuscate_layers:int'.
|
|
70
162
|
|
|
71
|
-
Returns:
|
|
72
|
-
list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
|
|
73
|
-
|
|
74
163
|
Examples:
|
|
75
164
|
>>> from mindspore import obfuscate_ckpt, save_checkpoint
|
|
76
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
165
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
77
166
|
>>> net = LeNet5()
|
|
78
167
|
>>> save_checkpoint(net, './test_net.ckpt')
|
|
79
168
|
>>> target_modules = ['', 'fc1|fc2']
|
|
80
|
-
>>> obfuscate_ckpt(net,
|
|
169
|
+
>>> obfuscate_ckpt(net, './', target_modules=target_modules, saved_path='./')
|
|
81
170
|
"""
|
|
171
|
+
def _gen_obfuscate_tensor(tensor_shape, tensor_type='rearrange'):
|
|
172
|
+
obf_tensor = None
|
|
173
|
+
if tensor_type == 'rearrange':
|
|
174
|
+
if len(tensor_shape) == 1:
|
|
175
|
+
obf_tensor = ms.Tensor(np.random.permutation(tensor_shape[0]), dtype=ms.int32)
|
|
176
|
+
if len(tensor_shape) == 2:
|
|
177
|
+
tensor = ms.Tensor(np.identity(tensor_shape[0]), dtype=ms.int32)
|
|
178
|
+
p = ms.Tensor(np.random.permutation(tensor_shape[1]), dtype=ms.int32)
|
|
179
|
+
obf_tensor = tensor[:, p]
|
|
180
|
+
if tensor_type == 'random':
|
|
181
|
+
obf_tensor = ms.Tensor(np.random.randint(1, obfuscate_scale, size=tensor_shape), dtype=ms.float16)
|
|
182
|
+
return obf_tensor
|
|
183
|
+
|
|
184
|
+
def _gen_obf_metadata(config):
|
|
185
|
+
name = config.get('name')
|
|
186
|
+
if name is None:
|
|
187
|
+
return False
|
|
188
|
+
save_metadata = config.get('save_metadata', False)
|
|
189
|
+
metadata_op_name = config.get('metadata_op')
|
|
190
|
+
layers = config.get('layers')
|
|
191
|
+
if not layers:
|
|
192
|
+
if not obf_metadata.get(name):
|
|
193
|
+
obf_tensor = _gen_obfuscate_tensor(config.get('shape'), config.get('type'))
|
|
194
|
+
obf_metadata[name] = obf_tensor
|
|
195
|
+
if save_metadata:
|
|
196
|
+
saved_obf_tensor = obf_tensor
|
|
197
|
+
if metadata_op_name is not None:
|
|
198
|
+
metadata_op = _get_op(metadata_op_name)
|
|
199
|
+
saved_obf_tensor = metadata_op(saved_obf_tensor)
|
|
200
|
+
if saved_obf_tensor is not None:
|
|
201
|
+
saved_metadata[name] = saved_obf_tensor.asnumpy()
|
|
202
|
+
else:
|
|
203
|
+
for layer in layers:
|
|
204
|
+
strTemplate = Template(name)
|
|
205
|
+
obf_name = strTemplate.safe_substitute({"layer": str(layer)})
|
|
206
|
+
obf_tensor = _gen_obfuscate_tensor(config.get('shape'), config.get('type'))
|
|
207
|
+
obf_metadata[obf_name] = obf_tensor
|
|
208
|
+
if save_metadata:
|
|
209
|
+
saved_obf_tensor = obf_tensor
|
|
210
|
+
if metadata_op_name is not None:
|
|
211
|
+
metadata_op = _get_op(metadata_op_name)
|
|
212
|
+
saved_obf_tensor = metadata_op(saved_obf_tensor)
|
|
213
|
+
if saved_obf_tensor is not None:
|
|
214
|
+
saved_metadata[obf_name] = saved_obf_tensor.asnumpy()
|
|
215
|
+
return True
|
|
216
|
+
|
|
82
217
|
if not isinstance(network, nn.Cell):
|
|
83
218
|
raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
|
|
84
219
|
_check_dir_path('ckpt_files', ckpt_files)
|
|
85
220
|
_check_dir_path('saved_path', saved_path)
|
|
86
|
-
|
|
87
|
-
if
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if not
|
|
94
|
-
raise
|
|
95
|
-
|
|
221
|
+
|
|
222
|
+
if obf_config is None:
|
|
223
|
+
if not _check_valid_target(network, target_modules):
|
|
224
|
+
raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
|
|
225
|
+
log.warning("'target_modules and obf_ratios' will be deprecated and "
|
|
226
|
+
"removed in a future version, use 'obf_config' instead.")
|
|
227
|
+
obf_config = _transform_target_modules(target_modules)
|
|
228
|
+
if not isinstance(obf_config, dict):
|
|
229
|
+
raise TypeError("obf_config type should be dict, but got {}.".format(type(obf_config)))
|
|
230
|
+
if not obf_config or not _check_valid_obf_config(obf_config, 'obf_metadata_config')\
|
|
231
|
+
or not _check_valid_obf_config(obf_config, 'weight_obf_config'):
|
|
232
|
+
raise ValueError("'obf_config' is empty or not valid, please check the input.")
|
|
233
|
+
obf_metadata = {}
|
|
234
|
+
obf_metadata_config = obf_config.get('obf_metadata_config', [])
|
|
235
|
+
saved_metadata = {}
|
|
236
|
+
for config in obf_metadata_config:
|
|
237
|
+
_gen_obf_metadata(config)
|
|
96
238
|
if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
|
|
97
239
|
raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
|
|
98
240
|
.format(obfuscate_scale))
|
|
99
|
-
# generate and save obf_ratios to saved_path
|
|
100
|
-
path_list = to_split_modules[0].split('/')
|
|
101
|
-
target_list = to_split_modules[1].split('|')
|
|
102
|
-
global OBF_RATIOS_LENGTH
|
|
103
|
-
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
104
|
-
if number_of_ratios > MAX_OBF_RATIOS_NUM:
|
|
105
|
-
OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
|
|
106
|
-
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
107
|
-
obf_ratios = []
|
|
108
|
-
secrets_generator = secrets.SystemRandom()
|
|
109
|
-
for _ in range(number_of_ratios):
|
|
110
|
-
secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
|
|
111
|
-
obf_ratios.append(secure_float)
|
|
112
241
|
# start obfuscate ckpt
|
|
113
242
|
ckpt_dir_files = os.listdir(ckpt_files)
|
|
114
243
|
for ckpt_name in ckpt_dir_files:
|
|
115
|
-
sub_path = os.path.
|
|
244
|
+
sub_path = os.path.realpath(ckpt_files) + '/' + ckpt_name
|
|
116
245
|
if Path(sub_path).is_dir():
|
|
117
246
|
sub_ckpt_file_list = os.listdir(sub_path)
|
|
118
|
-
new_saved_path = os.path.
|
|
247
|
+
new_saved_path = os.path.realpath(saved_path) + '/' + ckpt_name
|
|
119
248
|
if not os.path.exists(new_saved_path):
|
|
120
249
|
try:
|
|
121
250
|
os.mkdir(new_saved_path, mode=0o700)
|
|
@@ -124,71 +253,148 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', ob
|
|
|
124
253
|
for sub_ckpt_name in sub_ckpt_file_list:
|
|
125
254
|
if not sub_ckpt_name.endswith('.ckpt'):
|
|
126
255
|
continue
|
|
127
|
-
_obfuscate_single_ckpt(os.path.
|
|
128
|
-
|
|
256
|
+
_obfuscate_single_ckpt(os.path.realpath(sub_path) + '/' + sub_ckpt_name, obf_metadata,
|
|
257
|
+
obf_config, new_saved_path)
|
|
129
258
|
else:
|
|
130
259
|
if not ckpt_name.endswith('.ckpt'):
|
|
131
260
|
continue
|
|
132
|
-
_obfuscate_single_ckpt(os.path.
|
|
133
|
-
|
|
134
|
-
return
|
|
261
|
+
_obfuscate_single_ckpt(os.path.realpath(ckpt_files) + '/' + ckpt_name,
|
|
262
|
+
obf_metadata, obf_config, saved_path)
|
|
263
|
+
return saved_metadata
|
|
135
264
|
|
|
136
265
|
|
|
137
|
-
def _obfuscate_single_ckpt(ckpt_name,
|
|
266
|
+
def _obfuscate_single_ckpt(ckpt_name, obf_metadata, obf_config, saved_path):
|
|
138
267
|
"""Obfuscate single ckpt file"""
|
|
139
|
-
|
|
268
|
+
def _get_op_input_name(obf_op, name_key='input_x', layer=0):
|
|
269
|
+
op_name = obf_op.get('name')
|
|
270
|
+
input_name = obf_op.get(name_key)
|
|
271
|
+
if input_name is None:
|
|
272
|
+
log.error("can not find input: {} for op: {}.".format(name_key, op_name))
|
|
273
|
+
return None
|
|
274
|
+
strTemplate = Template(input_name)
|
|
275
|
+
input_name = strTemplate.safe_substitute({"layer": str(layer)})
|
|
276
|
+
return input_name
|
|
277
|
+
|
|
278
|
+
def _get_op_input(input_name, obf_param):
|
|
279
|
+
op_input = obf_metadata.get(input_name, None) if input_name.startswith('obf_metadata') else obf_param
|
|
280
|
+
return op_input
|
|
281
|
+
|
|
282
|
+
def _obfuscate_param(param, obf_metadata, obf_ops, layer=0):
|
|
283
|
+
param_dtype = F.dtype(param)
|
|
284
|
+
obf_param = param
|
|
285
|
+
for i in range(len(obf_ops)):
|
|
286
|
+
op_name = obf_ops[i].get('name')
|
|
287
|
+
if not isinstance(op_name, str):
|
|
288
|
+
raise TypeError('{} should be str type, but got {}'.format(op_name, type(op_name)))
|
|
289
|
+
if op_name == 'mul':
|
|
290
|
+
input_x = obf_param
|
|
291
|
+
input_y_name = _get_op_input_name(obf_ops[i], 'input_y', layer)
|
|
292
|
+
input_y = obf_metadata.get(input_y_name)
|
|
293
|
+
if input_x is None or input_y is None:
|
|
294
|
+
log.error("input_x or input_y is None")
|
|
295
|
+
return None
|
|
296
|
+
input_y = F.cast(input_y, param_dtype)
|
|
297
|
+
obf_param = ops.mul(input_x, input_y)
|
|
298
|
+
elif op_name == 'permuate':
|
|
299
|
+
input_x_name = _get_op_input_name(obf_ops[i], 'input_x', layer)
|
|
300
|
+
p = obf_metadata.get(input_x_name, None)
|
|
301
|
+
if p is None or obf_param is None:
|
|
302
|
+
log.error("input_x or param is None")
|
|
303
|
+
return None
|
|
304
|
+
obf_param = obf_param[p]
|
|
305
|
+
elif op_name == 'matmul':
|
|
306
|
+
input_x_name = _get_op_input_name(obf_ops[i], 'input_x', layer)
|
|
307
|
+
input_y_name = _get_op_input_name(obf_ops[i], 'input_y', layer)
|
|
308
|
+
input_x = _get_op_input(input_x_name, obf_param)
|
|
309
|
+
input_y = _get_op_input(input_y_name, obf_param)
|
|
310
|
+
if input_x is None or input_y is None:
|
|
311
|
+
log.error("the input_x or input_y of op: {} is None.".format(op_name))
|
|
312
|
+
return None
|
|
313
|
+
input_x = ops.transpose(input_x, (1, 0)) if obf_ops[i].get('transpose_a', False) else input_x
|
|
314
|
+
input_y = ops.transpose(input_y, (1, 0)) if obf_ops[i].get('transpose_b', False) else input_y
|
|
315
|
+
obf_param = ops.matmul(F.cast(input_x, param_dtype), F.cast(input_y, param_dtype))
|
|
316
|
+
else:
|
|
317
|
+
log.error("unsupported op, op must be matmul or permuate or mul, but got {}."
|
|
318
|
+
.format(op_name))
|
|
319
|
+
return None
|
|
320
|
+
return obf_param
|
|
321
|
+
|
|
140
322
|
try:
|
|
141
323
|
ckpt_param = load_checkpoint(ckpt_name)
|
|
142
324
|
except (ValueError, TypeError, OSError):
|
|
143
|
-
|
|
144
|
-
return
|
|
145
|
-
|
|
325
|
+
log.error("Load checkpoint failed for file {}.".format(ckpt_name))
|
|
326
|
+
return False
|
|
327
|
+
|
|
328
|
+
weight_obf_config = obf_config.get('weight_obf_config', [])
|
|
146
329
|
for item in ckpt_param:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
330
|
+
item_split = item.split('.')
|
|
331
|
+
param_path = '/'.join(item_split[:len(item_split)])
|
|
332
|
+
for obf_target in weight_obf_config:
|
|
333
|
+
if not isinstance(obf_target, dict):
|
|
334
|
+
raise TypeError('{} should be dict type, but got {}'.format(obf_target, type(obf_target)))
|
|
335
|
+
target = obf_target.get('target', None)
|
|
336
|
+
layers = obf_target.get('layers', [])
|
|
337
|
+
obf_ops = obf_target.get('weight_obf_ops', None)
|
|
338
|
+
if not target or not obf_ops:
|
|
339
|
+
raise KeyError("target or obf_ops is None.")
|
|
340
|
+
if not layers:
|
|
341
|
+
if target == param_path:
|
|
342
|
+
obf_param = _obfuscate_param(ckpt_param[item].value(), obf_metadata, obf_ops)
|
|
343
|
+
if obf_param is None:
|
|
344
|
+
log.error("obfuscate weight {} failed.".format(item))
|
|
345
|
+
return False
|
|
346
|
+
ckpt_param[item].set_data(obf_param)
|
|
347
|
+
for layer in layers:
|
|
348
|
+
strTemplate = Template(target)
|
|
349
|
+
target_path = strTemplate.safe_substitute({"layer": str(layer)})
|
|
350
|
+
if target_path == param_path:
|
|
351
|
+
obf_param = _obfuscate_param(ckpt_param[item].value(), obf_metadata, obf_ops, layer)
|
|
352
|
+
if obf_param is None:
|
|
353
|
+
log.error("obfuscate weight {} failed.".format(item))
|
|
354
|
+
return False
|
|
355
|
+
ckpt_param[item].set_data(obf_param)
|
|
356
|
+
|
|
157
357
|
# save the obfuscated model to saved_path
|
|
158
358
|
obf_param_list = []
|
|
159
359
|
for item in ckpt_param:
|
|
160
360
|
obf_param_list.append({'name': item, 'data': ckpt_param[item]})
|
|
161
361
|
ckpt_file_name = ckpt_name.split('/')[-1]
|
|
162
362
|
obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
|
|
163
|
-
save_checkpoint(obf_param_list, os.path.
|
|
164
|
-
return
|
|
363
|
+
save_checkpoint(obf_param_list, os.path.realpath(saved_path) + '/' + obf_ckpt_file_name)
|
|
364
|
+
return True
|
|
165
365
|
|
|
166
366
|
|
|
167
|
-
def load_obf_params_into_net(network, target_modules, obf_ratios,
|
|
367
|
+
def load_obf_params_into_net(network, target_modules=None, obf_ratios=None, obf_config=None,
|
|
368
|
+
data_parallel_num=1, **kwargs):
|
|
168
369
|
"""
|
|
169
|
-
|
|
170
|
-
interface.
|
|
370
|
+
Modify model structure according to obfuscation config and load obfuscated checkpoint into obfuscated network.
|
|
171
371
|
|
|
172
372
|
Args:
|
|
173
373
|
network (nn.Cell): The original network that need to be obfuscated.
|
|
174
|
-
target_modules (list[str]): The target
|
|
175
|
-
represents the network path of target
|
|
176
|
-
The second string represents the
|
|
177
|
-
example, thr target_modules of GPT2 can be ``['backbone
|
|
178
|
-
If target_modules has the third value, it should be
|
|
179
|
-
'obfuscate_layers:int', which represents the number of
|
|
180
|
-
(such as transformer layers or resnet blocks).
|
|
374
|
+
target_modules (list[str]): The target ops that need to be obfuscated in the network. The first string
|
|
375
|
+
represents the network path of the target ops in the original network, which should be in form of
|
|
376
|
+
``"A/B/C"``. The second string represents the names of multiple target ops in the same path, which
|
|
377
|
+
should be in form of ``"D|E|F"``. For example, thr target_modules of GPT2 can be ``['backbone
|
|
378
|
+
/blocks/attention', 'dense1|dense2|dense3']``. If target_modules has the third value, it should be
|
|
379
|
+
in the format of 'obfuscate_layers:all' or 'obfuscate_layers:int', which represents the number of
|
|
380
|
+
layers need to be obfuscated of duplicate layers (such as transformer layers or resnet blocks).
|
|
381
|
+
Default: ``None``.
|
|
382
|
+
obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`. Default: ``None``.
|
|
383
|
+
obf_config (dict): The configuration of model obfuscation polices. Default: ``None``.
|
|
181
384
|
data_parallel_num (int): The data parallel number of parallel training. Default: 1.
|
|
182
|
-
obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
|
|
183
385
|
kwargs (dict): Configuration options dictionary.
|
|
184
386
|
|
|
185
387
|
- ignored_func_decorators (list[str]): The name list of function decorators in network's python code.
|
|
186
388
|
- ignored_class_decorators (list[str]): The name list of class decorators in network's python code.
|
|
187
389
|
|
|
390
|
+
Returns:
|
|
391
|
+
nn.Cell, new_net, which is the obfuscated network.
|
|
392
|
+
|
|
188
393
|
Raises:
|
|
189
394
|
TypeError: If `network` is not nn.Cell.
|
|
190
395
|
TypeError: If `obf_ratios` is not Tensor.
|
|
191
396
|
TypeError: If `target_modules` is not list.
|
|
397
|
+
TypeError: If `obf_config` is not dict.
|
|
192
398
|
TypeError: If target_modules's elements are not string.
|
|
193
399
|
ValueError: If the number of elements of `target_modules` is less than ``2``.
|
|
194
400
|
ValueError: If `obf_ratios` is empty Tensor.
|
|
@@ -204,45 +410,38 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_
|
|
|
204
410
|
>>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
|
|
205
411
|
>>> import mindspore.common.dtype as mstype
|
|
206
412
|
>>> import numpy as np
|
|
207
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
413
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
208
414
|
>>> net = LeNet5()
|
|
209
415
|
>>> save_checkpoint(net, './test_net.ckpt')
|
|
210
416
|
>>> target_modules = ['', 'fc1|fc2']
|
|
211
417
|
>>> # obfuscate ckpt files
|
|
212
|
-
>>> obfuscate_ckpt(net,
|
|
418
|
+
>>> obfuscate_ckpt(net, './', target_modules=target_modules, saved_path='./')
|
|
213
419
|
>>> # load obf ckpt into network
|
|
214
420
|
>>> new_net = LeNet5()
|
|
215
421
|
>>> load_checkpoint('./test_net_obf.ckpt', new_net)
|
|
216
|
-
>>>
|
|
217
|
-
>>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios)
|
|
422
|
+
>>> obf_net = load_obf_params_into_net(new_net, target_modules)
|
|
218
423
|
"""
|
|
219
424
|
if not isinstance(network, nn.Cell):
|
|
220
425
|
raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
|
|
221
|
-
if
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
426
|
+
if obf_config is None:
|
|
427
|
+
if not _check_valid_target(network, target_modules):
|
|
428
|
+
raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
|
|
429
|
+
log.warning("'target_modules and obf_ratios' will be deprecated and "
|
|
430
|
+
"removed in a future version, use 'obf_config' instead.")
|
|
431
|
+
obf_config = _transform_target_modules(target_modules)
|
|
432
|
+
|
|
433
|
+
if not isinstance(obf_config, dict):
|
|
434
|
+
raise TypeError('{} should be dict type, but got {}'.format(obf_config, type(obf_config)))
|
|
435
|
+
|
|
436
|
+
if not obf_config or not _check_valid_obf_config(obf_config, 'network_obf_config'):
|
|
437
|
+
raise ValueError("'obf_config' is empty or not valid, please check the input.")
|
|
438
|
+
|
|
227
439
|
if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
|
|
228
440
|
raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
target_list = []
|
|
234
|
-
for _ in range(path_len):
|
|
235
|
-
target_list.append([])
|
|
236
|
-
target_list.append(target_modules[1].split('|'))
|
|
237
|
-
global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
|
|
238
|
-
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
239
|
-
if number_of_ratios > MAX_OBF_RATIOS_NUM:
|
|
240
|
-
OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
|
|
241
|
-
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
242
|
-
MAX_OBF_RATIOS_NUM = number_of_ratios
|
|
243
|
-
rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
|
|
244
|
-
setattr(rewrite_network, 'obf_ratios', obf_ratios)
|
|
245
|
-
return rewrite_network
|
|
441
|
+
|
|
442
|
+
network_obf_config = obf_config.get('network_obf_config', [])
|
|
443
|
+
new_net = _obfuscate_network(network, network_obf_config, data_parallel_num=data_parallel_num, **kwargs)
|
|
444
|
+
return new_net
|
|
246
445
|
|
|
247
446
|
|
|
248
447
|
def _check_dir_path(name, dir_path):
|
|
@@ -255,15 +454,6 @@ def _check_dir_path(name, dir_path):
|
|
|
255
454
|
raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path))
|
|
256
455
|
|
|
257
456
|
|
|
258
|
-
def _judge_layer_index(layer_name):
|
|
259
|
-
"""Judge the layer index of target layers"""
|
|
260
|
-
split_name = layer_name.split('.')
|
|
261
|
-
for split_str in split_name[:]:
|
|
262
|
-
if split_str.isdigit():
|
|
263
|
-
return int(split_str)
|
|
264
|
-
return 0
|
|
265
|
-
|
|
266
|
-
|
|
267
457
|
def _check_valid_target(network, target_modules):
|
|
268
458
|
"""check whether the input 'target_modules' exists"""
|
|
269
459
|
if not isinstance(target_modules, list):
|
|
@@ -314,7 +504,7 @@ def _check_valid_target(network, target_modules):
|
|
|
314
504
|
OBF_RATIOS_WIDTH = 0
|
|
315
505
|
for target in target_list:
|
|
316
506
|
if not hasattr(net, target):
|
|
317
|
-
|
|
507
|
+
log.warning("{} does not exist in the path {}".format(target, target_modules[0]))
|
|
318
508
|
else:
|
|
319
509
|
OBF_RATIOS_WIDTH += 1
|
|
320
510
|
if OBF_RATIOS_WIDTH == 0:
|
|
@@ -323,6 +513,118 @@ def _check_valid_target(network, target_modules):
|
|
|
323
513
|
return True
|
|
324
514
|
|
|
325
515
|
|
|
516
|
+
def _check_ops_info(ops_info):
|
|
517
|
+
"""check ops info config"""
|
|
518
|
+
for op in ops_info:
|
|
519
|
+
op_name = op.get('name')
|
|
520
|
+
if not isinstance(op_name, str):
|
|
521
|
+
raise TypeError("op_name type should be str, but got {}.".format(type(op_name)))
|
|
522
|
+
input_x_name = op.get('input_x')
|
|
523
|
+
if not isinstance(input_x_name, str):
|
|
524
|
+
raise TypeError("input_x_name type should be str, but got {}.".format(type(input_x_name)))
|
|
525
|
+
input_y_name = op.get('input_y')
|
|
526
|
+
if not isinstance(input_y_name, str):
|
|
527
|
+
raise TypeError("input_y_name type should be str, but got {}.".format(type(input_y_name)))
|
|
528
|
+
if not isinstance(op.get('transpose_a', False), bool):
|
|
529
|
+
raise TypeError("transpose_a type should be bool, but got {}.".format(type(op.get('transpose_a'))))
|
|
530
|
+
if not isinstance(op.get('transpose_b', False), bool):
|
|
531
|
+
raise TypeError("transpose_b type should be bool, but got {}.".format(type(op.get('transpose_b'))))
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _check_new_input_info(insert_new_input):
|
|
535
|
+
"""check new input config"""
|
|
536
|
+
if not isinstance(insert_new_input, list):
|
|
537
|
+
raise TypeError("obf_config[][]['insert_new_input'] type should be list, but got {}."
|
|
538
|
+
.format(type(insert_new_input)))
|
|
539
|
+
for new_input in insert_new_input:
|
|
540
|
+
input_name = new_input.get('name')
|
|
541
|
+
if not isinstance(input_name, str):
|
|
542
|
+
raise TypeError("obf_config[][]['insert_new_input'][]['name'] type should be str, but got {}."
|
|
543
|
+
.format(type(input_name)))
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def _check_obf_metadata_config(config):
|
|
547
|
+
"""check obf metadata config"""
|
|
548
|
+
name = config.get('name')
|
|
549
|
+
if not name or not isinstance(name, str):
|
|
550
|
+
raise TypeError("obf_config[][]['name'] type should be str, but got {}.".format(type(name)))
|
|
551
|
+
shape = config.get('shape')
|
|
552
|
+
if not shape or not isinstance(shape, list):
|
|
553
|
+
raise TypeError("obf_config[][]['shape'] type should be list, but got {}.".format(type(shape)))
|
|
554
|
+
for item in shape:
|
|
555
|
+
if not isinstance(item, int):
|
|
556
|
+
raise TypeError("shape[] type should be int, but got {}.".format(type(item)))
|
|
557
|
+
save_metadata = config.get('save_metadata', True)
|
|
558
|
+
if not isinstance(save_metadata, bool):
|
|
559
|
+
raise TypeError("obf_config[][]['save_metadata'] type should be bool, but got {}."
|
|
560
|
+
.format(type(save_metadata)))
|
|
561
|
+
metadata_type = config.get('type')
|
|
562
|
+
if metadata_type is not None:
|
|
563
|
+
if not isinstance(metadata_type, str) or metadata_type not in _supported_metadata_type:
|
|
564
|
+
raise TypeError("obf_config[][]['type'] should be str and must in {}, but got {}."
|
|
565
|
+
.format(str(_supported_metadata_type), type(metadata_type)))
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
def _check_weight_obf_config(config):
|
|
569
|
+
"""check weight obfuscation config"""
|
|
570
|
+
target = config.get('target')
|
|
571
|
+
if not target or not isinstance(target, str):
|
|
572
|
+
raise TypeError("obf_config[][]['target'] type should be str, but got {}.".format(type(target)))
|
|
573
|
+
weight_obf_ops = config.get('weight_obf_ops', [])
|
|
574
|
+
if not isinstance(weight_obf_ops, list):
|
|
575
|
+
raise TypeError("obf_config[][]['weight_obf_ops'] type should be list, but got {}."
|
|
576
|
+
.format(type(weight_obf_ops)))
|
|
577
|
+
_check_ops_info(weight_obf_ops)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _check_network_obf_config(config):
|
|
581
|
+
"""check network obfuscation config"""
|
|
582
|
+
target = config.get('target')
|
|
583
|
+
if not target or not isinstance(target, str):
|
|
584
|
+
raise TypeError("obf_config[][]['target'] type should be str, but got {}.".format(type(target)))
|
|
585
|
+
module = config.get('module')
|
|
586
|
+
if not module or not isinstance(module, str):
|
|
587
|
+
raise TypeError("obf_config[][]['module'] type should be str, but got {}.".format(type(module)))
|
|
588
|
+
insert_new_input = config.get('insert_new_input', [])
|
|
589
|
+
_check_new_input_info(insert_new_input)
|
|
590
|
+
insert_ops = config.get('insert_ops', [])
|
|
591
|
+
if not isinstance(insert_ops, list):
|
|
592
|
+
raise TypeError("obf_config[][]['insert_ops'] type should be list, but got {}.".format(type(insert_ops)))
|
|
593
|
+
_check_ops_info(insert_ops)
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def _check_valid_obf_config(obf_config, config_type):
|
|
597
|
+
"""check obfuscation config"""
|
|
598
|
+
if not isinstance(config_type, str) or config_type not in _supported_config_type:
|
|
599
|
+
raise TypeError("config_type must be str, and in {}, but got {}."
|
|
600
|
+
.format(str(_supported_config_type), config_type))
|
|
601
|
+
for config_type_item in obf_config.keys():
|
|
602
|
+
if not isinstance(config_type_item, str) or config_type_item not in _supported_config_type:
|
|
603
|
+
raise TypeError("config_type must be str, and in {}, but got {}."
|
|
604
|
+
.format(str(_supported_config_type), config_type_item))
|
|
605
|
+
config_list = obf_config.get(config_type)
|
|
606
|
+
if not isinstance(config_list, list):
|
|
607
|
+
raise TypeError("obf_config[] type of should be list, but got {}.".format(type(config_list)))
|
|
608
|
+
|
|
609
|
+
for config in config_list:
|
|
610
|
+
if not isinstance(config, dict):
|
|
611
|
+
raise TypeError("obf_config[][] type should be dict, but got {}.".format(type(config)))
|
|
612
|
+
if config_type == 'obf_metadata_config':
|
|
613
|
+
_check_obf_metadata_config(config)
|
|
614
|
+
elif config_type == 'weight_obf_config':
|
|
615
|
+
_check_weight_obf_config(config)
|
|
616
|
+
elif config_type == 'network_obf_config':
|
|
617
|
+
_check_network_obf_config(config)
|
|
618
|
+
layers = config.get('layers')
|
|
619
|
+
if layers is not None:
|
|
620
|
+
if not isinstance(layers, list):
|
|
621
|
+
raise TypeError("obf_config[][]['layers'] type should be list, but got {}.".format(type(layers)))
|
|
622
|
+
for layer in layers:
|
|
623
|
+
if not isinstance(layer, int):
|
|
624
|
+
raise TypeError("obf_config[][]['layers'][] type should be int, but got {}.".format(type(layer)))
|
|
625
|
+
return True
|
|
626
|
+
|
|
627
|
+
|
|
326
628
|
def _update_max_obf_ratios_num(target_modules):
|
|
327
629
|
"""Update MAX_OBF_RATIOS_NUM"""
|
|
328
630
|
if len(target_modules) >= 3:
|
|
@@ -341,170 +643,163 @@ def _update_max_obf_ratios_num(target_modules):
|
|
|
341
643
|
MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH
|
|
342
644
|
|
|
343
645
|
|
|
344
|
-
def _get_default_target_modules(ckpt_files):
|
|
345
|
-
"""Get the default or suggested target modules, if the target modules is None."""
|
|
346
|
-
|
|
347
|
-
def _split_to_path_and_target(module, target):
|
|
348
|
-
# split module into path list and target list
|
|
349
|
-
target_index = module.index(target)
|
|
350
|
-
path = module[:target_index - 1]
|
|
351
|
-
target = module[target_index:].split('/')[0]
|
|
352
|
-
return path, target
|
|
353
|
-
|
|
354
|
-
def _find_default_obfuscate_modules(net_path):
|
|
355
|
-
# find modules including the default paths
|
|
356
|
-
default_module = {'attention'}
|
|
357
|
-
for module in default_module:
|
|
358
|
-
if module in net_path and module not in candidate_modules:
|
|
359
|
-
candidate_modules.append(net_path)
|
|
360
|
-
# find the default targets in the default module
|
|
361
|
-
default_target = {'dense', 'query', 'key', 'value'}
|
|
362
|
-
for target in default_target:
|
|
363
|
-
for candidate in candidate_modules:
|
|
364
|
-
if target in candidate:
|
|
365
|
-
path, target = _split_to_path_and_target(candidate, target)
|
|
366
|
-
if path not in paths:
|
|
367
|
-
paths.append(path)
|
|
368
|
-
if target not in targets:
|
|
369
|
-
targets.append(target)
|
|
370
|
-
|
|
371
|
-
def _find_suggested_obfuscate_modules(net_path):
|
|
372
|
-
default_target = {'dense', 'query', 'key', 'value'}
|
|
373
|
-
for target in default_target:
|
|
374
|
-
# find the suggest modules
|
|
375
|
-
if target in net_path:
|
|
376
|
-
path, target = _split_to_path_and_target(net_path, target)
|
|
377
|
-
if [path, target] not in suggest_modules:
|
|
378
|
-
suggest_modules.append([path, target])
|
|
379
|
-
|
|
380
|
-
# store the potential candidate_modules
|
|
381
|
-
candidate_modules = []
|
|
382
|
-
suggest_modules = []
|
|
383
|
-
paths = []
|
|
384
|
-
targets = []
|
|
385
|
-
ckpt_dir_files = os.listdir(ckpt_files)
|
|
386
|
-
for ckpt_name in ckpt_dir_files:
|
|
387
|
-
if not ckpt_name.endswith('.ckpt'):
|
|
388
|
-
continue
|
|
389
|
-
try:
|
|
390
|
-
ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
|
|
391
|
-
except (ValueError, TypeError, OSError):
|
|
392
|
-
logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
|
|
393
|
-
return None
|
|
394
|
-
for item in ckpt_param:
|
|
395
|
-
param_path = _remove_digit(item)
|
|
396
|
-
param_path = '/'.join(param_path)
|
|
397
|
-
# find candidate modules including the default paths and append candidate_modules
|
|
398
|
-
_find_default_obfuscate_modules(param_path)
|
|
399
|
-
# give the suggested modules and find the default targets in the default module
|
|
400
|
-
_find_suggested_obfuscate_modules(param_path)
|
|
401
|
-
if paths and targets:
|
|
402
|
-
target_modules = [paths[0], '|'.join(targets)]
|
|
403
|
-
logger.warning("The default obfuscate modules is obtained:{}".format(target_modules))
|
|
404
|
-
return target_modules
|
|
405
|
-
# logging the suggested target module
|
|
406
|
-
logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}"
|
|
407
|
-
.format(suggest_modules))
|
|
408
|
-
raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']")
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
def _get_valid_module(item, path_list, target_list):
|
|
412
|
-
"""get the valid module"""
|
|
413
|
-
number_path = len(path_list)
|
|
414
|
-
net_path = _remove_digit(item)
|
|
415
|
-
net_path = '/'.join(net_path[:number_path])
|
|
416
|
-
tar_path = '/'.join(path_list)
|
|
417
|
-
# update the weights with obf_ratios in target module
|
|
418
|
-
if net_path == tar_path:
|
|
419
|
-
for target in target_list:
|
|
420
|
-
if target in item.split('.'):
|
|
421
|
-
target_index = item.split('.').index(target)
|
|
422
|
-
module = ''.join(item.split('.')[:target_index + 1])
|
|
423
|
-
return module
|
|
424
|
-
return None
|
|
425
|
-
|
|
426
|
-
|
|
427
646
|
def _remove_digit(item):
|
|
428
647
|
"""remove digit in the parameter path"""
|
|
429
|
-
|
|
430
|
-
for tmp_str in
|
|
648
|
+
item_split = item.split('_')
|
|
649
|
+
for tmp_str in item_split[:]:
|
|
431
650
|
if tmp_str.isdigit():
|
|
432
|
-
|
|
433
|
-
return
|
|
651
|
+
item_split.remove(tmp_str)
|
|
652
|
+
return '_'.join(item_split)
|
|
653
|
+
|
|
434
654
|
|
|
655
|
+
def _remove_scope(item):
|
|
656
|
+
"""remove scope of name values"""
|
|
657
|
+
item_split = item.split('.')
|
|
658
|
+
for tmp_str in item_split[:]:
|
|
659
|
+
if tmp_str == 'self':
|
|
660
|
+
item_split.remove(tmp_str)
|
|
661
|
+
return '.'.join(item_split)
|
|
435
662
|
|
|
436
|
-
def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
|
|
437
|
-
"""obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
|
|
438
663
|
|
|
439
|
-
|
|
664
|
+
def _obfuscate_network(model, obf_config=None, data_parallel_num=1, **kwargs):
|
|
665
|
+
"""obfuscate original network, including add deobfuscation ops and add inputs for passing obf_metadata."""
|
|
666
|
+
|
|
667
|
+
def _insert_input(stree: SymbolTree, arg_name: str = 'obf_metadata'):
|
|
440
668
|
"""add inputs for passing obf_ratio"""
|
|
441
669
|
last_input = None
|
|
442
670
|
for node in stree.nodes():
|
|
443
671
|
if node.get_node_type() == NodeType.Input:
|
|
444
672
|
last_input = node
|
|
445
673
|
position = stree.after(last_input)
|
|
446
|
-
# the insert input node name would be '
|
|
674
|
+
# the insert input node name would be 'input_obf_metadata'
|
|
447
675
|
new_input_node = last_input.create_input(arg_name)
|
|
448
676
|
stree.insert(position, new_input_node)
|
|
449
677
|
|
|
450
|
-
def
|
|
451
|
-
"""
|
|
452
|
-
|
|
453
|
-
input_y_node
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
if data_parallel_num > 1:
|
|
459
|
-
logger.info("Data parallel number is: {}".format(data_parallel_num))
|
|
460
|
-
new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
|
|
461
|
-
targets=target_list, args=arg_list, name='mul')
|
|
678
|
+
def _update_subnet(substree: SymbolTree, subnode: Node):
|
|
679
|
+
"""update the network once the subnet is obfuscated"""
|
|
680
|
+
input_y_node = substree.get_node("input_obf_metadata")
|
|
681
|
+
if input_y_node is None:
|
|
682
|
+
log.error("can not find input node: obf_metadata for net: {}.".format(subnode.get_name()))
|
|
683
|
+
return False
|
|
684
|
+
if hasattr(subnode, 'get_handler'):
|
|
685
|
+
subnode.get_handler().append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
|
|
462
686
|
else:
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
687
|
+
subnode.append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
|
|
688
|
+
return True
|
|
689
|
+
|
|
690
|
+
def _insert_ops(stree: SymbolTree, node: Node, insert_ops: list):
|
|
691
|
+
"""add mul operation for original network"""
|
|
692
|
+
current_node = node
|
|
693
|
+
for insert_op in insert_ops:
|
|
694
|
+
arg_list = current_node.get_targets().copy()
|
|
695
|
+
obf_metadata = stree.get_node("input_obf_metadata")
|
|
696
|
+
if obf_metadata is None:
|
|
697
|
+
raise ValueError("can not find input node: obf_metadata for net: {}.".format(current_node.get_name()))
|
|
698
|
+
v: str = obf_metadata.get_targets()[0].value
|
|
699
|
+
index = insert_op['input_y']
|
|
700
|
+
sv: ScopedValue = ScopedValue.create_naming_value(v + f'["{index}"]')
|
|
701
|
+
arg_list.append(sv)
|
|
702
|
+
target_list = current_node.get_targets().copy()
|
|
703
|
+
name = insert_op['name']
|
|
704
|
+
if data_parallel_num > 1:
|
|
705
|
+
new_node = current_node.create_call_cell(cell=_get_op(name).shard(((data_parallel_num, 1), ())),
|
|
706
|
+
targets=target_list, args=arg_list, name=name)
|
|
707
|
+
else:
|
|
708
|
+
new_node = current_node.create_call_cell(cell=_get_op(name), targets=target_list, args=arg_list,
|
|
709
|
+
name=name)
|
|
710
|
+
position = stree.after(current_node)
|
|
711
|
+
stree.insert(position, new_node)
|
|
712
|
+
current_node = new_node
|
|
466
713
|
|
|
467
|
-
def
|
|
714
|
+
def _insert_ops_by_name(stree: SymbolTree, after_name_list: list, module: str):
|
|
468
715
|
"""add mul operation after the target nodes according the name of them"""
|
|
469
716
|
if not after_name_list:
|
|
470
717
|
return
|
|
471
718
|
for node in stree.nodes():
|
|
472
719
|
for after_name in after_name_list:
|
|
473
720
|
if node.get_name() == after_name:
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
721
|
+
insert_ops = insert_ops_map[module+'/'+after_name]
|
|
722
|
+
_insert_ops(stree, node, insert_ops)
|
|
723
|
+
|
|
724
|
+
def _process_controlflow_node(node: Node, stree: SymbolTree, full_path: str, path: str, targets: dict):
|
|
725
|
+
ctrl = node.get_handler() if hasattr(node, 'get_handler') else node
|
|
726
|
+
cell_loop_name = ''
|
|
727
|
+
find_cell_loop = False
|
|
728
|
+
if hasattr(ctrl, "loop_vars") and ctrl.loop_vars:
|
|
729
|
+
cell_loop_name = ctrl.loop_vars[0]
|
|
730
|
+
inputs = ctrl.get_inputs()
|
|
731
|
+
for input in inputs:
|
|
732
|
+
if input.get_node_type() == NodeType.CellContainer:
|
|
733
|
+
find_cell_loop = True
|
|
734
|
+
full_node_name = input.get_name()
|
|
735
|
+
node_name = _remove_digit(_remove_scope(full_node_name))
|
|
736
|
+
if not _process_cellcontainer_node(input, full_path+'/'+full_node_name,
|
|
737
|
+
path+'/'+node_name, targets):
|
|
738
|
+
log.error("_process_cellcontainer_node for node: {} failed.".format(node_name))
|
|
739
|
+
return False
|
|
740
|
+
for c_node in ctrl.nodes():
|
|
741
|
+
c_node_name = c_node.get_name()
|
|
742
|
+
c_node_type = c_node.get_node_type()
|
|
743
|
+
if c_node.get_node_type() == NodeType.ControlFlow:
|
|
744
|
+
if not _process_controlflow_node(c_node, stree, full_path+'/'+c_node_name, path, targets):
|
|
745
|
+
return False
|
|
746
|
+
elif c_node.get_node_type() == NodeType.Tree and _is_target_module(path + '/' + c_node_name, targets):
|
|
747
|
+
sub_stree = SymbolTree(c_node.symbol_tree)
|
|
748
|
+
_insert_input(sub_stree, arg_name='obf_metadata')
|
|
749
|
+
_insert_ops_by_name(sub_stree, after_name_list=targets.get(path + '/' + c_node_name, None),
|
|
750
|
+
module=path + '/' + c_node_name)
|
|
751
|
+
if not _traverse(sub_stree, full_path+'/'+c_node_name, path+'/'+c_node_name, targets):
|
|
752
|
+
log.error("_traverse for node: {} failed.".format(c_node_name))
|
|
753
|
+
return False
|
|
754
|
+
if not _update_subnet(sub_stree, c_node):
|
|
755
|
+
log.error("_update_subnet for node: {} failed.".format(c_node_name))
|
|
756
|
+
return False
|
|
757
|
+
elif find_cell_loop and c_node_type == NodeType.CallFunction and c_node_name.startswith(cell_loop_name):
|
|
758
|
+
input_y_node = stree.get_node("input_obf_metadata")
|
|
759
|
+
if input_y_node is None:
|
|
760
|
+
log.error("input_y_node for node: {} is None.".format(c_node_name))
|
|
761
|
+
return False
|
|
762
|
+
c_node.append_kwarg({"obf_metadata": input_y_node.get_targets()[0]})
|
|
763
|
+
return True
|
|
764
|
+
|
|
765
|
+
def _process_cellcontainer_node(node: Node, full_path: str, path: str, targets: dict):
|
|
766
|
+
cellcontainer = node.get_handler() if hasattr(node, 'get_handler') else node
|
|
767
|
+
for i in range(len(cellcontainer.nodes())):
|
|
768
|
+
cell_node = cellcontainer.nodes()[i]
|
|
769
|
+
# insert input for each sub_stree in cell_container
|
|
770
|
+
if _is_target_module(path, targets) and cell_node.get_node_type() == NodeType.Tree:
|
|
771
|
+
sub_stree = SymbolTree(cell_node.symbol_tree)
|
|
772
|
+
_insert_input(sub_stree, arg_name='obf_metadata')
|
|
773
|
+
_insert_ops_by_name(sub_stree, after_name_list=targets.get(path, None), module=path)
|
|
774
|
+
if not _traverse(sub_stree, full_path + '/' + str(i), path + '/' + str(i), targets):
|
|
775
|
+
return False
|
|
776
|
+
return True
|
|
777
|
+
|
|
778
|
+
def _is_target_module(path, targets):
|
|
779
|
+
for target_module in targets.keys():
|
|
780
|
+
if target_module.startswith(path):
|
|
781
|
+
return True
|
|
782
|
+
return False
|
|
783
|
+
|
|
784
|
+
def _traverse(stree: SymbolTree, full_path: str, path: str, targets: dict):
|
|
500
785
|
for node in stree.nodes():
|
|
501
786
|
node_name = node.get_name()
|
|
502
|
-
if node.get_node_type() == NodeType.
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
787
|
+
if node.get_node_type() == NodeType.ControlFlow:
|
|
788
|
+
if not _process_controlflow_node(node, stree, full_path + '/' + node_name, path, targets):
|
|
789
|
+
log.error("process controlflow node: {} failed.".format(node.get_name()))
|
|
790
|
+
return False
|
|
791
|
+
elif node.get_node_type() == NodeType.Tree and _is_target_module(path + '/' + node_name, targets):
|
|
792
|
+
sub_stree = node.get_sub_tree()
|
|
793
|
+
_insert_input(sub_stree, arg_name='obf_metadata')
|
|
794
|
+
_insert_ops_by_name(sub_stree, after_name_list=targets.get(path + '/' + node_name, None),
|
|
795
|
+
module=path + '/' + node_name)
|
|
796
|
+
if not _traverse(sub_stree, full_path + '/' + node_name, path + '/' + node_name, targets):
|
|
797
|
+
log.error("traverse sub_stree for node: {} failed.".format(node.get_name()))
|
|
798
|
+
return False
|
|
799
|
+
if not _update_subnet(sub_stree, node):
|
|
800
|
+
log.error("update subnet for node: {} failed.".format(node.get_name()))
|
|
801
|
+
return False
|
|
802
|
+
return True
|
|
508
803
|
|
|
509
804
|
def _register_denied_func_decorators(fn):
|
|
510
805
|
"""set the function decorators which should be denied for parse"""
|
|
@@ -533,9 +828,48 @@ def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwa
|
|
|
533
828
|
if kw_class_dec and not isinstance(kw_class_dec[0], str):
|
|
534
829
|
raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0])))
|
|
535
830
|
|
|
831
|
+
targets = {}
|
|
832
|
+
insert_ops_map = {}
|
|
833
|
+
for obf_item in obf_config:
|
|
834
|
+
module = obf_item.get('module', None)
|
|
835
|
+
target = obf_item.get('target', None)
|
|
836
|
+
insert_ops_info = obf_item.get('insert_ops', None)
|
|
837
|
+
layers = obf_item.get('layers', [])
|
|
838
|
+
if not layers:
|
|
839
|
+
real_insert_ops_info = []
|
|
840
|
+
if not targets.get(module, None):
|
|
841
|
+
targets[module] = []
|
|
842
|
+
if target not in targets[module]:
|
|
843
|
+
targets[module].append(target)
|
|
844
|
+
target_path = module + '/' + target
|
|
845
|
+
for op_info in insert_ops_info:
|
|
846
|
+
real_op_info = op_info.copy()
|
|
847
|
+
real_insert_ops_info.append(real_op_info)
|
|
848
|
+
insert_ops_map[target_path] = real_insert_ops_info
|
|
849
|
+
for layer in layers:
|
|
850
|
+
real_insert_ops_info = []
|
|
851
|
+
strTemplate = Template(module)
|
|
852
|
+
real_module = strTemplate.safe_substitute({"layer": str(layer)})
|
|
853
|
+
if not targets.get(real_module, None):
|
|
854
|
+
targets[real_module] = []
|
|
855
|
+
if target not in targets[real_module]:
|
|
856
|
+
targets[real_module].append(target)
|
|
857
|
+
target_path = real_module + '/' + target
|
|
858
|
+
for op_info in insert_ops_info:
|
|
859
|
+
real_op_info = op_info.copy()
|
|
860
|
+
strTemplate = Template(real_op_info['input_x'])
|
|
861
|
+
real_op_info['input_x'] = strTemplate.safe_substitute({"layer": str(layer)})
|
|
862
|
+
strTemplate = Template(real_op_info['input_y'])
|
|
863
|
+
real_op_info['input_y'] = strTemplate.safe_substitute({"layer": str(layer)})
|
|
864
|
+
real_insert_ops_info.append(real_op_info)
|
|
865
|
+
insert_ops_map[target_path] = real_insert_ops_info
|
|
866
|
+
|
|
867
|
+
root_path = ""
|
|
536
868
|
main_stree = SymbolTree.create(model)
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
869
|
+
_insert_input(main_stree, arg_name='obf_metadata')
|
|
870
|
+
_insert_ops_by_name(main_stree, after_name_list=targets.get(root_path, None), module=root_path)
|
|
871
|
+
if not _traverse(main_stree, full_path=root_path, path=root_path, targets=targets):
|
|
872
|
+
log.error("_traverse for root_path: {} failed.".format(root_path))
|
|
873
|
+
return None
|
|
540
874
|
new_net = main_stree.get_network()
|
|
541
875
|
return new_net
|