mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""obfuscate network based on rewrite interfaces."""
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import secrets
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
from mindspore import ops, nn
|
|
22
|
+
from mindspore.common.tensor import Tensor
|
|
23
|
+
from mindspore import log as logger
|
|
24
|
+
from mindspore import load_checkpoint, save_checkpoint
|
|
25
|
+
from mindspore.rewrite import SymbolTree, Node, NodeType, TreeNodeHelper, ScopedValue
|
|
26
|
+
from mindspore.rewrite.parsers.class_def_parser import ClassDefParser
|
|
27
|
+
from mindspore.rewrite.parsers.class_def_parser import ModuleParser
|
|
28
|
+
|
|
29
|
+
OBF_RATIOS_LENGTH = 1
|
|
30
|
+
MAX_OBF_RATIOS_NUM = 50
|
|
31
|
+
OBF_RATIOS_WIDTH = 0
|
|
32
|
+
OBF_RATIOS_INSERT_INDEX = 0
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
|
|
36
|
+
"""
|
|
37
|
+
obfuscate the plaintext checkpoint files. Usually used in conjunction with
|
|
38
|
+
:func:`mindspore.load_obf_params_into_net`.
|
|
39
|
+
interface.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
network (nn.Cell): The original network that need to be obfuscated.
|
|
43
|
+
ckpt_files (str): The directory path of original ckpt files.
|
|
44
|
+
target_modules (list[str]): The target module of network that need to be obfuscated. The first string
|
|
45
|
+
represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
|
|
46
|
+
The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
|
|
47
|
+
example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
|
|
48
|
+
If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
|
|
49
|
+
'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
|
|
50
|
+
(such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search
|
|
51
|
+
target modules by itself. If found, the searched target module would be used, otherwise suggested target
|
|
52
|
+
modules would be given with warning log. Default: ``None``.
|
|
53
|
+
saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
|
|
54
|
+
obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
|
|
55
|
+
range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
TypeError: If `network` is not nn.Cell.
|
|
59
|
+
TypeError: If `ckpt_files` is not string or `saved_path` is not string.
|
|
60
|
+
TypeError: If `target_modules` is not list.
|
|
61
|
+
TypeError: If target_modules's elements are not string.
|
|
62
|
+
ValueError: If `ckpt_files` is not exist or `saved_path` is not exist.
|
|
63
|
+
ValueError: If the number of elements of `target_modules` is less than ``2``.
|
|
64
|
+
ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
|
|
65
|
+
letters, numbers, ``'_'`` and ``'/'``.
|
|
66
|
+
ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
|
|
67
|
+
lowercase letters, numbers, ``'_'`` and ``'|'``.
|
|
68
|
+
ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
|
|
69
|
+
'obfuscate_layers:int'.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
>>> from mindspore import obfuscate_ckpt, save_checkpoint
|
|
76
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
77
|
+
>>> net = LeNet5()
|
|
78
|
+
>>> save_checkpoint(net, './test_net.ckpt')
|
|
79
|
+
>>> target_modules = ['', 'fc1|fc2']
|
|
80
|
+
>>> obfuscate_ckpt(net, target_modules, './', './')
|
|
81
|
+
"""
|
|
82
|
+
if not isinstance(network, nn.Cell):
|
|
83
|
+
raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
|
|
84
|
+
_check_dir_path('ckpt_files', ckpt_files)
|
|
85
|
+
_check_dir_path('saved_path', saved_path)
|
|
86
|
+
# Try to find default target modules
|
|
87
|
+
if target_modules is None:
|
|
88
|
+
to_split_modules = _get_default_target_modules(ckpt_files)
|
|
89
|
+
else:
|
|
90
|
+
if len(target_modules) >= 1 and target_modules[0] == '/':
|
|
91
|
+
target_modules[0] = ''
|
|
92
|
+
to_split_modules = target_modules
|
|
93
|
+
if not _check_valid_target(network, to_split_modules):
|
|
94
|
+
raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'."
|
|
95
|
+
.format(to_split_modules))
|
|
96
|
+
if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
|
|
97
|
+
raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
|
|
98
|
+
.format(obfuscate_scale))
|
|
99
|
+
# generate and save obf_ratios to saved_path
|
|
100
|
+
path_list = to_split_modules[0].split('/')
|
|
101
|
+
target_list = to_split_modules[1].split('|')
|
|
102
|
+
global OBF_RATIOS_LENGTH
|
|
103
|
+
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
104
|
+
if number_of_ratios > MAX_OBF_RATIOS_NUM:
|
|
105
|
+
OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
|
|
106
|
+
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
107
|
+
obf_ratios = []
|
|
108
|
+
secrets_generator = secrets.SystemRandom()
|
|
109
|
+
for _ in range(number_of_ratios):
|
|
110
|
+
secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
|
|
111
|
+
obf_ratios.append(secure_float)
|
|
112
|
+
# start obfuscate ckpt
|
|
113
|
+
ckpt_dir_files = os.listdir(ckpt_files)
|
|
114
|
+
for ckpt_name in ckpt_dir_files:
|
|
115
|
+
sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name
|
|
116
|
+
if Path(sub_path).is_dir():
|
|
117
|
+
sub_ckpt_file_list = os.listdir(sub_path)
|
|
118
|
+
new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name
|
|
119
|
+
if not os.path.exists(new_saved_path):
|
|
120
|
+
try:
|
|
121
|
+
os.mkdir(new_saved_path, mode=0o700)
|
|
122
|
+
except FileExistsError:
|
|
123
|
+
pass
|
|
124
|
+
for sub_ckpt_name in sub_ckpt_file_list:
|
|
125
|
+
if not sub_ckpt_name.endswith('.ckpt'):
|
|
126
|
+
continue
|
|
127
|
+
_obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list,
|
|
128
|
+
target_list, new_saved_path)
|
|
129
|
+
else:
|
|
130
|
+
if not ckpt_name.endswith('.ckpt'):
|
|
131
|
+
continue
|
|
132
|
+
_obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list,
|
|
133
|
+
target_list, saved_path)
|
|
134
|
+
return obf_ratios
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path):
|
|
138
|
+
"""Obfuscate single ckpt file"""
|
|
139
|
+
module_has_been_obfuscated = set()
|
|
140
|
+
try:
|
|
141
|
+
ckpt_param = load_checkpoint(ckpt_name)
|
|
142
|
+
except (ValueError, TypeError, OSError):
|
|
143
|
+
logger.error("Load checkpoint failed for file {}.".format(ckpt_name))
|
|
144
|
+
return None
|
|
145
|
+
obf_ratios_index = -1
|
|
146
|
+
for item in ckpt_param:
|
|
147
|
+
module = _get_valid_module(item, path_list, target_list)
|
|
148
|
+
if module:
|
|
149
|
+
layer_index = _judge_layer_index(item)
|
|
150
|
+
if layer_index >= OBF_RATIOS_LENGTH:
|
|
151
|
+
continue
|
|
152
|
+
if module not in module_has_been_obfuscated:
|
|
153
|
+
module_has_been_obfuscated.add(module)
|
|
154
|
+
obf_ratios_index += 1
|
|
155
|
+
ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH
|
|
156
|
+
ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index])
|
|
157
|
+
# save the obfuscated model to saved_path
|
|
158
|
+
obf_param_list = []
|
|
159
|
+
for item in ckpt_param:
|
|
160
|
+
obf_param_list.append({'name': item, 'data': ckpt_param[item]})
|
|
161
|
+
ckpt_file_name = ckpt_name.split('/')[-1]
|
|
162
|
+
obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
|
|
163
|
+
save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name)
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs):
|
|
168
|
+
"""
|
|
169
|
+
load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt`
|
|
170
|
+
interface.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
network (nn.Cell): The original network that need to be obfuscated.
|
|
174
|
+
target_modules (list[str]): The target module of network that need to be obfuscated. The first string
|
|
175
|
+
represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
|
|
176
|
+
The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
|
|
177
|
+
example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
|
|
178
|
+
If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
|
|
179
|
+
'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
|
|
180
|
+
(such as transformer layers or resnet blocks).
|
|
181
|
+
data_parallel_num (int): The data parallel number of parallel training. Default: 1.
|
|
182
|
+
obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
|
|
183
|
+
kwargs (dict): Configuration options dictionary.
|
|
184
|
+
|
|
185
|
+
- ignored_func_decorators (list[str]): The name list of function decorators in network's python code.
|
|
186
|
+
- ignored_class_decorators (list[str]): The name list of class decorators in network's python code.
|
|
187
|
+
|
|
188
|
+
Raises:
|
|
189
|
+
TypeError: If `network` is not nn.Cell.
|
|
190
|
+
TypeError: If `obf_ratios` is not Tensor.
|
|
191
|
+
TypeError: If `target_modules` is not list.
|
|
192
|
+
TypeError: If target_modules's elements are not string.
|
|
193
|
+
ValueError: If the number of elements of `target_modules` is less than ``2``.
|
|
194
|
+
ValueError: If `obf_ratios` is empty Tensor.
|
|
195
|
+
ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
|
|
196
|
+
letters, numbers, ``'_'`` and ``'/'``.
|
|
197
|
+
ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
|
|
198
|
+
lowercase letters, numbers, ``'_'`` and ``'|'``.
|
|
199
|
+
ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
|
|
200
|
+
'obfuscate_layers:int'.
|
|
201
|
+
TypeError: If `ignored_func_decorators` is not list[str] or `ignored_class_decorators` is not list[str].
|
|
202
|
+
|
|
203
|
+
Examples:
|
|
204
|
+
>>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
|
|
205
|
+
>>> import mindspore.common.dtype as mstype
|
|
206
|
+
>>> import numpy as np
|
|
207
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
208
|
+
>>> net = LeNet5()
|
|
209
|
+
>>> save_checkpoint(net, './test_net.ckpt')
|
|
210
|
+
>>> target_modules = ['', 'fc1|fc2']
|
|
211
|
+
>>> # obfuscate ckpt files
|
|
212
|
+
>>> obfuscate_ckpt(net, target_modules, './', './')
|
|
213
|
+
>>> # load obf ckpt into network
|
|
214
|
+
>>> new_net = LeNet5()
|
|
215
|
+
>>> load_checkpoint('./test_net_obf.ckpt', new_net)
|
|
216
|
+
>>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16)
|
|
217
|
+
>>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios)
|
|
218
|
+
"""
|
|
219
|
+
if not isinstance(network, nn.Cell):
|
|
220
|
+
raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
|
|
221
|
+
if not isinstance(obf_ratios, Tensor):
|
|
222
|
+
raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios)))
|
|
223
|
+
if obf_ratios.size == 0:
|
|
224
|
+
raise ValueError("obf_ratios can not be empty.")
|
|
225
|
+
if not _check_valid_target(network, target_modules):
|
|
226
|
+
raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
|
|
227
|
+
if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
|
|
228
|
+
raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
|
|
229
|
+
if len(target_modules) >= 1 and target_modules[0] == '/':
|
|
230
|
+
target_modules[0] = ''
|
|
231
|
+
path_list = target_modules[0].split('/')
|
|
232
|
+
path_len = len(path_list)
|
|
233
|
+
target_list = []
|
|
234
|
+
for _ in range(path_len):
|
|
235
|
+
target_list.append([])
|
|
236
|
+
target_list.append(target_modules[1].split('|'))
|
|
237
|
+
global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
|
|
238
|
+
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
239
|
+
if number_of_ratios > MAX_OBF_RATIOS_NUM:
|
|
240
|
+
OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
|
|
241
|
+
number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
242
|
+
MAX_OBF_RATIOS_NUM = number_of_ratios
|
|
243
|
+
rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
|
|
244
|
+
setattr(rewrite_network, 'obf_ratios', obf_ratios)
|
|
245
|
+
return rewrite_network
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _check_dir_path(name, dir_path):
|
|
249
|
+
"""check directory path"""
|
|
250
|
+
if not isinstance(dir_path, str):
|
|
251
|
+
raise TypeError("{} must be string, but got {}.".format(name, type(dir_path)))
|
|
252
|
+
if not os.path.exists(dir_path):
|
|
253
|
+
raise ValueError("{} is not exist, please check the input {}.".format(dir_path, name))
|
|
254
|
+
if not Path(dir_path).is_dir():
|
|
255
|
+
raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path))
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _judge_layer_index(layer_name):
|
|
259
|
+
"""Judge the layer index of target layers"""
|
|
260
|
+
split_name = layer_name.split('.')
|
|
261
|
+
for split_str in split_name[:]:
|
|
262
|
+
if split_str.isdigit():
|
|
263
|
+
return int(split_str)
|
|
264
|
+
return 0
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _check_valid_target(network, target_modules):
|
|
268
|
+
"""check whether the input 'target_modules' exists"""
|
|
269
|
+
if not isinstance(target_modules, list):
|
|
270
|
+
raise TypeError("target_modules type should be list, but got {}.".format(type(target_modules)))
|
|
271
|
+
if len(target_modules) < 2:
|
|
272
|
+
raise ValueError("target_modules should contain at least two string values, in the form of ['A/B/C', 'D1|D2'],"
|
|
273
|
+
"but got {}.".format(target_modules))
|
|
274
|
+
if (not isinstance(target_modules[0], str)) or (not isinstance(target_modules[1], str)):
|
|
275
|
+
raise TypeError("The values of target_modules should be string, but got {} and {}.".
|
|
276
|
+
format(type(target_modules[0]), type(target_modules[1])))
|
|
277
|
+
|
|
278
|
+
if not target_modules[1]:
|
|
279
|
+
raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'"
|
|
280
|
+
.format(target_modules[1]))
|
|
281
|
+
if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \
|
|
282
|
+
or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]):
|
|
283
|
+
raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']."
|
|
284
|
+
"target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/',"
|
|
285
|
+
"target_modules[1] can only contain uppercase and lowercase letters, numbers, '_' and '|'"
|
|
286
|
+
.format(target_modules))
|
|
287
|
+
# target_modules[0] is allowed to be '', it means the main network path
|
|
288
|
+
path_list = target_modules[0].split('/')
|
|
289
|
+
target_list = target_modules[1].split('|')
|
|
290
|
+
net = network
|
|
291
|
+
# DFS check whether path_list is valid
|
|
292
|
+
stk = [net]
|
|
293
|
+
i = 0
|
|
294
|
+
global OBF_RATIOS_LENGTH
|
|
295
|
+
OBF_RATIOS_LENGTH = 1
|
|
296
|
+
while stk and i < len(path_list):
|
|
297
|
+
net = stk.pop()
|
|
298
|
+
if hasattr(net, path_list[i]):
|
|
299
|
+
net = getattr(net, path_list[i])
|
|
300
|
+
i += 1
|
|
301
|
+
if isinstance(net, nn.CellList):
|
|
302
|
+
OBF_RATIOS_LENGTH *= len(net)
|
|
303
|
+
for n in net:
|
|
304
|
+
stk.append(n)
|
|
305
|
+
elif isinstance(net, nn.Cell):
|
|
306
|
+
stk.append(net)
|
|
307
|
+
else:
|
|
308
|
+
raise TypeError("Target_modules[0] should be a subgraph and it's type should be nn.Cell(nn.CellList),"
|
|
309
|
+
"but got type {}".format(type(net)))
|
|
310
|
+
if target_modules[0] != '' and i != len(path_list):
|
|
311
|
+
raise ValueError("the path {} does not exist.".format(target_modules[0]))
|
|
312
|
+
# check whether target_list is valid
|
|
313
|
+
global OBF_RATIOS_WIDTH
|
|
314
|
+
OBF_RATIOS_WIDTH = 0
|
|
315
|
+
for target in target_list:
|
|
316
|
+
if not hasattr(net, target):
|
|
317
|
+
logger.warning("{} does not exist in the path {}".format(target, target_modules[0]))
|
|
318
|
+
else:
|
|
319
|
+
OBF_RATIOS_WIDTH += 1
|
|
320
|
+
if OBF_RATIOS_WIDTH == 0:
|
|
321
|
+
raise ValueError("all targets {} do not exist in the path {}.".format(target_list, target_modules[0]))
|
|
322
|
+
_update_max_obf_ratios_num(target_modules)
|
|
323
|
+
return True
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _update_max_obf_ratios_num(target_modules):
|
|
327
|
+
"""Update MAX_OBF_RATIOS_NUM"""
|
|
328
|
+
if len(target_modules) >= 3:
|
|
329
|
+
obfuscate_layers = target_modules[2].split(':')
|
|
330
|
+
if len(obfuscate_layers) != 2 or obfuscate_layers[0] != 'obfuscate_layers':
|
|
331
|
+
raise ValueError("The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
|
|
332
|
+
"'obfuscate_layers:int'")
|
|
333
|
+
global MAX_OBF_RATIOS_NUM
|
|
334
|
+
if obfuscate_layers[1] == 'all':
|
|
335
|
+
MAX_OBF_RATIOS_NUM = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
|
|
336
|
+
else:
|
|
337
|
+
if not obfuscate_layers[1].isdigit():
|
|
338
|
+
raise ValueError(
|
|
339
|
+
"The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
|
|
340
|
+
"'obfuscate_layers:int'")
|
|
341
|
+
MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _get_default_target_modules(ckpt_files):
|
|
345
|
+
"""Get the default or suggested target modules, if the target modules is None."""
|
|
346
|
+
|
|
347
|
+
def _split_to_path_and_target(module, target):
|
|
348
|
+
# split module into path list and target list
|
|
349
|
+
target_index = module.index(target)
|
|
350
|
+
path = module[:target_index - 1]
|
|
351
|
+
target = module[target_index:].split('/')[0]
|
|
352
|
+
return path, target
|
|
353
|
+
|
|
354
|
+
def _find_default_obfuscate_modules(net_path):
|
|
355
|
+
# find modules including the default paths
|
|
356
|
+
default_module = {'attention'}
|
|
357
|
+
for module in default_module:
|
|
358
|
+
if module in net_path and module not in candidate_modules:
|
|
359
|
+
candidate_modules.append(net_path)
|
|
360
|
+
# find the default targets in the default module
|
|
361
|
+
default_target = {'dense', 'query', 'key', 'value'}
|
|
362
|
+
for target in default_target:
|
|
363
|
+
for candidate in candidate_modules:
|
|
364
|
+
if target in candidate:
|
|
365
|
+
path, target = _split_to_path_and_target(candidate, target)
|
|
366
|
+
if path not in paths:
|
|
367
|
+
paths.append(path)
|
|
368
|
+
if target not in targets:
|
|
369
|
+
targets.append(target)
|
|
370
|
+
|
|
371
|
+
def _find_suggested_obfuscate_modules(net_path):
|
|
372
|
+
default_target = {'dense', 'query', 'key', 'value'}
|
|
373
|
+
for target in default_target:
|
|
374
|
+
# find the suggest modules
|
|
375
|
+
if target in net_path:
|
|
376
|
+
path, target = _split_to_path_and_target(net_path, target)
|
|
377
|
+
if [path, target] not in suggest_modules:
|
|
378
|
+
suggest_modules.append([path, target])
|
|
379
|
+
|
|
380
|
+
# store the potential candidate_modules
|
|
381
|
+
candidate_modules = []
|
|
382
|
+
suggest_modules = []
|
|
383
|
+
paths = []
|
|
384
|
+
targets = []
|
|
385
|
+
ckpt_dir_files = os.listdir(ckpt_files)
|
|
386
|
+
for ckpt_name in ckpt_dir_files:
|
|
387
|
+
if not ckpt_name.endswith('.ckpt'):
|
|
388
|
+
continue
|
|
389
|
+
try:
|
|
390
|
+
ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
|
|
391
|
+
except (ValueError, TypeError, OSError):
|
|
392
|
+
logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
|
|
393
|
+
return None
|
|
394
|
+
for item in ckpt_param:
|
|
395
|
+
param_path = _remove_digit(item)
|
|
396
|
+
param_path = '/'.join(param_path)
|
|
397
|
+
# find candidate modules including the default paths and append candidate_modules
|
|
398
|
+
_find_default_obfuscate_modules(param_path)
|
|
399
|
+
# give the suggested modules and find the default targets in the default module
|
|
400
|
+
_find_suggested_obfuscate_modules(param_path)
|
|
401
|
+
if paths and targets:
|
|
402
|
+
target_modules = [paths[0], '|'.join(targets)]
|
|
403
|
+
logger.warning("The default obfuscate modules is obtained:{}".format(target_modules))
|
|
404
|
+
return target_modules
|
|
405
|
+
# logging the suggested target module
|
|
406
|
+
logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}"
|
|
407
|
+
.format(suggest_modules))
|
|
408
|
+
raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']")
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _get_valid_module(item, path_list, target_list):
|
|
412
|
+
"""get the valid module"""
|
|
413
|
+
number_path = len(path_list)
|
|
414
|
+
net_path = _remove_digit(item)
|
|
415
|
+
net_path = '/'.join(net_path[:number_path])
|
|
416
|
+
tar_path = '/'.join(path_list)
|
|
417
|
+
# update the weights with obf_ratios in target module
|
|
418
|
+
if net_path == tar_path:
|
|
419
|
+
for target in target_list:
|
|
420
|
+
if target in item.split('.'):
|
|
421
|
+
target_index = item.split('.').index(target)
|
|
422
|
+
module = ''.join(item.split('.')[:target_index + 1])
|
|
423
|
+
return module
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _remove_digit(item):
|
|
428
|
+
"""remove digit in the parameter path"""
|
|
429
|
+
param_path = item.split('.')
|
|
430
|
+
for tmp_str in param_path[:]:
|
|
431
|
+
if tmp_str.isdigit():
|
|
432
|
+
param_path.remove(tmp_str)
|
|
433
|
+
return param_path
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
|
|
437
|
+
"""obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
|
|
438
|
+
|
|
439
|
+
def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'):
|
|
440
|
+
"""add inputs for passing obf_ratio"""
|
|
441
|
+
last_input = None
|
|
442
|
+
for node in stree.nodes():
|
|
443
|
+
if node.get_node_type() == NodeType.Input:
|
|
444
|
+
last_input = node
|
|
445
|
+
position = stree.after(last_input)
|
|
446
|
+
# the insert input node name would be 'input_y_obf'
|
|
447
|
+
new_input_node = last_input.create_input(arg_name)
|
|
448
|
+
stree.insert(position, new_input_node)
|
|
449
|
+
|
|
450
|
+
def _insert_mul(stree: SymbolTree, node: Node, index: int):
|
|
451
|
+
"""add mul operation for original network"""
|
|
452
|
+
arg_list = node.get_targets().copy()
|
|
453
|
+
input_y_node = stree.get_node("input_y_obf")
|
|
454
|
+
v: str = input_y_node.get_targets()[0].value
|
|
455
|
+
sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]')
|
|
456
|
+
arg_list.append(sv)
|
|
457
|
+
target_list = node.get_targets().copy()
|
|
458
|
+
if data_parallel_num > 1:
|
|
459
|
+
logger.info("Data parallel number is: {}".format(data_parallel_num))
|
|
460
|
+
new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
|
|
461
|
+
targets=target_list, args=arg_list, name='mul')
|
|
462
|
+
else:
|
|
463
|
+
new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
|
|
464
|
+
position = stree.after(node)
|
|
465
|
+
stree.insert(position, new_mul_node)
|
|
466
|
+
|
|
467
|
+
def _insert_mul_by_name(stree: SymbolTree, after_name_list: list):
|
|
468
|
+
"""add mul operation after the target nodes according the name of them"""
|
|
469
|
+
if not after_name_list:
|
|
470
|
+
return
|
|
471
|
+
for node in stree.nodes():
|
|
472
|
+
for after_name in after_name_list:
|
|
473
|
+
if node.get_name() == after_name:
|
|
474
|
+
global OBF_RATIOS_INSERT_INDEX
|
|
475
|
+
if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM:
|
|
476
|
+
_insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX)
|
|
477
|
+
OBF_RATIOS_INSERT_INDEX += 1
|
|
478
|
+
|
|
479
|
+
def _update_subnet(stree: SymbolTree, substree: SymbolTree, subnode: Node):
|
|
480
|
+
"""update the network once the subnet is obfuscated"""
|
|
481
|
+
new_net = substree.get_network()
|
|
482
|
+
input_y_node = substree.get_node("input_y_obf")
|
|
483
|
+
if input_y_node is None:
|
|
484
|
+
return
|
|
485
|
+
arg_list = subnode.get_args().copy()
|
|
486
|
+
kwargs_list = list(subnode.get_kwargs().values())
|
|
487
|
+
arg_list.extend(kwargs_list)
|
|
488
|
+
v: str = input_y_node.get_targets()[0].value
|
|
489
|
+
arg_obf: ScopedValue = ScopedValue.create_naming_value("y_obf=" + v)
|
|
490
|
+
arg_list.append(arg_obf)
|
|
491
|
+
target_list = subnode.get_targets().copy()
|
|
492
|
+
name = subnode.get_name()
|
|
493
|
+
new_node = subnode.create_call_cell(cell=new_net, targets=target_list, args=arg_list, name=name)
|
|
494
|
+
stree.replace(subnode, [new_node])
|
|
495
|
+
|
|
496
|
+
def _traverse(stree, i=0):
|
|
497
|
+
"""traverse and obfuscate the original network"""
|
|
498
|
+
if len(path_list) == i:
|
|
499
|
+
return
|
|
500
|
+
for node in stree.nodes():
|
|
501
|
+
node_name = node.get_name()
|
|
502
|
+
if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]):
|
|
503
|
+
sub_stree = TreeNodeHelper.get_sub_tree(node)
|
|
504
|
+
_traverse(sub_stree, i + 1)
|
|
505
|
+
_insert_input(sub_stree, arg_name='y_obf')
|
|
506
|
+
_insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1])
|
|
507
|
+
_update_subnet(stree, sub_stree, node)
|
|
508
|
+
|
|
509
|
+
def _register_denied_func_decorators(fn):
|
|
510
|
+
"""set the function decorators which should be denied for parse"""
|
|
511
|
+
name = "denied_function_decorator_list"
|
|
512
|
+
setattr(ClassDefParser, name, fn)
|
|
513
|
+
|
|
514
|
+
def _register_denied_class_decorators(fn):
|
|
515
|
+
"""set the class decorators which should be denied for parse"""
|
|
516
|
+
name = "denied_class_decorator_list"
|
|
517
|
+
setattr(ModuleParser, name, fn)
|
|
518
|
+
|
|
519
|
+
if 'ignored_func_decorators' in kwargs.keys():
|
|
520
|
+
kw_func_dec = kwargs["ignored_func_decorators"]
|
|
521
|
+
if not isinstance(kw_func_dec, list):
|
|
522
|
+
raise TypeError('{} should be list, but got {}'.format(kw_func_dec, type(kw_func_dec)))
|
|
523
|
+
if kw_func_dec and not isinstance(kw_func_dec[0], str):
|
|
524
|
+
raise TypeError('elements of {} should be str, but got {}'.format(kw_func_dec, type(kw_func_dec[0])))
|
|
525
|
+
_register_denied_func_decorators(kw_func_dec)
|
|
526
|
+
else:
|
|
527
|
+
_register_denied_func_decorators(["_args_type_validator_check", "_LogActionOnce", "cell_attr_register"])
|
|
528
|
+
if 'ignored_class_decorators' in kwargs.keys():
|
|
529
|
+
kw_class_dec = kwargs["ignored_class_decorators"]
|
|
530
|
+
_register_denied_class_decorators(kw_class_dec)
|
|
531
|
+
if not isinstance(kw_class_dec, list):
|
|
532
|
+
raise TypeError('{} should be list[str] type, but got {}'.format(kw_class_dec, type(kw_class_dec)))
|
|
533
|
+
if kw_class_dec and not isinstance(kw_class_dec[0], str):
|
|
534
|
+
raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0])))
|
|
535
|
+
|
|
536
|
+
main_stree = SymbolTree.create(model)
|
|
537
|
+
_traverse(main_stree, 0)
|
|
538
|
+
_insert_input(main_stree, arg_name='y_obf')
|
|
539
|
+
_insert_mul_by_name(main_stree, after_name_list=target_list[0])
|
|
540
|
+
new_net = main_stree.get_network()
|
|
541
|
+
return new_net
|
mindspore/scipy/linalg.py
CHANGED
|
@@ -461,8 +461,8 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
|
|
|
461
461
|
loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
|
|
462
462
|
x = permutation[..., i]
|
|
463
463
|
y = permutation[loc + (j,)]
|
|
464
|
-
permutation[..., i] = y
|
|
465
464
|
permutation[loc + (j,)] = x
|
|
465
|
+
permutation[..., i] = y
|
|
466
466
|
return permutation
|
|
467
467
|
|
|
468
468
|
|
mindspore/scipy/ops.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -156,14 +156,64 @@ class LU(PrimitiveWithInfer):
|
|
|
156
156
|
|
|
157
157
|
|
|
158
158
|
class LinearSumAssignment(Primitive):
|
|
159
|
-
"""
|
|
159
|
+
r"""
|
|
160
|
+
Solve the linear sum assignment problem.
|
|
161
|
+
|
|
162
|
+
The assignment problem is represented as follows:
|
|
163
|
+
|
|
164
|
+
.. math::
|
|
165
|
+
min\sum_{i}^{} \sum_{j}^{} C_{i,j} X_{i,j}
|
|
166
|
+
|
|
167
|
+
where :math:`C` is cost matrix, :math:`X_{i,j} = 1` means column :math:`j` is assigned to row :math:`i` .
|
|
168
|
+
|
|
169
|
+
Inputs:
|
|
170
|
+
- **cost_matrix** (Tensor) - 2-D cost matrix. Tensor of shape :math:`(M, N)` .
|
|
171
|
+
- **dimension_limit** (Tensor, optional) - A scalar used to limit the actual size of the 2nd dimension of
|
|
172
|
+
``cost_matrix``. Default is ``Tensor(sys.maxsize)``, which means no limitation. The type is 0-D int64
|
|
173
|
+
Tensor.
|
|
174
|
+
- **maximize** (bool) - Calculate a maximum weight matching if true, otherwise calculate a minimum weight
|
|
175
|
+
matching.
|
|
176
|
+
|
|
177
|
+
Outputs:
|
|
178
|
+
A tuple of tensors containing 'row_idx' and 'col_idx'.
|
|
179
|
+
|
|
180
|
+
- **row_idx** (Tensor) - Row indices of the problem. If `dimension_limit` is given, -1 would be padded at the
|
|
181
|
+
end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
|
|
182
|
+
- **col_idx** (Tensor) - Column indices of the problem. If `dimension_limit` is given, -1 would be padded at
|
|
183
|
+
the end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
TypeError: If the data type of `cost_matrix` is not the type in [float16, float32, float64,
|
|
187
|
+
int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool]
|
|
188
|
+
TypeError: If the type of `maximize` is not bool.
|
|
189
|
+
TypeError: If the data type of `dimension_limit` is not int64.
|
|
190
|
+
ValueError: If the rank of `cost_matrix` is not 2.
|
|
191
|
+
ValueError: If the number of input args is not 3.
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
Supported Platforms:
|
|
195
|
+
``Ascend`` ``CPU``
|
|
196
|
+
|
|
197
|
+
Examples:
|
|
198
|
+
>>> import mindspore as ms
|
|
199
|
+
>>> import numpy as np
|
|
200
|
+
>>> from mindspore import Tensor
|
|
201
|
+
>>> from mindspore.scipy.ops import LinearSumAssignment
|
|
202
|
+
>>> lsap = LinearSumAssignment()
|
|
203
|
+
>>> cost_matrix = Tensor(np.array([[2, 3, 3], [3, 2, 3], [3, 3, 2]])).astype(ms.float64)
|
|
204
|
+
>>> dimension_limit = Tensor(2)
|
|
205
|
+
>>> maximize = False
|
|
206
|
+
>>> a, b = lsap(cost_matrix, dimension_limit, maximize)
|
|
207
|
+
>>> print(a)
|
|
208
|
+
[0 1 -1]
|
|
209
|
+
>>> print(b)
|
|
210
|
+
[0 1 -1]
|
|
211
|
+
"""
|
|
160
212
|
|
|
161
213
|
@prim_attr_register
|
|
162
214
|
def __init__(self):
|
|
163
|
-
super().__init__("LinearSumAssignment")
|
|
215
|
+
super().__init__(name="LinearSumAssignment")
|
|
164
216
|
self.init_prim_io_names(inputs=['cost_matrix', 'dimension_limit', 'maximize'], outputs=['row_ind', 'col_ind'])
|
|
165
|
-
self.add_prim_attr("cust_aicpu", "mindspore_aicpu_kernels")
|
|
166
|
-
|
|
167
217
|
|
|
168
218
|
# pylint: disable=C0413,W0611
|
|
169
219
|
from .ops_grad import get_bprpo_eigh, get_bprpo_trsm
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2021-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -15,5 +15,6 @@
|
|
|
15
15
|
"""Optimize submodule"""
|
|
16
16
|
from .minimize import minimize
|
|
17
17
|
from .line_search import line_search
|
|
18
|
+
from .linear_sum_assignment import linear_sum_assignment
|
|
18
19
|
|
|
19
|
-
__all__ = ["minimize", "line_search"]
|
|
20
|
+
__all__ = ["minimize", "line_search", "linear_sum_assignment"]
|