mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/basic.py
CHANGED
|
@@ -24,7 +24,7 @@ from mindspore import context, log as logger
|
|
|
24
24
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
25
25
|
from mindspore.common.seed import _get_graph_seed
|
|
26
26
|
from mindspore.common.tensor import Tensor
|
|
27
|
-
from mindspore.common.initializer import initializer
|
|
27
|
+
from mindspore.common.initializer import initializer, HeUniform, Uniform
|
|
28
28
|
from mindspore.ops import operations as P
|
|
29
29
|
from mindspore.ops import functional as F
|
|
30
30
|
from mindspore.ops.operations import _inner_ops as inner
|
|
@@ -75,9 +75,11 @@ class L1Regularizer(Cell):
|
|
|
75
75
|
``Ascend`` ``GPU`` ``CPU``
|
|
76
76
|
|
|
77
77
|
Examples:
|
|
78
|
+
>>> import mindspore as ms
|
|
79
|
+
>>> import numpy as np
|
|
78
80
|
>>> scale = 0.5
|
|
79
|
-
>>> net = nn.L1Regularizer(scale)
|
|
80
|
-
>>> weights = Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
|
|
81
|
+
>>> net = ms.nn.L1Regularizer(scale)
|
|
82
|
+
>>> weights = ms.Tensor(np.array([[1.0, -2.0], [-3.0, 4.0]]).astype(np.float32))
|
|
81
83
|
>>> output = net(weights)
|
|
82
84
|
>>> print(output.asnumpy())
|
|
83
85
|
5.0
|
|
@@ -108,8 +110,9 @@ class Dropout(Cell):
|
|
|
108
110
|
r"""
|
|
109
111
|
Dropout layer for the input.
|
|
110
112
|
|
|
111
|
-
Dropout is a regularization
|
|
112
|
-
according to
|
|
113
|
+
Dropout is a means of regularization that reduces overfitting by preventing correlations between neuronal nodes.
|
|
114
|
+
The operator randomly sets some neurons output to 0 according to `p`, which means the probability of discarding
|
|
115
|
+
during training. And the return will be multiplied by :math:`\frac{1}{1-p}` during training.
|
|
113
116
|
During the reasoning, this layer returns the same Tensor as the `x`.
|
|
114
117
|
|
|
115
118
|
This technique is proposed in paper `Dropout: A Simple Way to Prevent Neural Networks from Overfitting
|
|
@@ -126,10 +129,10 @@ class Dropout(Cell):
|
|
|
126
129
|
|
|
127
130
|
Args:
|
|
128
131
|
keep_prob (float): Deprecated. The keep rate, greater than 0 and less equal than 1.
|
|
129
|
-
E.g. rate=0.9, dropping out 10% of input neurons. Default: 0.5.
|
|
132
|
+
E.g. rate=0.9, dropping out 10% of input neurons. Default: ``0.5`` .
|
|
130
133
|
p (Union[float, int, None]): The dropout rate, greater than or equal to 0 and less than 1.
|
|
131
|
-
E.g. rate=0.9, dropping out 90% of input neurons. Default: None.
|
|
132
|
-
dtype (:class:`mindspore.dtype`): Data type of `input`. Default:
|
|
134
|
+
E.g. rate=0.9, dropping out 90% of input neurons. Default: ``None`` .
|
|
135
|
+
dtype (:class:`mindspore.dtype`): Data type of `input`. Default: ``mstype.float32`` .
|
|
133
136
|
|
|
134
137
|
Inputs:
|
|
135
138
|
- **x** (Tensor) - The input of Dropout with data type of float16 or float32.
|
|
@@ -149,6 +152,9 @@ class Dropout(Cell):
|
|
|
149
152
|
``Ascend`` ``GPU`` ``CPU``
|
|
150
153
|
|
|
151
154
|
Examples:
|
|
155
|
+
>>> import mindspore
|
|
156
|
+
>>> from mindspore import Tensor, nn
|
|
157
|
+
>>> import numpy as np
|
|
152
158
|
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
|
|
153
159
|
>>> net = nn.Dropout(p=0.2)
|
|
154
160
|
>>> net.set_train()
|
|
@@ -216,7 +222,7 @@ class Dropout1d(Cell):
|
|
|
216
222
|
|
|
217
223
|
Args:
|
|
218
224
|
p (float, optional): The dropping probability of a channel, between 0 and 1, e.g. `p` = 0.8,
|
|
219
|
-
which means an 80% chance of being set to 0. Default: 0.5.
|
|
225
|
+
which means an 80% chance of being set to 0. Default: ``0.5`` .
|
|
220
226
|
|
|
221
227
|
Inputs:
|
|
222
228
|
- **x** (Tensor) - A tensor with shape :math:`(N, C, L)` or :math:`(C, L)`, where `N` is the batch size,
|
|
@@ -224,7 +230,7 @@ class Dropout1d(Cell):
|
|
|
224
230
|
int64, float16, float32 or float64.
|
|
225
231
|
|
|
226
232
|
Outputs:
|
|
227
|
-
Tensor,
|
|
233
|
+
Tensor, has the same shape and data type as `x`.
|
|
228
234
|
|
|
229
235
|
Raises:
|
|
230
236
|
TypeError: If `x` is not a Tensor.
|
|
@@ -238,10 +244,9 @@ class Dropout1d(Cell):
|
|
|
238
244
|
Examples:
|
|
239
245
|
>>> import numpy as np
|
|
240
246
|
>>> import mindspore as ms
|
|
241
|
-
>>>
|
|
242
|
-
>>> op = nn.Dropout1d(p=0.6)
|
|
247
|
+
>>> op = ms.nn.Dropout1d(p=0.6)
|
|
243
248
|
>>> op.training = True
|
|
244
|
-
>>> a = Tensor(np.ones((3, 3)), ms.float32)
|
|
249
|
+
>>> a = ms.Tensor(np.ones((3, 3)), ms.float32)
|
|
245
250
|
>>> output = op(a)
|
|
246
251
|
"""
|
|
247
252
|
|
|
@@ -281,6 +286,9 @@ class Dropout2d(Cell):
|
|
|
281
286
|
``Ascend`` ``GPU`` ``CPU``
|
|
282
287
|
|
|
283
288
|
Examples:
|
|
289
|
+
>>> import mindspore
|
|
290
|
+
>>> from mindspore import Tensor, nn
|
|
291
|
+
>>> import numpy as np
|
|
284
292
|
>>> dropout = nn.Dropout2d(p=0.5)
|
|
285
293
|
>>> x = Tensor(np.ones([2, 1, 2, 3]), mindspore.float32)
|
|
286
294
|
>>> output = dropout(x)
|
|
@@ -306,7 +314,7 @@ class Dropout2d(Cell):
|
|
|
306
314
|
return out
|
|
307
315
|
|
|
308
316
|
def extend_repr(self):
|
|
309
|
-
return
|
|
317
|
+
return f"p={self.keep_prob}"
|
|
310
318
|
|
|
311
319
|
|
|
312
320
|
class Dropout3d(Cell):
|
|
@@ -329,6 +337,9 @@ class Dropout3d(Cell):
|
|
|
329
337
|
``Ascend`` ``GPU`` ``CPU``
|
|
330
338
|
|
|
331
339
|
Examples:
|
|
340
|
+
>>> import mindspore
|
|
341
|
+
>>> from mindspore import Tensor, nn
|
|
342
|
+
>>> import numpy as np
|
|
332
343
|
>>> dropout = nn.Dropout3d(p=0.5)
|
|
333
344
|
>>> x = Tensor(np.ones([2, 1, 2, 1, 2]), mindspore.float32)
|
|
334
345
|
>>> output = dropout(x)
|
|
@@ -354,7 +365,7 @@ class Dropout3d(Cell):
|
|
|
354
365
|
return out
|
|
355
366
|
|
|
356
367
|
def extend_repr(self):
|
|
357
|
-
return 'p={
|
|
368
|
+
return f'p={self.keep_prob}'
|
|
358
369
|
|
|
359
370
|
|
|
360
371
|
class Upsample(Cell):
|
|
@@ -365,8 +376,9 @@ class Upsample(Cell):
|
|
|
365
376
|
``Ascend`` ``GPU`` ``CPU``
|
|
366
377
|
|
|
367
378
|
Examples:
|
|
368
|
-
>>>
|
|
369
|
-
>>>
|
|
379
|
+
>>> import mindspore as ms
|
|
380
|
+
>>> x = ms.Tensor([[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]])
|
|
381
|
+
>>> upsample = ms.nn.Upsample(size=(5, 5))
|
|
370
382
|
>>> out = upsample(x)
|
|
371
383
|
>>> print(x.asnumpy())
|
|
372
384
|
[[[[1. 2. 3. 4.]
|
|
@@ -401,8 +413,8 @@ class Flatten(Cell):
|
|
|
401
413
|
Flatten the input Tensor along dimensions from `start_dim` to `end_dim`.
|
|
402
414
|
|
|
403
415
|
Args:
|
|
404
|
-
start_dim (int, optional): The first dimension to flatten. Default: 1.
|
|
405
|
-
end_dim (int, optional): The last dimension to flatten. Default:
|
|
416
|
+
start_dim (int, optional): The first dimension to flatten. Default: ``1`` .
|
|
417
|
+
end_dim (int, optional): The last dimension to flatten. Default: ``-1`` .
|
|
406
418
|
|
|
407
419
|
Inputs:
|
|
408
420
|
- **x** (Tensor) - The input Tensor to be flattened.
|
|
@@ -421,6 +433,9 @@ class Flatten(Cell):
|
|
|
421
433
|
``Ascend`` ``GPU`` ``CPU``
|
|
422
434
|
|
|
423
435
|
Examples:
|
|
436
|
+
>>> import mindspore
|
|
437
|
+
>>> from mindspore import Tensor, nn
|
|
438
|
+
>>> import numpy as np
|
|
424
439
|
>>> x = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
|
425
440
|
>>> net = nn.Flatten()
|
|
426
441
|
>>> output = net(x)
|
|
@@ -439,13 +454,15 @@ class Flatten(Cell):
|
|
|
439
454
|
self.start_dim = start_dim
|
|
440
455
|
self.end_dim = end_dim
|
|
441
456
|
|
|
457
|
+
def check_axis_valid(self, axis, ndim):
|
|
458
|
+
if axis < -ndim or axis >= ndim:
|
|
459
|
+
raise ValueError("'start_dim' or 'end_dim' out of range.")
|
|
460
|
+
|
|
442
461
|
def construct(self, x):
|
|
443
462
|
x_rank = F.rank(x)
|
|
444
463
|
ndim = x_rank if x_rank != 0 else 1
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
if self.end_dim < -ndim or self.end_dim >= ndim:
|
|
448
|
-
const_utils.raise_value_error("'end_dim' out of range.")
|
|
464
|
+
self.check_axis_valid(self.start_dim, ndim)
|
|
465
|
+
self.check_axis_valid(self.end_dim, ndim)
|
|
449
466
|
return F.flatten(x, start_dim=self.start_dim, end_dim=self.end_dim)
|
|
450
467
|
|
|
451
468
|
|
|
@@ -458,22 +475,22 @@ def check_dense_input_shape(x, prim_name=None):
|
|
|
458
475
|
|
|
459
476
|
|
|
460
477
|
class Identity(Cell):
|
|
461
|
-
"""
|
|
462
|
-
|
|
478
|
+
r"""
|
|
479
|
+
A placeholder identity operator that returns the same as input.
|
|
463
480
|
|
|
464
481
|
Inputs:
|
|
465
|
-
- **x** (
|
|
482
|
+
- **x** (Any) - The input of Identity.
|
|
466
483
|
|
|
467
484
|
Outputs:
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
Raises:
|
|
471
|
-
TypeError: If `x` is not a Tensor.
|
|
485
|
+
The same as `x`.
|
|
472
486
|
|
|
473
487
|
Supported Platforms:
|
|
474
488
|
``Ascend`` ``GPU`` ``CPU``
|
|
475
489
|
|
|
476
490
|
Examples:
|
|
491
|
+
>>> import mindspore
|
|
492
|
+
>>> from mindspore import Tensor, nn
|
|
493
|
+
>>> import numpy as np
|
|
477
494
|
>>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
|
|
478
495
|
>>> net = nn.Identity()
|
|
479
496
|
>>> output = net(x)
|
|
@@ -484,11 +501,9 @@ class Identity(Cell):
|
|
|
484
501
|
def __init__(self):
|
|
485
502
|
"""Initialize Identity."""
|
|
486
503
|
super(Identity, self).__init__()
|
|
487
|
-
self.identity = P.Identity()
|
|
488
504
|
|
|
489
505
|
def construct(self, x):
|
|
490
|
-
|
|
491
|
-
return out
|
|
506
|
+
return x
|
|
492
507
|
|
|
493
508
|
|
|
494
509
|
class Dense(Cell):
|
|
@@ -509,13 +524,16 @@ class Dense(Cell):
|
|
|
509
524
|
in_channels (int): The number of channels in the input space.
|
|
510
525
|
out_channels (int): The number of channels in the output space.
|
|
511
526
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
512
|
-
is same as `x`. The values of str refer to the function `initializer`. Default:
|
|
527
|
+
is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
528
|
+
weight will be initialized using HeUniform.
|
|
513
529
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
514
|
-
same as `x`. The values of str refer to the function `initializer`. Default:
|
|
515
|
-
|
|
530
|
+
same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
531
|
+
bias will be initialized using Uniform.
|
|
532
|
+
has_bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
|
|
516
533
|
activation (Union[str, Cell, Primitive, None]): activate function applied to the output of the fully connected
|
|
517
534
|
layer. Both activation name, e.g. 'relu', and mindspore activation function, e.g. mindspore.ops.ReLU(),
|
|
518
|
-
are supported. Default: None.
|
|
535
|
+
are supported. Default: ``None`` .
|
|
536
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
519
537
|
|
|
520
538
|
Inputs:
|
|
521
539
|
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
|
@@ -537,6 +555,9 @@ class Dense(Cell):
|
|
|
537
555
|
``Ascend`` ``GPU`` ``CPU``
|
|
538
556
|
|
|
539
557
|
Examples:
|
|
558
|
+
>>> import mindspore
|
|
559
|
+
>>> from mindspore import Tensor, nn
|
|
560
|
+
>>> import numpy as np
|
|
540
561
|
>>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
|
541
562
|
>>> net = nn.Dense(3, 4)
|
|
542
563
|
>>> output = net(x)
|
|
@@ -548,10 +569,11 @@ class Dense(Cell):
|
|
|
548
569
|
def __init__(self,
|
|
549
570
|
in_channels,
|
|
550
571
|
out_channels,
|
|
551
|
-
weight_init=
|
|
552
|
-
bias_init=
|
|
572
|
+
weight_init=None,
|
|
573
|
+
bias_init=None,
|
|
553
574
|
has_bias=True,
|
|
554
|
-
activation=None
|
|
575
|
+
activation=None,
|
|
576
|
+
dtype=mstype.float32):
|
|
555
577
|
"""Initialize Dense."""
|
|
556
578
|
super(Dense, self).__init__()
|
|
557
579
|
self.in_channels = Validator.check_positive_int(
|
|
@@ -570,8 +592,10 @@ class Dense(Cell):
|
|
|
570
592
|
f"be equal to 2, and the first dim must be equal to 'out_channels', and the "
|
|
571
593
|
f"second dim must be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
|
|
572
594
|
f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
|
|
595
|
+
if weight_init is None:
|
|
596
|
+
weight_init = HeUniform(math.sqrt(5))
|
|
573
597
|
self.weight = Parameter(initializer(
|
|
574
|
-
weight_init, [out_channels, in_channels]), name="weight")
|
|
598
|
+
weight_init, [out_channels, in_channels], dtype=dtype), name="weight")
|
|
575
599
|
|
|
576
600
|
self.bias = None
|
|
577
601
|
if self.has_bias:
|
|
@@ -580,8 +604,11 @@ class Dense(Cell):
|
|
|
580
604
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
|
581
605
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
582
606
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
607
|
+
if bias_init is None:
|
|
608
|
+
bound = 1 / math.sqrt(in_channels)
|
|
609
|
+
bias_init = Uniform(scale=bound)
|
|
583
610
|
self.bias = Parameter(initializer(
|
|
584
|
-
bias_init, [out_channels]), name="bias")
|
|
611
|
+
bias_init, [out_channels], dtype=dtype), name="bias")
|
|
585
612
|
self.bias_add = P.BiasAdd()
|
|
586
613
|
|
|
587
614
|
self.matmul = P.MatMul(transpose_b=True)
|
|
@@ -608,12 +635,11 @@ class Dense(Cell):
|
|
|
608
635
|
return x
|
|
609
636
|
|
|
610
637
|
def extend_repr(self):
|
|
611
|
-
s = 'input_channels={}, output_channels={}'
|
|
612
|
-
self.in_channels, self.out_channels)
|
|
638
|
+
s = f'input_channels={self.in_channels}, output_channels={self.out_channels}'
|
|
613
639
|
if self.has_bias:
|
|
614
|
-
s += ', has_bias={
|
|
640
|
+
s += f', has_bias={self.has_bias}'
|
|
615
641
|
if self.activation_flag:
|
|
616
|
-
s += ', activation={
|
|
642
|
+
s += f', activation={self.activation}'
|
|
617
643
|
return s
|
|
618
644
|
|
|
619
645
|
|
|
@@ -660,7 +686,7 @@ class ClipByNorm(Cell):
|
|
|
660
686
|
|
|
661
687
|
Args:
|
|
662
688
|
axis (Union[None, int, tuple(int)]): Compute the L2-norm along the Specific dimension.
|
|
663
|
-
Default: None, all dimensions to calculate.
|
|
689
|
+
Default: ``None`` , all dimensions to calculate.
|
|
664
690
|
|
|
665
691
|
Inputs:
|
|
666
692
|
- **x** (Tensor) - Tensor of shape N-D. The type must be float32 or float16.
|
|
@@ -678,6 +704,9 @@ class ClipByNorm(Cell):
|
|
|
678
704
|
``Ascend`` ``GPU`` ``CPU``
|
|
679
705
|
|
|
680
706
|
Examples:
|
|
707
|
+
>>> import mindspore
|
|
708
|
+
>>> from mindspore import Tensor, nn
|
|
709
|
+
>>> import numpy as np
|
|
681
710
|
>>> net = nn.ClipByNorm()
|
|
682
711
|
>>> x = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
|
|
683
712
|
>>> clip_norm = Tensor(np.array([100]).astype(np.float32))
|
|
@@ -699,8 +728,8 @@ class ClipByNorm(Cell):
|
|
|
699
728
|
|
|
700
729
|
class Norm(Cell):
|
|
701
730
|
r"""
|
|
702
|
-
|
|
703
|
-
|
|
731
|
+
The Norm class will be deprecated in the future,
|
|
732
|
+
this function can be replaced by :func:`ops.norm`
|
|
704
733
|
"""
|
|
705
734
|
|
|
706
735
|
@deprecated("2.0", "ops.norm", False)
|
|
@@ -723,13 +752,13 @@ class Norm(Cell):
|
|
|
723
752
|
return x
|
|
724
753
|
|
|
725
754
|
def extend_repr(self):
|
|
726
|
-
return 'axis={}, keep_dims={
|
|
755
|
+
return f'axis={self.axis}, keep_dims={self.keep_dims}'
|
|
727
756
|
|
|
728
757
|
|
|
729
758
|
class OneHot(Cell):
|
|
730
759
|
"""
|
|
731
|
-
|
|
732
|
-
|
|
760
|
+
The OneHot class will be deprecated in the future,
|
|
761
|
+
this function can be replaced by :func:`ops.one_hot`
|
|
733
762
|
"""
|
|
734
763
|
|
|
735
764
|
@deprecated("2.0", "ops.one_hot", False)
|
|
@@ -769,8 +798,8 @@ class Pad(Cell):
|
|
|
769
798
|
# 2nd dimension of output is paddings[1][0] + 3 + paddings[1][1] = 2 + 3 + 2 = 7.
|
|
770
799
|
# So the shape of output is (5, 7).
|
|
771
800
|
|
|
772
|
-
mode (str): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC".
|
|
773
|
-
Default: "CONSTANT".
|
|
801
|
+
mode (str): Specifies padding mode. The optional values are ``"CONSTANT"`` , ``"REFLECT"`` , ``"SYMMETRIC"`` .
|
|
802
|
+
Default: ``"CONSTANT"`` .
|
|
774
803
|
|
|
775
804
|
Inputs:
|
|
776
805
|
- **x** (Tensor) - The input tensor.
|
|
@@ -792,14 +821,14 @@ class Pad(Cell):
|
|
|
792
821
|
Raises:
|
|
793
822
|
TypeError: If `paddings` is not a tuple.
|
|
794
823
|
ValueError: If length of `paddings` is more than 4 or its shape is not :math:`(N, 2)` .
|
|
795
|
-
ValueError: If `mode` is not one of
|
|
824
|
+
ValueError: If `mode` is not one of ``"CONSTANT"``, ``"REFLECT"``, ``"SYMMETRIC"``.
|
|
796
825
|
|
|
797
826
|
Supported Platforms:
|
|
798
827
|
``Ascend`` ``GPU`` ``CPU``
|
|
799
828
|
|
|
800
829
|
Examples:
|
|
801
|
-
>>>
|
|
802
|
-
>>>
|
|
830
|
+
>>> import mindspore
|
|
831
|
+
>>> from mindspore import Tensor, nn, ops
|
|
803
832
|
>>> import numpy as np
|
|
804
833
|
>>> # If `mode` is "CONSTANT"
|
|
805
834
|
>>> class Net(nn.Cell):
|
|
@@ -912,7 +941,7 @@ class Pad(Cell):
|
|
|
912
941
|
return x
|
|
913
942
|
|
|
914
943
|
|
|
915
|
-
@
|
|
944
|
+
@_primexpr
|
|
916
945
|
def bilinear(shape, size, scale, align_corners, prim_name=None):
|
|
917
946
|
"""Check input and calculate shape"""
|
|
918
947
|
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
@@ -945,6 +974,8 @@ class ResizeBilinear(Cell):
|
|
|
945
974
|
Deprecated
|
|
946
975
|
|
|
947
976
|
Examples:
|
|
977
|
+
>>> import mindspore
|
|
978
|
+
>>> from mindspore import Tensor, nn
|
|
948
979
|
>>> x = Tensor([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], mindspore.float32)
|
|
949
980
|
>>> resize_bilinear = nn.ResizeBilinear()
|
|
950
981
|
>>> result = resize_bilinear(x, size=(5,5))
|
|
@@ -988,20 +1019,22 @@ class Unfold(Cell):
|
|
|
988
1019
|
must be a tuple or list of int, and the format is [1, stride_row, stride_col, 1].
|
|
989
1020
|
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
|
|
990
1021
|
pixel positions, must be a tuple or a list of integers, and the format is [1, rate_row, rate_col, 1].
|
|
991
|
-
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid", not case
|
|
992
|
-
Default: "valid".
|
|
1022
|
+
padding (str): The type of padding algorithm, is a string whose value is ``"same"`` or ``"valid"`` , not case
|
|
1023
|
+
sensitive. Default: ``"valid"`` .
|
|
993
1024
|
|
|
994
|
-
- same
|
|
1025
|
+
- ``"same"``: Means that the patch can take the part beyond the original image, and this part is filled
|
|
1026
|
+
with 0.
|
|
995
1027
|
|
|
996
|
-
- valid
|
|
1028
|
+
- ``"valid"``: Means that the taken patch area must be completely covered in the original image.
|
|
997
1029
|
|
|
998
1030
|
Inputs:
|
|
999
|
-
- **x** (Tensor) - A 4-D tensor whose shape is [
|
|
1000
|
-
data type is number.
|
|
1031
|
+
- **x** (Tensor) - A 4-D tensor whose shape is :math:`[in\_batch, in\_depth, in\_row, in\_col]`
|
|
1032
|
+
and data type is number.
|
|
1001
1033
|
|
|
1002
1034
|
Outputs:
|
|
1003
1035
|
Tensor, a 4-D tensor whose data type is same as `x`,
|
|
1004
|
-
and the shape is
|
|
1036
|
+
and the shape is :math:`(out\_batch, out\_depth, out\_row, out\_col)`
|
|
1037
|
+
where `out_batch` is the same as the `in_batch`.
|
|
1005
1038
|
|
|
1006
1039
|
- :math:`out\_depth = ksize\_row * ksize\_col * in\_depth`
|
|
1007
1040
|
- :math:`out\_row = (in\_row - (ksize\_row + (ksize\_row - 1) * (rate\_row - 1))) // stride\_row + 1`
|
|
@@ -1009,17 +1042,20 @@ class Unfold(Cell):
|
|
|
1009
1042
|
|
|
1010
1043
|
Raises:
|
|
1011
1044
|
TypeError: If `ksizes`, `strides` or `rates` is neither a tuple nor list.
|
|
1012
|
-
ValueError: If shape of `ksizes`, `strides` or `rates` is not (1,
|
|
1045
|
+
ValueError: If shape of `ksizes`, `strides` or `rates` is not :math:`(1, x\_row, x\_col, 1)`.
|
|
1013
1046
|
ValueError: If the second and third element of `ksizes`, `strides` or `rates` is less than 1.
|
|
1014
1047
|
|
|
1015
1048
|
Supported Platforms:
|
|
1016
1049
|
``Ascend`` ``GPU``
|
|
1017
1050
|
|
|
1018
1051
|
Examples:
|
|
1019
|
-
>>>
|
|
1052
|
+
>>> import mindspore
|
|
1053
|
+
>>> from mindspore import Tensor, nn
|
|
1054
|
+
>>> import numpy as np
|
|
1055
|
+
>>> net = nn.Unfold(ksizes=[1, 2, 2, 1], strides=[1, 2, 2, 1], rates=[1, 2, 2, 1])
|
|
1020
1056
|
>>> # As stated in the above code:
|
|
1021
1057
|
>>> # ksize_row = 2, ksize_col = 2, rate_row = 2, rate_col = 2, stride_row = 2, stride_col = 2.
|
|
1022
|
-
>>> image = Tensor(np.ones([2, 3, 6, 6]), dtype=
|
|
1058
|
+
>>> image = Tensor(np.ones([2, 3, 6, 6]), dtype=mindspore.float16)
|
|
1023
1059
|
>>> # in_batch = 2, in_depth = 3, in_row = 6, in_col = 6.
|
|
1024
1060
|
>>> # Substituting the formula to get:
|
|
1025
1061
|
>>> # out_batch = in_batch = 2
|
|
@@ -1041,7 +1077,8 @@ class Unfold(Cell):
|
|
|
1041
1077
|
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
|
|
1042
1078
|
raise ValueError(f"For '{prim_name}' the format of '{arg_name}s' must be [1, {arg_name}_row, "
|
|
1043
1079
|
f"{arg_name}_col, 1], but got {arg_val}.")
|
|
1044
|
-
|
|
1080
|
+
is_int = isinstance(arg_val[1], int) and isinstance(arg_val[2], int)
|
|
1081
|
+
if not is_int or arg_val[1] < 1 or arg_val[2] < 1:
|
|
1045
1082
|
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in '{arg_name}s' must be "
|
|
1046
1083
|
f"an positive integer number, but got {arg_name}_row is {arg_val[1]}, "
|
|
1047
1084
|
f"{arg_name}_col is {arg_val[2]}")
|
|
@@ -1070,8 +1107,8 @@ def tril(x_shape, x_dtype, k):
|
|
|
1070
1107
|
|
|
1071
1108
|
class Tril(Cell):
|
|
1072
1109
|
"""
|
|
1073
|
-
|
|
1074
|
-
|
|
1110
|
+
The Tril class will be deprecated in the future,
|
|
1111
|
+
this function can be replaced by :func:`ops.tril`
|
|
1075
1112
|
"""
|
|
1076
1113
|
|
|
1077
1114
|
@deprecated("2.0", "ops.tril", False)
|
|
@@ -1099,8 +1136,8 @@ def triu(x_shape, x_dtype, k):
|
|
|
1099
1136
|
|
|
1100
1137
|
class Triu(Cell):
|
|
1101
1138
|
"""
|
|
1102
|
-
|
|
1103
|
-
|
|
1139
|
+
The Triu class will be deprecated in the future,
|
|
1140
|
+
this function can be replaced by :func:`ops.triu`
|
|
1104
1141
|
"""
|
|
1105
1142
|
|
|
1106
1143
|
@deprecated("2.0", "ops.triu", False)
|
|
@@ -1149,8 +1186,8 @@ def _get_matrix_diag_part_assist(x_shape, x_dtype):
|
|
|
1149
1186
|
|
|
1150
1187
|
class MatrixDiag(Cell):
|
|
1151
1188
|
r"""
|
|
1152
|
-
|
|
1153
|
-
|
|
1189
|
+
The MatrixDiag class will be deprecated in the future,
|
|
1190
|
+
this function can be replaced by :func:`ops.diag`
|
|
1154
1191
|
"""
|
|
1155
1192
|
|
|
1156
1193
|
@deprecated("2.0", "ops.diag", False)
|
|
@@ -1170,8 +1207,8 @@ class MatrixDiag(Cell):
|
|
|
1170
1207
|
|
|
1171
1208
|
class MatrixDiagPart(Cell):
|
|
1172
1209
|
r"""
|
|
1173
|
-
|
|
1174
|
-
|
|
1210
|
+
The MatrixDiagPart class will be deprecated in the future,
|
|
1211
|
+
this function can be replaced by :func:`ops.diagonal`
|
|
1175
1212
|
"""
|
|
1176
1213
|
|
|
1177
1214
|
@deprecated("2.0", "ops.diagonal", False)
|
|
@@ -1221,6 +1258,8 @@ class MatrixSetDiag(Cell):
|
|
|
1221
1258
|
``Ascend``
|
|
1222
1259
|
|
|
1223
1260
|
Examples:
|
|
1261
|
+
>>> import mindspore
|
|
1262
|
+
>>> from mindspore import Tensor, nn
|
|
1224
1263
|
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
|
|
1225
1264
|
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
|
|
1226
1265
|
>>> matrix_set_diag = nn.MatrixSetDiag()
|
|
@@ -1255,8 +1294,8 @@ def _check_input_dim(axis, dim, cls_name):
|
|
|
1255
1294
|
|
|
1256
1295
|
class Roll(Cell):
|
|
1257
1296
|
"""
|
|
1258
|
-
|
|
1259
|
-
|
|
1297
|
+
The Roll class will be deprecated in the future,
|
|
1298
|
+
this function can be replaced by :func:`ops.roll`
|
|
1260
1299
|
"""
|
|
1261
1300
|
|
|
1262
1301
|
@deprecated("2.0", "ops.roll", False)
|
|
@@ -1350,6 +1389,9 @@ class Unflatten(Cell):
|
|
|
1350
1389
|
``Ascend`` ``GPU`` ``CPU``
|
|
1351
1390
|
|
|
1352
1391
|
Examples:
|
|
1392
|
+
>>> import mindspore
|
|
1393
|
+
>>> from mindspore import Tensor, nn
|
|
1394
|
+
>>> import numpy as np
|
|
1353
1395
|
>>> input = Tensor(np.arange(0, 100).reshape(2, 10, 5), mindspore.float32)
|
|
1354
1396
|
>>> net = nn.Unflatten(1, (2, 5))
|
|
1355
1397
|
>>> output = net(input)
|
|
@@ -45,35 +45,35 @@ class ChannelShuffle(Cell):
|
|
|
45
45
|
``Ascend`` ``GPU`` ``CPU``
|
|
46
46
|
|
|
47
47
|
Examples:
|
|
48
|
-
>>>
|
|
49
|
-
>>>
|
|
48
|
+
>>> import mindspore as ms
|
|
49
|
+
>>> import numpy as np
|
|
50
|
+
>>> channel_shuffle = ms.nn.ChannelShuffle(2)
|
|
51
|
+
>>> x = ms.Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
|
|
50
52
|
>>> print(x)
|
|
51
|
-
[[[[0
|
|
52
|
-
[2
|
|
53
|
-
[[4
|
|
54
|
-
[6
|
|
55
|
-
[[8
|
|
56
|
-
[10 11]]
|
|
57
|
-
[[12 13]
|
|
58
|
-
[14 15]]
|
|
59
|
-
]]
|
|
53
|
+
[[[[ 0 1]
|
|
54
|
+
[ 2 3]]
|
|
55
|
+
[[ 4 5]
|
|
56
|
+
[ 6 7]]
|
|
57
|
+
[[ 8 9]
|
|
58
|
+
[10 11]]
|
|
59
|
+
[[12 13]
|
|
60
|
+
[14 15]]]]
|
|
60
61
|
>>> output = channel_shuffle(x)
|
|
61
62
|
>>> print(output)
|
|
62
|
-
[[[[0
|
|
63
|
-
[2
|
|
64
|
-
[[8
|
|
65
|
-
[10 11]]
|
|
66
|
-
[[4
|
|
67
|
-
[6
|
|
68
|
-
[[12 13]
|
|
69
|
-
[14 15]]
|
|
70
|
-
]]
|
|
63
|
+
[[[[ 0 1]
|
|
64
|
+
[ 2 3]]
|
|
65
|
+
[[ 8 9]
|
|
66
|
+
[10 11]]
|
|
67
|
+
[[ 4 5]
|
|
68
|
+
[ 6 7]]
|
|
69
|
+
[[12 13]
|
|
70
|
+
[14 15]]]]
|
|
71
71
|
"""
|
|
72
72
|
def __init__(self, groups):
|
|
73
73
|
"""Initialize ChannelShuffle."""
|
|
74
74
|
super(ChannelShuffle, self).__init__()
|
|
75
75
|
if not isinstance(groups, int):
|
|
76
|
-
raise TypeError("For ChannelShuffle, the param `groups` must be int, but got {
|
|
76
|
+
raise TypeError(f"For ChannelShuffle, the param `groups` must be int, but got {type(groups)}.")
|
|
77
77
|
if groups < 1:
|
|
78
78
|
raise ValueError(f"For ChannelShuffle, the param `groups` must be larger than 0, but got {groups}.")
|
|
79
79
|
|