mindspore 1.10.0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
- mindspore/__init__.py +9 -4
- mindspore/_akg/akg/composite/build_module.py +11 -0
- mindspore/_akg/akg/config/repository_cuda.json +11 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +53 -58
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/config/super_bar_config.json +512 -0
- mindspore/context.py +291 -56
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/include/mindapi/base/type_id.h +42 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libicudata.so.69 +0 -0
- mindspore/lib/libicui18n.so.69 +0 -0
- mindspore/lib/libicuuc.so.69 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/{libakg.so → plugin/cpu/libakg.so} +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/scipy/linalg.py +13 -117
- mindspore/scipy/ops.py +5 -71
- mindspore/scipy/ops_grad.py +1 -25
- mindspore/scipy/ops_wrapper.py +1 -1
- mindspore/scipy/optimize/_bfgs.py +1 -1
- mindspore/scipy/optimize/_lagrange.py +200 -0
- mindspore/scipy/optimize/line_search.py +3 -2
- mindspore/scipy/optimize/minimize.py +43 -6
- mindspore/scipy/sparse/__init__.py +2 -2
- mindspore/scipy/sparse/linalg.py +5 -465
- mindspore/scipy/utils.py +2 -1
- mindspore/scipy/utils_const.py +7 -1
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +899 -675
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -18,6 +18,7 @@ from __future__ import division
|
|
|
18
18
|
|
|
19
19
|
import itertools
|
|
20
20
|
import numbers
|
|
21
|
+
import hashlib
|
|
21
22
|
|
|
22
23
|
from mindspore.ops import operations as P
|
|
23
24
|
from mindspore.ops import functional as F
|
|
@@ -25,22 +26,26 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
25
26
|
from mindspore.common.parameter import Parameter
|
|
26
27
|
from mindspore.common.initializer import initializer, Initializer
|
|
27
28
|
from mindspore.common.tensor import Tensor
|
|
28
|
-
from mindspore.
|
|
29
|
-
from mindspore.ops.primitive import constexpr
|
|
29
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
30
30
|
import mindspore.context as context
|
|
31
|
-
from mindspore
|
|
32
|
-
from mindspore._checkparam import Validator as validator
|
|
31
|
+
from mindspore import _checkparam as validator
|
|
33
32
|
from mindspore._extends import cell_attr_register
|
|
34
33
|
from mindspore.communication.management import get_group_size, get_rank
|
|
35
34
|
from mindspore.communication import management
|
|
36
35
|
from mindspore.common import dtype as mstype
|
|
37
36
|
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|
38
37
|
from mindspore.nn.cell import Cell
|
|
38
|
+
from mindspore import log as logger
|
|
39
39
|
|
|
40
40
|
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
|
|
41
|
-
'
|
|
41
|
+
'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
|
42
42
|
|
|
43
|
-
|
|
43
|
+
|
|
44
|
+
def _check_dim(val, target, cls_name):
|
|
45
|
+
def _check(val, target, cls_name):
|
|
46
|
+
if val != target:
|
|
47
|
+
raise ValueError(f"For '{cls_name}', the in_shape must have {target} dims, but got {val}.")
|
|
48
|
+
_check(val, target, cls_name)
|
|
44
49
|
|
|
45
50
|
|
|
46
51
|
class _BatchNorm(Cell):
|
|
@@ -57,9 +62,6 @@ class _BatchNorm(Cell):
|
|
|
57
62
|
moving_mean_init='zeros',
|
|
58
63
|
moving_var_init='ones',
|
|
59
64
|
use_batch_statistics=None,
|
|
60
|
-
device_num_each_group=1,
|
|
61
|
-
process_groups=0,
|
|
62
|
-
input_dims='2d',
|
|
63
65
|
data_format='NCHW'):
|
|
64
66
|
"""Initialize _BatchNorm."""
|
|
65
67
|
super(_BatchNorm, self).__init__()
|
|
@@ -70,7 +72,6 @@ class _BatchNorm(Cell):
|
|
|
70
72
|
if momentum < 0 or momentum > 1:
|
|
71
73
|
raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
|
|
72
74
|
f"but got {momentum}.")
|
|
73
|
-
self.input_dims = input_dims
|
|
74
75
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
|
75
76
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
76
77
|
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
|
|
@@ -93,39 +94,8 @@ class _BatchNorm(Cell):
|
|
|
93
94
|
gamma_init, num_features), name="gamma", requires_grad=affine)
|
|
94
95
|
self.beta = Parameter(initializer(
|
|
95
96
|
beta_init, num_features), name="beta", requires_grad=affine)
|
|
96
|
-
|
|
97
|
-
self.cls_name)
|
|
98
|
-
self.process_groups = process_groups
|
|
99
|
-
self.is_global = False
|
|
97
|
+
|
|
100
98
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
101
|
-
global SYNC_BN_GROUP_NAME
|
|
102
|
-
# for GlobalBatchNorm
|
|
103
|
-
if self.group_device_num != 1:
|
|
104
|
-
self.rank_id = get_rank()
|
|
105
|
-
self.rank_size = get_group_size()
|
|
106
|
-
self.device_list = [i for i in range(0, self.rank_size)]
|
|
107
|
-
self.rank_list = self.list_group(self.device_list, self.group_device_num)
|
|
108
|
-
self.rank_list_idx = len(self.rank_list)
|
|
109
|
-
self._create_global_groups()
|
|
110
|
-
# for SyncBatchNorm
|
|
111
|
-
if self.process_groups != 0:
|
|
112
|
-
self.rank_id = get_rank()
|
|
113
|
-
self.rank_size = get_group_size()
|
|
114
|
-
if self.process_groups is not None:
|
|
115
|
-
validator.check_isinstance("process_groups", self.process_groups, list)
|
|
116
|
-
self._check_rank_ids(self.process_groups, self.rank_size)
|
|
117
|
-
self._create_sync_groups()
|
|
118
|
-
elif self.rank_size > 1:
|
|
119
|
-
self.is_global = True
|
|
120
|
-
self.group_device_num = self.rank_size
|
|
121
|
-
self.device_list = [i for i in range(0, self.rank_size)]
|
|
122
|
-
if context.get_context("device_target") == "Ascend":
|
|
123
|
-
if SYNC_BN_GROUP_NAME == "":
|
|
124
|
-
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
|
125
|
-
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
|
126
|
-
elif context.get_context("device_target") == "GPU":
|
|
127
|
-
if SYNC_BN_GROUP_NAME == "":
|
|
128
|
-
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
|
129
99
|
|
|
130
100
|
self.shape = P.Shape()
|
|
131
101
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
@@ -137,20 +107,11 @@ class _BatchNorm(Cell):
|
|
|
137
107
|
self._target = context.get_context("device_target")
|
|
138
108
|
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
|
139
109
|
self.momentum = 1.0 - momentum
|
|
140
|
-
if context.get_context("enable_ge"):
|
|
141
|
-
self.is_ge_backend = True
|
|
142
|
-
else:
|
|
143
|
-
self.is_ge_backend = False
|
|
144
110
|
|
|
145
111
|
self.bn_train = P.BatchNorm(is_training=True,
|
|
146
112
|
epsilon=self.eps,
|
|
147
113
|
momentum=self.momentum,
|
|
148
114
|
data_format=self.format)
|
|
149
|
-
if self.is_global:
|
|
150
|
-
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
|
151
|
-
momentum=self.momentum,
|
|
152
|
-
group=SYNC_BN_GROUP_NAME,
|
|
153
|
-
device_num=self.group_device_num)
|
|
154
115
|
|
|
155
116
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
|
156
117
|
if _is_in_auto_parallel_mode():
|
|
@@ -166,22 +127,15 @@ class _BatchNorm(Cell):
|
|
|
166
127
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
|
167
128
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
|
168
129
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if len(world_rank) % group_size != 0:
|
|
176
|
-
raise ValueError(f"For '{self.cls_name}', the dimension of device_list must be divisible by "
|
|
177
|
-
f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, "
|
|
178
|
-
f"'device_num_each_group': {group_size}.")
|
|
179
|
-
world_rank_list = zip(*(iter(world_rank),) * group_size)
|
|
180
|
-
group_list = [list(i) for i in world_rank_list]
|
|
181
|
-
return group_list
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
@_primexpr
|
|
133
|
+
def _check_input_dim(shape, cls_name):
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
182
136
|
|
|
183
137
|
def construct(self, x):
|
|
184
|
-
|
|
138
|
+
self._check_input_dim(self.shape(x), self.cls_name)
|
|
185
139
|
if self.use_batch_statistics is None:
|
|
186
140
|
if self.training:
|
|
187
141
|
return self.bn_train(x,
|
|
@@ -214,98 +168,11 @@ class _BatchNorm(Cell):
|
|
|
214
168
|
self.num_features, self.eps, 1.0 - self.momentum, self.gamma, self.beta, \
|
|
215
169
|
self.moving_mean, self.moving_variance)
|
|
216
170
|
|
|
217
|
-
def _check_data_dim(self, x):
|
|
218
|
-
raise NotImplementedError
|
|
219
|
-
|
|
220
|
-
def _check_rank_ids(self, process_groups, rank_size):
|
|
221
|
-
seen = set()
|
|
222
|
-
for rid in itertools.chain(*process_groups):
|
|
223
|
-
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
|
|
224
|
-
if rid in seen:
|
|
225
|
-
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
|
|
226
|
-
f"but got {process_groups}.")
|
|
227
|
-
seen.add(rid)
|
|
228
|
-
|
|
229
|
-
def _create_global_groups(self):
|
|
230
|
-
for i in range(self.rank_list_idx):
|
|
231
|
-
if self.rank_id in self.rank_list[i]:
|
|
232
|
-
self.is_global = True
|
|
233
|
-
global SYNC_BN_GROUP_NAME
|
|
234
|
-
if SYNC_BN_GROUP_NAME == "":
|
|
235
|
-
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
|
236
|
-
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
|
237
|
-
|
|
238
|
-
def _create_sync_groups(self):
|
|
239
|
-
for i in range(len(self.process_groups)):
|
|
240
|
-
validator.check_isinstance("process_groups[%d]" % i, self.process_groups[i], list)
|
|
241
|
-
self.group_device_num = len(self.process_groups[i])
|
|
242
|
-
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
|
243
|
-
self.is_global = True
|
|
244
|
-
global SYNC_BN_GROUP_NAME
|
|
245
|
-
if SYNC_BN_GROUP_NAME == "":
|
|
246
|
-
SYNC_BN_GROUP_NAME = "sync_bn_group%d" % i
|
|
247
|
-
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
@constexpr
|
|
251
|
-
def _channel_check(channel, num_channel, prim_name=None):
|
|
252
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
253
|
-
if channel != num_channel:
|
|
254
|
-
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to num_channels, "
|
|
255
|
-
f"but got channel: {channel}, num_channels: {num_channel}.")
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
@constexpr
|
|
259
|
-
def _shape_check(in_shape, prim_name=None):
|
|
260
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
261
|
-
if len(in_shape) != 4:
|
|
262
|
-
raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.")
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
@constexpr
|
|
266
|
-
def _shape_check_bn(in_shape, in_dims, prim_name=None):
|
|
267
|
-
"""check input dims of batch norm."""
|
|
268
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
269
|
-
dim = len(in_shape)
|
|
270
|
-
if in_dims == '1d' and dim != 2:
|
|
271
|
-
raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.")
|
|
272
|
-
if in_dims == '2d' and dim != 4:
|
|
273
|
-
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
|
|
274
|
-
if in_dims == '3d' and dim != 5:
|
|
275
|
-
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
|
|
276
|
-
if in_dims == 'both' and dim != 2 and dim != 4:
|
|
277
|
-
raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.")
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
@constexpr
|
|
281
|
-
def _shape_check_in(in_shape, in_dims, prim_name=None):
|
|
282
|
-
"""check input dims of batch norm."""
|
|
283
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
284
|
-
dim = len(in_shape)
|
|
285
|
-
if in_dims == '1d' and dim != 3:
|
|
286
|
-
raise ValueError(f"{msg_prefix} in_shape must have 3 dims, but got {len(in_shape)}.")
|
|
287
|
-
if in_dims == '2d' and dim != 4:
|
|
288
|
-
raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
|
|
289
|
-
if in_dims == '3d' and dim != 5:
|
|
290
|
-
raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
@constexpr
|
|
294
|
-
def _shape_infer(x_shape, num_feature):
|
|
295
|
-
"""global Batch Normalization shape and axes infer"""
|
|
296
|
-
if len(x_shape) == 4:
|
|
297
|
-
axes = (0, 2, 3)
|
|
298
|
-
re_shape = (1, num_feature, 1, 1)
|
|
299
|
-
else:
|
|
300
|
-
axes = (0,)
|
|
301
|
-
re_shape = (1, num_feature)
|
|
302
|
-
return axes, re_shape
|
|
303
|
-
|
|
304
171
|
|
|
305
172
|
class BatchNorm1d(_BatchNorm):
|
|
306
173
|
r"""
|
|
307
174
|
This layer
|
|
308
|
-
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
|
|
175
|
+
applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D or 2D inputs) to
|
|
309
176
|
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
|
|
310
177
|
For the setailed contents, refer to `Batch Normalization: Accelerating Deep Network Training by
|
|
311
178
|
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
|
|
@@ -320,14 +187,14 @@ class BatchNorm1d(_BatchNorm):
|
|
|
320
187
|
recommended to be changed after net was initialized.
|
|
321
188
|
|
|
322
189
|
Args:
|
|
323
|
-
num_features (int): `C`
|
|
324
|
-
eps (float):
|
|
190
|
+
num_features (int): number of features or channels `C` of the input `x` .
|
|
191
|
+
eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
|
|
325
192
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
326
193
|
running_mean and running_var computation. Default: 0.9.
|
|
327
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
|
328
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
194
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
|
|
195
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
329
196
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
|
330
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
197
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
331
198
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
332
199
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
333
200
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
@@ -337,12 +204,15 @@ class BatchNorm1d(_BatchNorm):
|
|
|
337
204
|
use the mean value and variance value of specified value. If None, the training process will use the mean
|
|
338
205
|
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
|
339
206
|
the running mean and variance. Default: None.
|
|
207
|
+
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
|
208
|
+
Default: 'NCHW'.
|
|
340
209
|
|
|
341
210
|
Inputs:
|
|
342
|
-
- **x** (Tensor) - Tensor of shape :math:`(N,
|
|
211
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
|
|
212
|
+
where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
|
|
343
213
|
|
|
344
214
|
Outputs:
|
|
345
|
-
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N,
|
|
215
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
|
|
346
216
|
|
|
347
217
|
Raises:
|
|
348
218
|
TypeError: If `num_features` is not an int.
|
|
@@ -366,31 +236,14 @@ class BatchNorm1d(_BatchNorm):
|
|
|
366
236
|
[ 0.4999975 0.399998 0.59999704 0.89999545 ]]
|
|
367
237
|
"""
|
|
368
238
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
moving_var_init='ones',
|
|
378
|
-
use_batch_statistics=None):
|
|
379
|
-
"""Initialize BatchNorm1d."""
|
|
380
|
-
super(BatchNorm1d, self).__init__(num_features,
|
|
381
|
-
eps,
|
|
382
|
-
momentum,
|
|
383
|
-
affine,
|
|
384
|
-
gamma_init,
|
|
385
|
-
beta_init,
|
|
386
|
-
moving_mean_init,
|
|
387
|
-
moving_var_init,
|
|
388
|
-
use_batch_statistics,
|
|
389
|
-
input_dims='1d')
|
|
390
|
-
|
|
391
|
-
def _check_data_dim(self, x):
|
|
392
|
-
if x.ndim != 2:
|
|
393
|
-
pass
|
|
239
|
+
@staticmethod
|
|
240
|
+
@_primexpr
|
|
241
|
+
def _check_input_dim(shape, cls_name):
|
|
242
|
+
def _check(dim):
|
|
243
|
+
if dim not in (2, 3):
|
|
244
|
+
raise ValueError(f"For '{cls_name}', the must have 2 dims or 3 dims, but got {dim}.")
|
|
245
|
+
dim = len(shape)
|
|
246
|
+
_check(dim)
|
|
394
247
|
|
|
395
248
|
|
|
396
249
|
class BatchNorm2d(_BatchNorm):
|
|
@@ -412,22 +265,22 @@ class BatchNorm2d(_BatchNorm):
|
|
|
412
265
|
Note that the formula for updating the :math:`moving\_mean` and :math:`moving\_var` is
|
|
413
266
|
|
|
414
267
|
.. math::
|
|
415
|
-
\text{moving_mean}=\text{moving_mean
|
|
416
|
-
\text{moving_var}=\text{moving_var
|
|
268
|
+
\text{moving_mean}=\text{moving_mean*momentum}+μ_β\text{*(1−momentum)}\\
|
|
269
|
+
\text{moving_var}=\text{moving_var*momentum}+σ^2_β\text{*(1−momentum)}
|
|
417
270
|
|
|
418
271
|
where :math:`moving\_mean` is the updated mean, :math:`moving\_var` is the updated variance,
|
|
419
272
|
:math:`μ_β, σ^2_β` are the observed value (mean and variance) of each batch of data.
|
|
420
273
|
|
|
421
274
|
Args:
|
|
422
|
-
num_features (int): The number of channels of the input tensor. Expected input size is (N, C, H, W)
|
|
275
|
+
num_features (int): The number of channels of the input tensor. Expected input size is :math:`(N, C, H, W)`,
|
|
423
276
|
`C` represents the number of channels.
|
|
424
|
-
eps (float):
|
|
277
|
+
eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
|
|
425
278
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
426
279
|
running_mean and running_var computation. Default: 0.9.
|
|
427
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
|
428
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
280
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
|
|
281
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
429
282
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
|
430
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
283
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
431
284
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
432
285
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
433
286
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
@@ -446,10 +299,10 @@ class BatchNorm2d(_BatchNorm):
|
|
|
446
299
|
Default: 'NCHW'.
|
|
447
300
|
|
|
448
301
|
Inputs:
|
|
449
|
-
- **x** (Tensor) - Tensor of shape :math:`(N,
|
|
302
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
|
|
450
303
|
|
|
451
304
|
Outputs:
|
|
452
|
-
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N,
|
|
305
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
|
|
453
306
|
|
|
454
307
|
Raises:
|
|
455
308
|
TypeError: If `num_features` is not an int.
|
|
@@ -477,46 +330,11 @@ class BatchNorm2d(_BatchNorm):
|
|
|
477
330
|
[ 0.999995 0.999995 ]]]]
|
|
478
331
|
"""
|
|
479
332
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
gamma_init='ones',
|
|
486
|
-
beta_init='zeros',
|
|
487
|
-
moving_mean_init='zeros',
|
|
488
|
-
moving_var_init='ones',
|
|
489
|
-
use_batch_statistics=None,
|
|
490
|
-
data_format='NCHW'):
|
|
491
|
-
"""Initialize BatchNorm2d."""
|
|
492
|
-
super(BatchNorm2d, self).__init__(num_features,
|
|
493
|
-
eps,
|
|
494
|
-
momentum,
|
|
495
|
-
affine,
|
|
496
|
-
gamma_init,
|
|
497
|
-
beta_init,
|
|
498
|
-
moving_mean_init,
|
|
499
|
-
moving_var_init,
|
|
500
|
-
use_batch_statistics,
|
|
501
|
-
input_dims='2d',
|
|
502
|
-
data_format=data_format)
|
|
503
|
-
|
|
504
|
-
def _check_data_dim(self, x):
|
|
505
|
-
if x.ndim != 4:
|
|
506
|
-
pass
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
@constexpr
|
|
510
|
-
def _check_3d_shape(input_shape, prim_name=None):
|
|
511
|
-
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
512
|
-
if len(input_shape) != 5:
|
|
513
|
-
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
|
|
514
|
-
f"{len(input_shape)}.")
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
@constexpr
|
|
518
|
-
def _check_dtype(dtype, valid_dtypes, args_name, prim_name=None):
|
|
519
|
-
validator.check_type_name(args_name, dtype, valid_dtypes, prim_name)
|
|
333
|
+
@staticmethod
|
|
334
|
+
@_primexpr
|
|
335
|
+
def _check_input_dim(shape, cls_name):
|
|
336
|
+
dim = len(shape)
|
|
337
|
+
_check_dim(dim, 4, cls_name)
|
|
520
338
|
|
|
521
339
|
|
|
522
340
|
class BatchNorm3d(Cell):
|
|
@@ -536,7 +354,7 @@ class BatchNorm3d(Cell):
|
|
|
536
354
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
537
355
|
|
|
538
356
|
Args:
|
|
539
|
-
num_features (int): `C` from an expected input of size (N, C, D, H, W).
|
|
357
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)` .
|
|
540
358
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
541
359
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
542
360
|
running_mean and running_var computation. Default: 0.9.
|
|
@@ -553,7 +371,6 @@ class BatchNorm3d(Cell):
|
|
|
553
371
|
use the mean value and variance value of specified value. If None, the training process will use the mean
|
|
554
372
|
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
|
555
373
|
the running mean and variance. Default: None.
|
|
556
|
-
data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
|
|
557
374
|
|
|
558
375
|
Inputs:
|
|
559
376
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
|
@@ -566,7 +383,6 @@ class BatchNorm3d(Cell):
|
|
|
566
383
|
TypeError: If `eps` is not a float.
|
|
567
384
|
ValueError: If `num_features` is less than 1.
|
|
568
385
|
ValueError: If `momentum` is not in range [0, 1].
|
|
569
|
-
ValueError: If `data_format` is not 'NCDHW'.
|
|
570
386
|
|
|
571
387
|
Supported Platforms:
|
|
572
388
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -591,12 +407,9 @@ class BatchNorm3d(Cell):
|
|
|
591
407
|
beta_init='zeros',
|
|
592
408
|
moving_mean_init='zeros',
|
|
593
409
|
moving_var_init='ones',
|
|
594
|
-
use_batch_statistics=None
|
|
595
|
-
data_format='NCDHW'):
|
|
410
|
+
use_batch_statistics=None):
|
|
596
411
|
"""Initialize BatchNorm3d."""
|
|
597
412
|
super(BatchNorm3d, self).__init__()
|
|
598
|
-
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
|
599
|
-
self.reshape = P.Reshape()
|
|
600
413
|
self.bn2d = BatchNorm2d(num_features=num_features,
|
|
601
414
|
eps=eps,
|
|
602
415
|
momentum=momentum,
|
|
@@ -607,57 +420,33 @@ class BatchNorm3d(Cell):
|
|
|
607
420
|
moving_var_init=moving_var_init,
|
|
608
421
|
use_batch_statistics=use_batch_statistics,
|
|
609
422
|
data_format="NCHW")
|
|
423
|
+
self.shape = P.Shape()
|
|
424
|
+
self.reshape = P.Reshape()
|
|
610
425
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
426
|
+
@staticmethod
|
|
427
|
+
@_primexpr
|
|
428
|
+
def _check_input_dim(shape, cls_name):
|
|
429
|
+
dim = len(shape)
|
|
430
|
+
_check_dim(dim, 5, cls_name)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def construct(self, x):
|
|
434
|
+
x_shape = self.shape(x)
|
|
435
|
+
self._check_input_dim(x_shape, self.cls_name)
|
|
436
|
+
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
|
437
|
+
bn2d_out = self.bn2d(x)
|
|
616
438
|
bn3d_out = self.reshape(bn2d_out, x_shape)
|
|
617
439
|
return bn3d_out
|
|
618
440
|
|
|
619
441
|
|
|
620
|
-
|
|
621
|
-
r"""
|
|
622
|
-
The GlobalBatchNorm interface is deprecated, please use the :class:`mindspore.nn.SyncBatchNorm` instead.
|
|
442
|
+
SYNCBN_GROUP_DICT = None
|
|
623
443
|
|
|
624
|
-
Supported Platforms:
|
|
625
|
-
deprecated
|
|
626
|
-
"""
|
|
627
444
|
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
affine=True,
|
|
634
|
-
gamma_init='ones',
|
|
635
|
-
beta_init='zeros',
|
|
636
|
-
moving_mean_init='zeros',
|
|
637
|
-
moving_var_init='ones',
|
|
638
|
-
use_batch_statistics=None,
|
|
639
|
-
device_num_each_group=2):
|
|
640
|
-
"""Initialize GlobalBatchNorm."""
|
|
641
|
-
super(GlobalBatchNorm, self).__init__(num_features,
|
|
642
|
-
eps,
|
|
643
|
-
momentum,
|
|
644
|
-
affine,
|
|
645
|
-
gamma_init,
|
|
646
|
-
beta_init,
|
|
647
|
-
moving_mean_init,
|
|
648
|
-
moving_var_init,
|
|
649
|
-
use_batch_statistics,
|
|
650
|
-
device_num_each_group,
|
|
651
|
-
input_dims='both')
|
|
652
|
-
self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
|
|
653
|
-
self.cls_name)
|
|
654
|
-
if self.group_device_num <= 1:
|
|
655
|
-
raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, "
|
|
656
|
-
f"but got {self.group_device_num}.")
|
|
657
|
-
|
|
658
|
-
def _check_data_dim(self, x):
|
|
659
|
-
if x.dim == 0:
|
|
660
|
-
pass
|
|
445
|
+
def _syncbatchnorm_group_dict():
|
|
446
|
+
global SYNCBN_GROUP_DICT
|
|
447
|
+
if SYNCBN_GROUP_DICT is None:
|
|
448
|
+
SYNCBN_GROUP_DICT = dict()
|
|
449
|
+
return SYNCBN_GROUP_DICT
|
|
661
450
|
|
|
662
451
|
|
|
663
452
|
class SyncBatchNorm(_BatchNorm):
|
|
@@ -677,15 +466,16 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
677
466
|
Currently, SyncBatchNorm only supports 2D and 4D inputs.
|
|
678
467
|
|
|
679
468
|
Args:
|
|
680
|
-
num_features (int): `C` from an expected input of size (N, C, H, W)
|
|
681
|
-
eps (float):
|
|
469
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
|
|
470
|
+
eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: 1e-5.
|
|
682
471
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
683
472
|
running_mean and running_var computation. Default: 0.9.
|
|
684
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned.
|
|
685
|
-
|
|
473
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned.
|
|
474
|
+
Default: True.
|
|
475
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
686
476
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
687
477
|
'he_uniform', etc. Default: 'ones'.
|
|
688
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
478
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
689
479
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
690
480
|
'he_uniform', etc. Default: 'zeros'.
|
|
691
481
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
@@ -721,10 +511,19 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
721
511
|
``Ascend``
|
|
722
512
|
|
|
723
513
|
Examples:
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
514
|
+
.. note::
|
|
515
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
516
|
+
|
|
517
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
518
|
+
Please see the `Ascend tutorial
|
|
519
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
|
|
520
|
+
for more details.
|
|
521
|
+
|
|
522
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
|
|
523
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
|
|
524
|
+
|
|
525
|
+
This example should be run with multiple devices.
|
|
526
|
+
|
|
728
527
|
>>> import numpy as np
|
|
729
528
|
>>> import mindspore as ms
|
|
730
529
|
>>> from mindspore.communication import init
|
|
@@ -747,7 +546,7 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
747
546
|
[[ 0.999995 0.999995 ]
|
|
748
547
|
[ 0.999995 0.999995 ]]]]
|
|
749
548
|
"""
|
|
750
|
-
|
|
549
|
+
@cell_attr_register(attrs=['num_features', 'process_groups'])
|
|
751
550
|
def __init__(self,
|
|
752
551
|
num_features,
|
|
753
552
|
eps=1e-5,
|
|
@@ -768,13 +567,71 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
768
567
|
beta_init,
|
|
769
568
|
moving_mean_init,
|
|
770
569
|
moving_var_init,
|
|
771
|
-
use_batch_statistics
|
|
772
|
-
|
|
773
|
-
|
|
570
|
+
use_batch_statistics)
|
|
571
|
+
self.is_global = False
|
|
572
|
+
self.group_name = None
|
|
573
|
+
self.process_groups = process_groups
|
|
574
|
+
if self.process_groups != 0:
|
|
575
|
+
self.rank_id = get_rank()
|
|
576
|
+
self.rank_size = get_group_size()
|
|
577
|
+
if self.process_groups is not None:
|
|
578
|
+
validator.check_isinstance("process_groups", self.process_groups, list)
|
|
579
|
+
self._check_rank_ids(self.process_groups, self.rank_size)
|
|
580
|
+
self._create_sync_groups()
|
|
581
|
+
elif self.rank_size > 1:
|
|
582
|
+
self.is_global = True
|
|
583
|
+
self.group_device_num = self.rank_size
|
|
584
|
+
if context.get_context("device_target") == "Ascend":
|
|
585
|
+
self.group_name = "hccl_world_group"
|
|
586
|
+
elif context.get_context("device_target") == "GPU":
|
|
587
|
+
self.group_name = "nccl_world_group"
|
|
774
588
|
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
589
|
+
if self.is_global:
|
|
590
|
+
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
|
591
|
+
momentum=self.momentum,
|
|
592
|
+
group=self.group_name,
|
|
593
|
+
device_num=self.group_device_num)
|
|
594
|
+
|
|
595
|
+
def _create_sync_groups(self):
|
|
596
|
+
""" create groups by process groups. """
|
|
597
|
+
for sub_group in self.process_groups:
|
|
598
|
+
validator.check_isinstance("sub group", sub_group, list)
|
|
599
|
+
self.group_device_num = len(sub_group)
|
|
600
|
+
if self.rank_id in sub_group and self.group_device_num > 1:
|
|
601
|
+
self.is_global = True
|
|
602
|
+
rank_list_name = '_'.join('%s' % id for id in sub_group)
|
|
603
|
+
group_dict = _syncbatchnorm_group_dict()
|
|
604
|
+
if rank_list_name not in group_dict:
|
|
605
|
+
md5 = hashlib.md5()
|
|
606
|
+
md5.update(rank_list_name.encode('utf-8'))
|
|
607
|
+
hash_name = md5.hexdigest()
|
|
608
|
+
self.group_name = str(self.group_device_num) + '_' + hash_name
|
|
609
|
+
group_dict[rank_list_name] = self.group_name
|
|
610
|
+
management.create_group(self.group_name, sub_group)
|
|
611
|
+
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
|
|
612
|
+
rank_list_name, self.group_name))
|
|
613
|
+
else:
|
|
614
|
+
self.group_name = group_dict[rank_list_name]
|
|
615
|
+
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
|
|
616
|
+
|
|
617
|
+
@staticmethod
|
|
618
|
+
@_primexpr
|
|
619
|
+
def _check_input_dim(shape, cls_name):
|
|
620
|
+
def _check(dim):
|
|
621
|
+
if dim not in (2, 4):
|
|
622
|
+
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
|
|
623
|
+
dim = len(shape)
|
|
624
|
+
_check(dim)
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def _check_rank_ids(self, process_groups, rank_size):
|
|
628
|
+
seen = set()
|
|
629
|
+
for rid in itertools.chain(*process_groups):
|
|
630
|
+
validator.check_int_range(rid, 0, rank_size, validator.INC_LEFT, "rank id in process_groups", self.cls_name)
|
|
631
|
+
if rid in seen:
|
|
632
|
+
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
|
|
633
|
+
f"but got {process_groups}.")
|
|
634
|
+
seen.add(rid)
|
|
778
635
|
|
|
779
636
|
|
|
780
637
|
class LayerNorm(Cell):
|
|
@@ -799,13 +656,13 @@ class LayerNorm(Cell):
|
|
|
799
656
|
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
|
800
657
|
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
|
801
658
|
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
|
802
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
659
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
803
660
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
804
661
|
'he_uniform', etc. Default: 'ones'.
|
|
805
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
662
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
806
663
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
807
664
|
'he_uniform', etc. Default: 'zeros'.
|
|
808
|
-
epsilon (float):
|
|
665
|
+
epsilon (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-7.
|
|
809
666
|
|
|
810
667
|
Inputs:
|
|
811
668
|
- **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
|
|
@@ -867,7 +724,6 @@ class LayerNorm(Cell):
|
|
|
867
724
|
|
|
868
725
|
class _InstanceNorm(Cell):
|
|
869
726
|
"""Instance Normalization base class."""
|
|
870
|
-
|
|
871
727
|
@cell_attr_register
|
|
872
728
|
def __init__(self,
|
|
873
729
|
num_features,
|
|
@@ -875,8 +731,7 @@ class _InstanceNorm(Cell):
|
|
|
875
731
|
momentum=0.1,
|
|
876
732
|
affine=True,
|
|
877
733
|
gamma_init='ones',
|
|
878
|
-
beta_init='zeros'
|
|
879
|
-
input_dims='2d'):
|
|
734
|
+
beta_init='zeros'):
|
|
880
735
|
"""Initialize Normalization base class."""
|
|
881
736
|
super(_InstanceNorm, self).__init__()
|
|
882
737
|
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
|
@@ -893,7 +748,6 @@ class _InstanceNorm(Cell):
|
|
|
893
748
|
f"but got {momentum}.")
|
|
894
749
|
self.num_features = num_features
|
|
895
750
|
self.eps = eps
|
|
896
|
-
self.input_dims = input_dims
|
|
897
751
|
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
|
|
898
752
|
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
|
|
899
753
|
self.gamma = Parameter(initializer(
|
|
@@ -906,7 +760,7 @@ class _InstanceNorm(Cell):
|
|
|
906
760
|
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
|
907
761
|
|
|
908
762
|
def construct(self, x):
|
|
909
|
-
|
|
763
|
+
self._check_input_dim(self.shape(x), self.cls_name)
|
|
910
764
|
return self.instance_bn(x,
|
|
911
765
|
self.gamma,
|
|
912
766
|
self.beta,
|
|
@@ -952,7 +806,7 @@ class InstanceNorm1d(_InstanceNorm):
|
|
|
952
806
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
953
807
|
|
|
954
808
|
Args:
|
|
955
|
-
num_features (int): `C` from an expected input of size (N, C, L)
|
|
809
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, L)`.
|
|
956
810
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
957
811
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
958
812
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -999,21 +853,12 @@ class InstanceNorm1d(_InstanceNorm):
|
|
|
999
853
|
(2, 3, 5)
|
|
1000
854
|
"""
|
|
1001
855
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
beta_init='zeros'):
|
|
1009
|
-
"""Initialize InstanceNorm2d."""
|
|
1010
|
-
super(InstanceNorm1d, self).__init__(num_features,
|
|
1011
|
-
eps,
|
|
1012
|
-
momentum,
|
|
1013
|
-
affine,
|
|
1014
|
-
gamma_init,
|
|
1015
|
-
beta_init,
|
|
1016
|
-
input_dims='1d')
|
|
856
|
+
@staticmethod
|
|
857
|
+
@_primexpr
|
|
858
|
+
def _check_input_dim(shape, cls_name):
|
|
859
|
+
dim = len(shape)
|
|
860
|
+
_check_dim(dim, 3, cls_name)
|
|
861
|
+
|
|
1017
862
|
|
|
1018
863
|
|
|
1019
864
|
class InstanceNorm2d(_InstanceNorm):
|
|
@@ -1040,7 +885,7 @@ class InstanceNorm2d(_InstanceNorm):
|
|
|
1040
885
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
1041
886
|
|
|
1042
887
|
Args:
|
|
1043
|
-
num_features (int): `C` from an expected input of size (N, C, H, W)
|
|
888
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
|
|
1044
889
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
1045
890
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
1046
891
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -1087,21 +932,11 @@ class InstanceNorm2d(_InstanceNorm):
|
|
|
1087
932
|
(2, 3, 2, 2)
|
|
1088
933
|
"""
|
|
1089
934
|
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
gamma_init='ones',
|
|
1096
|
-
beta_init='zeros'):
|
|
1097
|
-
"""Initialize InstanceNorm2d."""
|
|
1098
|
-
super(InstanceNorm2d, self).__init__(num_features,
|
|
1099
|
-
eps,
|
|
1100
|
-
momentum,
|
|
1101
|
-
affine,
|
|
1102
|
-
gamma_init,
|
|
1103
|
-
beta_init,
|
|
1104
|
-
input_dims='2d')
|
|
935
|
+
@staticmethod
|
|
936
|
+
@_primexpr
|
|
937
|
+
def _check_input_dim(shape, cls_name):
|
|
938
|
+
dim = len(shape)
|
|
939
|
+
_check_dim(dim, 4, cls_name)
|
|
1105
940
|
|
|
1106
941
|
|
|
1107
942
|
class InstanceNorm3d(_InstanceNorm):
|
|
@@ -1128,7 +963,7 @@ class InstanceNorm3d(_InstanceNorm):
|
|
|
1128
963
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
1129
964
|
|
|
1130
965
|
Args:
|
|
1131
|
-
num_features (int): `C` from an expected input of size (N, C, D, H, W)
|
|
966
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)`.
|
|
1132
967
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
1133
968
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
1134
969
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -1175,21 +1010,11 @@ class InstanceNorm3d(_InstanceNorm):
|
|
|
1175
1010
|
(2, 3, 5, 2, 2)
|
|
1176
1011
|
"""
|
|
1177
1012
|
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
gamma_init='ones',
|
|
1184
|
-
beta_init='zeros'):
|
|
1185
|
-
"""Initialize InstanceNorm2d."""
|
|
1186
|
-
super(InstanceNorm3d, self).__init__(num_features,
|
|
1187
|
-
eps,
|
|
1188
|
-
momentum,
|
|
1189
|
-
affine,
|
|
1190
|
-
gamma_init,
|
|
1191
|
-
beta_init,
|
|
1192
|
-
input_dims='3d')
|
|
1013
|
+
@staticmethod
|
|
1014
|
+
@_primexpr
|
|
1015
|
+
def _check_input_dim(shape, cls_name):
|
|
1016
|
+
dim = len(shape)
|
|
1017
|
+
_check_dim(dim, 5, cls_name)
|
|
1193
1018
|
|
|
1194
1019
|
|
|
1195
1020
|
class GroupNorm(Cell):
|
|
@@ -1212,10 +1037,10 @@ class GroupNorm(Cell):
|
|
|
1212
1037
|
affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
|
|
1213
1038
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
1214
1039
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
1215
|
-
'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be
|
|
1040
|
+
'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be :math:`(num\_channels)`.
|
|
1216
1041
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
1217
1042
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
1218
|
-
'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be
|
|
1043
|
+
'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be :math:`(num\_channels)`.
|
|
1219
1044
|
|
|
1220
1045
|
Inputs:
|
|
1221
1046
|
- **x** (Tensor) - The input feature with shape :math:`(N, C, H, W)` .
|
|
@@ -1273,7 +1098,7 @@ class GroupNorm(Cell):
|
|
|
1273
1098
|
def _cal_output(self, x):
|
|
1274
1099
|
"""calculate groupnorm output"""
|
|
1275
1100
|
batch, channel, height, width = self.shape(x)
|
|
1276
|
-
_channel_check(channel, self.num_channels, self.cls_name)
|
|
1101
|
+
self._channel_check(channel, self.num_channels, self.cls_name)
|
|
1277
1102
|
x = self.reshape(x, (batch, self.num_groups, -1))
|
|
1278
1103
|
mean = self.reduce_mean(x, 2)
|
|
1279
1104
|
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
|
|
@@ -1283,11 +1108,32 @@ class GroupNorm(Cell):
|
|
|
1283
1108
|
output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
|
|
1284
1109
|
return output
|
|
1285
1110
|
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1111
|
+
@staticmethod
|
|
1112
|
+
@_primexpr
|
|
1113
|
+
def _check_input_dim(shape, cls_name):
|
|
1114
|
+
dim = len(shape)
|
|
1115
|
+
_check_dim(dim, 4, cls_name)
|
|
1116
|
+
|
|
1117
|
+
@staticmethod
|
|
1118
|
+
@_primexpr
|
|
1119
|
+
def _channel_check(channel, num_channel, prim_name=None):
|
|
1120
|
+
def _check():
|
|
1121
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1122
|
+
if channel != num_channel:
|
|
1123
|
+
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to "
|
|
1124
|
+
f"num_channels, but got channel: {channel}, num_channels: {num_channel}.")
|
|
1125
|
+
_check()
|
|
1126
|
+
|
|
1127
|
+
@staticmethod
|
|
1128
|
+
@constexpr
|
|
1129
|
+
def _check_dtype(dtype, valid_dtypes, prim_name=None):
|
|
1130
|
+
validator.check_type_name("input", dtype, valid_dtypes, prim_name)
|
|
1291
1131
|
|
|
1292
1132
|
def extend_repr(self):
|
|
1293
1133
|
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
|
1134
|
+
|
|
1135
|
+
def construct(self, x):
|
|
1136
|
+
self._check_input_dim(self.shape(x), self.cls_name)
|
|
1137
|
+
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
|
|
1138
|
+
output = self._cal_output(x)
|
|
1139
|
+
return output
|