mindspore 2.1.0__cp37-none-any.whl → 2.2.10__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/container.py
CHANGED
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
"""container"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
from collections import OrderedDict
|
|
18
|
+
from collections import OrderedDict, abc
|
|
19
19
|
from abc import abstractmethod
|
|
20
20
|
|
|
21
21
|
from mindspore.nn.cell import Cell
|
|
22
22
|
|
|
23
|
-
__all__ = ['SequentialCell', 'CellList']
|
|
23
|
+
__all__ = ['SequentialCell', 'CellList', 'CellDict']
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def _valid_index(cell_num, index, op_name=None):
|
|
@@ -34,6 +34,20 @@ def _valid_index(cell_num, index, op_name=None):
|
|
|
34
34
|
return index % cell_num
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def _valid_index_for_inserting(cell_num, index, op_name=None):
|
|
38
|
+
"""
|
|
39
|
+
Internal function, used to detect the value and type of index for inserting Cell in
|
|
40
|
+
SequentialCell or CellList.
|
|
41
|
+
"""
|
|
42
|
+
msg_prefix = f"For '{op_name}', the" if op_name else "The"
|
|
43
|
+
if not isinstance(index, int):
|
|
44
|
+
raise TypeError(f"{msg_prefix} type of 'index' must be int, but got {type(index).__name__}.")
|
|
45
|
+
if not -cell_num <= index <= cell_num:
|
|
46
|
+
raise IndexError(f"{msg_prefix} value of 'index' must be a number in range [{-cell_num}, {cell_num}], "
|
|
47
|
+
f"but got {index}.")
|
|
48
|
+
return index % cell_num if (cell_num != 0 and index != cell_num) else index
|
|
49
|
+
|
|
50
|
+
|
|
37
51
|
def _valid_cell(cell, op_name=None):
|
|
38
52
|
"""Internal function, used to check whether the input cell is a subclass of Cell."""
|
|
39
53
|
if issubclass(cell.__class__, Cell):
|
|
@@ -109,7 +123,7 @@ class _CellListBase:
|
|
|
109
123
|
class SequentialCell(Cell):
|
|
110
124
|
"""
|
|
111
125
|
Sequential Cell container. For more details about Cell, please refer to
|
|
112
|
-
`Cell <https://www.mindspore.cn/docs/en/r2.
|
|
126
|
+
`Cell <https://www.mindspore.cn/docs/en/r2.2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell>`_.
|
|
113
127
|
|
|
114
128
|
A list of Cells will be added to it in the order they are passed in the constructor.
|
|
115
129
|
Alternatively, an ordered dict of cells can also be passed in.
|
|
@@ -290,7 +304,7 @@ class SequentialCell(Cell):
|
|
|
290
304
|
cell(Cell): The Cell to be inserted.
|
|
291
305
|
"""
|
|
292
306
|
cls_name = self.__class__.__name__
|
|
293
|
-
idx =
|
|
307
|
+
idx = _valid_index_for_inserting(len(self), index, cls_name)
|
|
294
308
|
_valid_cell(cell, cls_name)
|
|
295
309
|
length = len(self)
|
|
296
310
|
prefix, key_index = _get_prefix_and_index(self._cells)
|
|
@@ -311,10 +325,11 @@ class SequentialCell(Cell):
|
|
|
311
325
|
class CellList(_CellListBase, Cell):
|
|
312
326
|
"""
|
|
313
327
|
Holds Cells in a list. For more details about Cell, please refer to
|
|
314
|
-
`Cell <https://www.mindspore.cn/docs/en/r2.
|
|
328
|
+
`Cell <https://www.mindspore.cn/docs/en/r2.2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell>`_.
|
|
315
329
|
|
|
316
|
-
CellList can be used like a regular Python list, the Cells it contains have been initialized
|
|
317
|
-
|
|
330
|
+
CellList can be used like a regular Python list, the Cells it contains have been initialized and
|
|
331
|
+
the types of Cells it contains can not be CellDict.
|
|
332
|
+
Unlike the SequentialCell, the cells in CellList are not connected.
|
|
318
333
|
|
|
319
334
|
Args:
|
|
320
335
|
args (list, optional): List of subclass of Cell.
|
|
@@ -413,7 +428,7 @@ class CellList(_CellListBase, Cell):
|
|
|
413
428
|
cell(Cell): The Cell to be inserted.
|
|
414
429
|
"""
|
|
415
430
|
cls_name = self.__class__.__name__
|
|
416
|
-
idx =
|
|
431
|
+
idx = _valid_index_for_inserting(len(self), index, cls_name)
|
|
417
432
|
_valid_cell(cell, cls_name)
|
|
418
433
|
length = len(self)
|
|
419
434
|
prefix, key_index = _get_prefix_and_index(self._cells)
|
|
@@ -433,7 +448,7 @@ class CellList(_CellListBase, Cell):
|
|
|
433
448
|
Appends Cells from a Python iterable to the end of the list.
|
|
434
449
|
|
|
435
450
|
Args:
|
|
436
|
-
cells(list): The Cells to be extended.
|
|
451
|
+
cells(list): The Cells to be extended, the types of Cells can not be CellDict.
|
|
437
452
|
|
|
438
453
|
Raises:
|
|
439
454
|
TypeError: If the argument cells are not a list of Cells.
|
|
@@ -444,6 +459,9 @@ class CellList(_CellListBase, Cell):
|
|
|
444
459
|
f"should be instance of list, but got {type(cells).__name__}.")
|
|
445
460
|
prefix, _ = _get_prefix_and_index(self._cells)
|
|
446
461
|
for cell in cells:
|
|
462
|
+
if isinstance(cell, CellDict):
|
|
463
|
+
raise TypeError(f"For '{cls_name}', the type of cell can not be CellDict, "
|
|
464
|
+
f"but got {type(cell).__name__}.")
|
|
447
465
|
if _valid_cell(cell, cls_name):
|
|
448
466
|
if self._auto_prefix:
|
|
449
467
|
cell.update_parameters_name(prefix + str(len(self)) + ".")
|
|
@@ -470,3 +488,247 @@ class CellList(_CellListBase, Cell):
|
|
|
470
488
|
|
|
471
489
|
def construct(self, *inputs):
|
|
472
490
|
raise NotImplementedError
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
class _CellDictBase:
|
|
494
|
+
"""
|
|
495
|
+
An interface for base the Cell as dict.
|
|
496
|
+
|
|
497
|
+
The sequential Cell may be iterated using the construct method using for-in statement.
|
|
498
|
+
But there are some scenarios that the construct method built-in does not fit.
|
|
499
|
+
For convenience, we provide an interface that indicates the sequential
|
|
500
|
+
Cell may be interpreted as dict of Cells, so it can be accessed using
|
|
501
|
+
key when a sequential Cell instantiate is accessed by key,
|
|
502
|
+
it will be interpreted as a dict of Cells.
|
|
503
|
+
"""
|
|
504
|
+
def __init__(self):
|
|
505
|
+
"""Initialize _CellDictBase."""
|
|
506
|
+
self.__cell_as_dict__ = True
|
|
507
|
+
|
|
508
|
+
@abstractmethod
|
|
509
|
+
def __len__(self):
|
|
510
|
+
pass
|
|
511
|
+
|
|
512
|
+
@abstractmethod
|
|
513
|
+
def __getitem__(self, index):
|
|
514
|
+
pass
|
|
515
|
+
|
|
516
|
+
def construct(self):
|
|
517
|
+
raise NotImplementedError
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
class CellDict(_CellDictBase, Cell):
|
|
521
|
+
"""
|
|
522
|
+
Holds Cells in a dictionary. For more details about `Cell` , please refer to :class:`mindspore.nn.Cell` .
|
|
523
|
+
|
|
524
|
+
`CellDict` can be used like a regular Python dictionary.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
args (iterable, optional): An iterable of key-value pairs of (key, Cell), the type of key-value pairs is
|
|
528
|
+
(string, Cell); Or a mapping(dictionary) from string to Cell.
|
|
529
|
+
The type of Cell can not be CellDict, CellList or SequentialCell.
|
|
530
|
+
The key can not be same with the attributes of class Cell, can not contain '.',
|
|
531
|
+
can not be an empty string.
|
|
532
|
+
The key of type string is used to search corresponding Cell in the CellDict.
|
|
533
|
+
kwargs (dict): Reserved for keyword argument to be expanded.
|
|
534
|
+
|
|
535
|
+
Supported Platforms:
|
|
536
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
537
|
+
|
|
538
|
+
Examples:
|
|
539
|
+
>>> import collections
|
|
540
|
+
>>> from collections import OrderedDict
|
|
541
|
+
>>> import mindspore as ms
|
|
542
|
+
>>> import numpy as np
|
|
543
|
+
>>> from mindspore import Tensor, nn
|
|
544
|
+
>>>
|
|
545
|
+
>>> cell_dict = nn.CellDict({'conv': nn.Conv2d(10, 6, 5),
|
|
546
|
+
... 'relu': nn.ReLU(),
|
|
547
|
+
... 'max_pool2d': nn.MaxPool2d(kernel_size=4, stride=4)})
|
|
548
|
+
>>> print(len(cell_dict))
|
|
549
|
+
3
|
|
550
|
+
>>> cell_dict.clear()
|
|
551
|
+
>>> print(len(cell_dict))
|
|
552
|
+
0
|
|
553
|
+
>>> ordered_cells = OrderedDict([('conv', nn.Conv2d(10, 6, 5, pad_mode='valid')),
|
|
554
|
+
... ('relu', nn.ReLU()),
|
|
555
|
+
... ('max_pool2d', nn.MaxPool2d(kernel_size=2, stride=2))])
|
|
556
|
+
>>> cell_dict.update(ordered_cells)
|
|
557
|
+
>>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
|
|
558
|
+
>>> for cell in cell_dict.values():
|
|
559
|
+
... x = cell(x)
|
|
560
|
+
>>> print(x.shape)
|
|
561
|
+
(1, 6, 1, 3)
|
|
562
|
+
>>> x = Tensor(np.ones([1, 10, 6, 10]), ms.float32)
|
|
563
|
+
>>> for item in cell_dict.items():
|
|
564
|
+
... x = item[1](x)
|
|
565
|
+
>>> print(x.shape)
|
|
566
|
+
(1, 6, 1, 3)
|
|
567
|
+
>>> print(cell_dict.keys())
|
|
568
|
+
odict_keys(['conv', 'relu', 'max_pool2d'])
|
|
569
|
+
>>> pop_cell = cell_dict.pop('conv')
|
|
570
|
+
>>> x = Tensor(np.ones([1, 10, 6, 5]), ms.float32)
|
|
571
|
+
>>> x = pop_cell(x)
|
|
572
|
+
>>> print(x.shape)
|
|
573
|
+
(1, 6, 2, 1)
|
|
574
|
+
>>> print(len(cell_dict))
|
|
575
|
+
2
|
|
576
|
+
"""
|
|
577
|
+
def __init__(self, *args, **kwargs):
|
|
578
|
+
"""Initialize CellDict."""
|
|
579
|
+
auto_prefix = kwargs["auto_preifx"] if "auto_prefix" in kwargs.keys() else True
|
|
580
|
+
_CellDictBase.__init__(self)
|
|
581
|
+
Cell.__init__(self, auto_prefix)
|
|
582
|
+
if len(args) == 1:
|
|
583
|
+
self.update(args[0])
|
|
584
|
+
|
|
585
|
+
def __getitem__(self, key):
|
|
586
|
+
return self._cells[key]
|
|
587
|
+
|
|
588
|
+
def __setitem__(self, key, cell):
|
|
589
|
+
self._validate_key(key)
|
|
590
|
+
self._validate_cell_type(cell)
|
|
591
|
+
self._update_cell_para_name(key, cell)
|
|
592
|
+
self._cells[key] = cell
|
|
593
|
+
|
|
594
|
+
def __delitem__(self, key):
|
|
595
|
+
del self._cells[key]
|
|
596
|
+
|
|
597
|
+
def __len__(self):
|
|
598
|
+
return len(self._cells)
|
|
599
|
+
|
|
600
|
+
def __iter__(self):
|
|
601
|
+
return iter(self._cells)
|
|
602
|
+
|
|
603
|
+
def __contains__(self, key):
|
|
604
|
+
return key in self._cells
|
|
605
|
+
|
|
606
|
+
def _validate_key(self, key):
|
|
607
|
+
"""validate key."""
|
|
608
|
+
cls_name = self.__class__.__name__
|
|
609
|
+
if not isinstance(key, str):
|
|
610
|
+
raise TypeError(f"For '{cls_name}', the type of key should be string "
|
|
611
|
+
f"but got {type(key).__name__}.")
|
|
612
|
+
if hasattr(self, key) and key not in self._cells:
|
|
613
|
+
raise KeyError(f"For '{cls_name}', the key can not be same with the attributes of Cell, "
|
|
614
|
+
f"but got key {key}.")
|
|
615
|
+
if '.' in key:
|
|
616
|
+
raise KeyError(f"For '{cls_name}', key can not contain \".\", "
|
|
617
|
+
f"but got key {key}")
|
|
618
|
+
if key == '':
|
|
619
|
+
raise KeyError(f"For '{cls_name}', key can not be empty string \"\", "
|
|
620
|
+
f"but got key {key}")
|
|
621
|
+
|
|
622
|
+
def _validate_cell_type(self, cell):
|
|
623
|
+
"""validate cell type."""
|
|
624
|
+
cls_name = self.__class__.__name__
|
|
625
|
+
if cell is None:
|
|
626
|
+
raise TypeError(f"For '{cls_name}', cell can not be None.")
|
|
627
|
+
if not isinstance(cell, Cell):
|
|
628
|
+
raise TypeError(f"For '{cls_name}', the type of cell should be Cell, "
|
|
629
|
+
f"but got {type(cell).__name__}.")
|
|
630
|
+
if isinstance(cell, (CellDict, CellList, SequentialCell)):
|
|
631
|
+
raise TypeError(f"For '{cls_name}', the type of cell can not be CellDict, CellList or SequentialCell, "
|
|
632
|
+
f"but got {type(cell).__name__}.")
|
|
633
|
+
|
|
634
|
+
def _update_cell_para_name(self, key, cell):
|
|
635
|
+
"""update cell para name."""
|
|
636
|
+
if self._auto_prefix:
|
|
637
|
+
prefix, _ = _get_prefix_and_index(self._cells)
|
|
638
|
+
cell.update_parameters_name(prefix + key + ".")
|
|
639
|
+
|
|
640
|
+
def clear(self):
|
|
641
|
+
"""
|
|
642
|
+
Remove all Cells from the CellDict.
|
|
643
|
+
"""
|
|
644
|
+
return self._cells.clear()
|
|
645
|
+
|
|
646
|
+
def pop(self, key):
|
|
647
|
+
"""
|
|
648
|
+
Remove key from the CellDict and return its cell.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
key (string): key to pop from the CellDict.
|
|
652
|
+
|
|
653
|
+
Raises:
|
|
654
|
+
KeyError: If `key` not exist in CellDict when attempt to access cell.
|
|
655
|
+
"""
|
|
656
|
+
value = self[key]
|
|
657
|
+
del self[key]
|
|
658
|
+
return value
|
|
659
|
+
|
|
660
|
+
def keys(self):
|
|
661
|
+
"""
|
|
662
|
+
Return an iterable of the CellDict keys.
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
An iterable object.
|
|
666
|
+
"""
|
|
667
|
+
return self._cells.keys()
|
|
668
|
+
|
|
669
|
+
def values(self):
|
|
670
|
+
"""
|
|
671
|
+
Return an iterable of the CellDict values.
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
An iterable object.
|
|
675
|
+
"""
|
|
676
|
+
return self._cells.values()
|
|
677
|
+
|
|
678
|
+
def items(self):
|
|
679
|
+
"""
|
|
680
|
+
Return an iterable of the CellDict key-value pairs.
|
|
681
|
+
|
|
682
|
+
Returns:
|
|
683
|
+
An iterable object.
|
|
684
|
+
"""
|
|
685
|
+
return self._cells.items()
|
|
686
|
+
|
|
687
|
+
def update(self, cells):
|
|
688
|
+
"""
|
|
689
|
+
Update the CellDict by overwriting the existing keys with the key-value pairs from a mapping or an iterable.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
cells (iterable): An iterable of key-value pairs of (key, Cell), the type of key-value pairs is
|
|
693
|
+
(string, Cell); Or a mapping(dictionary) from string to Cell.
|
|
694
|
+
The type of Cell can not be CellDict, CellList or SequentialCell.
|
|
695
|
+
The key can not be same with the attributes of class Cell, can not contain '.',
|
|
696
|
+
can not be an empty string.
|
|
697
|
+
|
|
698
|
+
Note:
|
|
699
|
+
If the `cells` is a CellDict, an OrderedDict or an iterable containing key-value pairs,
|
|
700
|
+
the order of newly added elements is maintained.
|
|
701
|
+
|
|
702
|
+
Raises:
|
|
703
|
+
TypeError: If `cells` is not an iterable object.
|
|
704
|
+
TypeError: If key-value pairs in `cells` are not iterable objects.
|
|
705
|
+
ValueError: If the length of key-value pairs in `cells` is not 2.
|
|
706
|
+
TypeError: If the cell in `cells` is None.
|
|
707
|
+
TypeError: If the type of cell in `cells` is not Cell.
|
|
708
|
+
TypeError: If the type of cell in `cells` is CellDict, CellList or SequentialCell.
|
|
709
|
+
TypeError: If the type of key in `cells` is not string.
|
|
710
|
+
KeyError: If the key in `cells` is same with the attributes of class Cell.
|
|
711
|
+
KeyError: If the key in `cells` contain ".".
|
|
712
|
+
KeyError: If the key in `cells` is an empty string.
|
|
713
|
+
"""
|
|
714
|
+
if not isinstance(cells, abc.Iterable):
|
|
715
|
+
raise TypeError("CellDict.update() should be called with an "
|
|
716
|
+
"iterable of key-value pairs, but got " +
|
|
717
|
+
type(cells).__name__)
|
|
718
|
+
if isinstance(cells, (OrderedDict, CellDict, abc.Mapping)):
|
|
719
|
+
for key, cell in cells.items():
|
|
720
|
+
self[key] = cell
|
|
721
|
+
else:
|
|
722
|
+
for id, k_v in enumerate(cells):
|
|
723
|
+
if not isinstance(k_v, abc.Iterable):
|
|
724
|
+
raise TypeError("CellDict update sequence element "
|
|
725
|
+
"#" + str(id) + " should be Iterable; but got " +
|
|
726
|
+
type(k_v).__name__)
|
|
727
|
+
if len(k_v) != 2:
|
|
728
|
+
raise ValueError("CellDict update sequence element "
|
|
729
|
+
"#" + str(id) + ", length should be 2; but has length " +
|
|
730
|
+
str(len(k_v)))
|
|
731
|
+
self[k_v[0]] = k_v[1]
|
|
732
|
+
|
|
733
|
+
def construct(self, *inputs):
|
|
734
|
+
raise NotImplementedError
|