mindspore 2.0.0a0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -2
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/amp.py +52 -57
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +610 -541
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -18,6 +18,7 @@ from __future__ import division
|
|
|
18
18
|
|
|
19
19
|
import itertools
|
|
20
20
|
import numbers
|
|
21
|
+
import hashlib
|
|
21
22
|
|
|
22
23
|
from mindspore.ops import operations as P
|
|
23
24
|
from mindspore.ops import functional as F
|
|
@@ -25,21 +26,26 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
25
26
|
from mindspore.common.parameter import Parameter
|
|
26
27
|
from mindspore.common.initializer import initializer, Initializer
|
|
27
28
|
from mindspore.common.tensor import Tensor
|
|
28
|
-
from mindspore.ops.primitive import constexpr
|
|
29
|
+
from mindspore.ops.primitive import constexpr, _primexpr
|
|
29
30
|
import mindspore.context as context
|
|
30
|
-
from mindspore
|
|
31
|
-
from mindspore._checkparam import Validator as validator
|
|
31
|
+
from mindspore import _checkparam as validator
|
|
32
32
|
from mindspore._extends import cell_attr_register
|
|
33
33
|
from mindspore.communication.management import get_group_size, get_rank
|
|
34
34
|
from mindspore.communication import management
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
36
|
from mindspore.parallel._utils import _is_in_auto_parallel_mode
|
|
37
37
|
from mindspore.nn.cell import Cell
|
|
38
|
+
from mindspore import log as logger
|
|
38
39
|
|
|
39
40
|
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
|
|
40
41
|
'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
|
|
41
42
|
|
|
42
|
-
|
|
43
|
+
|
|
44
|
+
def _check_dim(val, target, cls_name):
|
|
45
|
+
def _check(val, target, cls_name):
|
|
46
|
+
if val != target:
|
|
47
|
+
raise ValueError(f"For '{cls_name}', the in_shape must have {target} dims, but got {val}.")
|
|
48
|
+
_check(val, target, cls_name)
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
class _BatchNorm(Cell):
|
|
@@ -121,11 +127,13 @@ class _BatchNorm(Cell):
|
|
|
121
127
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
|
122
128
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
|
123
129
|
|
|
130
|
+
|
|
124
131
|
@staticmethod
|
|
125
|
-
@
|
|
132
|
+
@_primexpr
|
|
126
133
|
def _check_input_dim(shape, cls_name):
|
|
127
134
|
raise NotImplementedError
|
|
128
135
|
|
|
136
|
+
|
|
129
137
|
def construct(self, x):
|
|
130
138
|
self._check_input_dim(self.shape(x), self.cls_name)
|
|
131
139
|
if self.use_batch_statistics is None:
|
|
@@ -164,7 +172,7 @@ class _BatchNorm(Cell):
|
|
|
164
172
|
class BatchNorm1d(_BatchNorm):
|
|
165
173
|
r"""
|
|
166
174
|
This layer
|
|
167
|
-
applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
|
|
175
|
+
applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D or 2D inputs) to
|
|
168
176
|
reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
|
|
169
177
|
For the setailed contents, refer to `Batch Normalization: Accelerating Deep Network Training by
|
|
170
178
|
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
|
|
@@ -179,14 +187,14 @@ class BatchNorm1d(_BatchNorm):
|
|
|
179
187
|
recommended to be changed after net was initialized.
|
|
180
188
|
|
|
181
189
|
Args:
|
|
182
|
-
num_features (int): `C`
|
|
183
|
-
eps (float):
|
|
190
|
+
num_features (int): number of features or channels `C` of the input `x` .
|
|
191
|
+
eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
|
|
184
192
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
185
193
|
running_mean and running_var computation. Default: 0.9.
|
|
186
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
|
187
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
194
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
|
|
195
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
188
196
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
|
189
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
197
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
190
198
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
191
199
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
192
200
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
@@ -200,10 +208,11 @@ class BatchNorm1d(_BatchNorm):
|
|
|
200
208
|
Default: 'NCHW'.
|
|
201
209
|
|
|
202
210
|
Inputs:
|
|
203
|
-
- **x** (Tensor) - Tensor of shape :math:`(N,
|
|
211
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
|
|
212
|
+
where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
|
|
204
213
|
|
|
205
214
|
Outputs:
|
|
206
|
-
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N,
|
|
215
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
|
|
207
216
|
|
|
208
217
|
Raises:
|
|
209
218
|
TypeError: If `num_features` is not an int.
|
|
@@ -228,11 +237,13 @@ class BatchNorm1d(_BatchNorm):
|
|
|
228
237
|
"""
|
|
229
238
|
|
|
230
239
|
@staticmethod
|
|
231
|
-
@
|
|
240
|
+
@_primexpr
|
|
232
241
|
def _check_input_dim(shape, cls_name):
|
|
242
|
+
def _check(dim):
|
|
243
|
+
if dim not in (2, 3):
|
|
244
|
+
raise ValueError(f"For '{cls_name}', the must have 2 dims or 3 dims, but got {dim}.")
|
|
233
245
|
dim = len(shape)
|
|
234
|
-
|
|
235
|
-
raise ValueError(f"For '{cls_name}', the in_shape must have 2 dims, but got {dim}.")
|
|
246
|
+
_check(dim)
|
|
236
247
|
|
|
237
248
|
|
|
238
249
|
class BatchNorm2d(_BatchNorm):
|
|
@@ -254,22 +265,22 @@ class BatchNorm2d(_BatchNorm):
|
|
|
254
265
|
Note that the formula for updating the :math:`moving\_mean` and :math:`moving\_var` is
|
|
255
266
|
|
|
256
267
|
.. math::
|
|
257
|
-
\text{moving_mean}=\text{moving_mean
|
|
258
|
-
\text{moving_var}=\text{moving_var
|
|
268
|
+
\text{moving_mean}=\text{moving_mean*momentum}+μ_β\text{*(1−momentum)}\\
|
|
269
|
+
\text{moving_var}=\text{moving_var*momentum}+σ^2_β\text{*(1−momentum)}
|
|
259
270
|
|
|
260
271
|
where :math:`moving\_mean` is the updated mean, :math:`moving\_var` is the updated variance,
|
|
261
272
|
:math:`μ_β, σ^2_β` are the observed value (mean and variance) of each batch of data.
|
|
262
273
|
|
|
263
274
|
Args:
|
|
264
|
-
num_features (int): The number of channels of the input tensor. Expected input size is (N, C, H, W)
|
|
275
|
+
num_features (int): The number of channels of the input tensor. Expected input size is :math:`(N, C, H, W)`,
|
|
265
276
|
`C` represents the number of channels.
|
|
266
|
-
eps (float):
|
|
277
|
+
eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-5.
|
|
267
278
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
268
279
|
running_mean and running_var computation. Default: 0.9.
|
|
269
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
|
270
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
280
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned. Default: True.
|
|
281
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
271
282
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
|
272
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
283
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
273
284
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
274
285
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
275
286
|
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
|
@@ -288,10 +299,10 @@ class BatchNorm2d(_BatchNorm):
|
|
|
288
299
|
Default: 'NCHW'.
|
|
289
300
|
|
|
290
301
|
Inputs:
|
|
291
|
-
- **x** (Tensor) - Tensor of shape :math:`(N,
|
|
302
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
|
|
292
303
|
|
|
293
304
|
Outputs:
|
|
294
|
-
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N,
|
|
305
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
|
|
295
306
|
|
|
296
307
|
Raises:
|
|
297
308
|
TypeError: If `num_features` is not an int.
|
|
@@ -320,11 +331,10 @@ class BatchNorm2d(_BatchNorm):
|
|
|
320
331
|
"""
|
|
321
332
|
|
|
322
333
|
@staticmethod
|
|
323
|
-
@
|
|
334
|
+
@_primexpr
|
|
324
335
|
def _check_input_dim(shape, cls_name):
|
|
325
336
|
dim = len(shape)
|
|
326
|
-
|
|
327
|
-
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
337
|
+
_check_dim(dim, 4, cls_name)
|
|
328
338
|
|
|
329
339
|
|
|
330
340
|
class BatchNorm3d(Cell):
|
|
@@ -344,7 +354,7 @@ class BatchNorm3d(Cell):
|
|
|
344
354
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
345
355
|
|
|
346
356
|
Args:
|
|
347
|
-
num_features (int): `C` from an expected input of size (N, C, D, H, W).
|
|
357
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)` .
|
|
348
358
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
349
359
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
350
360
|
running_mean and running_var computation. Default: 0.9.
|
|
@@ -414,11 +424,11 @@ class BatchNorm3d(Cell):
|
|
|
414
424
|
self.reshape = P.Reshape()
|
|
415
425
|
|
|
416
426
|
@staticmethod
|
|
417
|
-
@
|
|
427
|
+
@_primexpr
|
|
418
428
|
def _check_input_dim(shape, cls_name):
|
|
419
429
|
dim = len(shape)
|
|
420
|
-
|
|
421
|
-
|
|
430
|
+
_check_dim(dim, 5, cls_name)
|
|
431
|
+
|
|
422
432
|
|
|
423
433
|
def construct(self, x):
|
|
424
434
|
x_shape = self.shape(x)
|
|
@@ -429,6 +439,16 @@ class BatchNorm3d(Cell):
|
|
|
429
439
|
return bn3d_out
|
|
430
440
|
|
|
431
441
|
|
|
442
|
+
SYNCBN_GROUP_DICT = None
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _syncbatchnorm_group_dict():
|
|
446
|
+
global SYNCBN_GROUP_DICT
|
|
447
|
+
if SYNCBN_GROUP_DICT is None:
|
|
448
|
+
SYNCBN_GROUP_DICT = dict()
|
|
449
|
+
return SYNCBN_GROUP_DICT
|
|
450
|
+
|
|
451
|
+
|
|
432
452
|
class SyncBatchNorm(_BatchNorm):
|
|
433
453
|
r"""
|
|
434
454
|
Sync Batch Normalization layer over a N-dimension input.
|
|
@@ -446,15 +466,16 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
446
466
|
Currently, SyncBatchNorm only supports 2D and 4D inputs.
|
|
447
467
|
|
|
448
468
|
Args:
|
|
449
|
-
num_features (int): `C` from an expected input of size (N, C, H, W)
|
|
450
|
-
eps (float):
|
|
469
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
|
|
470
|
+
eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: 1e-5.
|
|
451
471
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
452
472
|
running_mean and running_var computation. Default: 0.9.
|
|
453
|
-
affine (bool): A bool value. When set to True, gamma and beta can be learned.
|
|
454
|
-
|
|
473
|
+
affine (bool): A bool value. When set to True, :math:`\gamma` and :math:`\beta` can be learned.
|
|
474
|
+
Default: True.
|
|
475
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
455
476
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
456
477
|
'he_uniform', etc. Default: 'ones'.
|
|
457
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
478
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
458
479
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
459
480
|
'he_uniform', etc. Default: 'zeros'.
|
|
460
481
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
@@ -495,11 +516,11 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
495
516
|
|
|
496
517
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
497
518
|
Please see the `Ascend tutorial
|
|
498
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
519
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_ascend.html#preparations>`_
|
|
499
520
|
for more details.
|
|
500
521
|
|
|
501
522
|
For the GPU devices, users need to prepare the host file and mpi, please see the `GPU tutorial
|
|
502
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.0
|
|
523
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/train_gpu.html#preparation>`_ .
|
|
503
524
|
|
|
504
525
|
This example should be run with multiple devices.
|
|
505
526
|
|
|
@@ -525,7 +546,7 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
525
546
|
[[ 0.999995 0.999995 ]
|
|
526
547
|
[ 0.999995 0.999995 ]]]]
|
|
527
548
|
"""
|
|
528
|
-
|
|
549
|
+
@cell_attr_register(attrs=['num_features', 'process_groups'])
|
|
529
550
|
def __init__(self,
|
|
530
551
|
num_features,
|
|
531
552
|
eps=1e-5,
|
|
@@ -548,7 +569,7 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
548
569
|
moving_var_init,
|
|
549
570
|
use_batch_statistics)
|
|
550
571
|
self.is_global = False
|
|
551
|
-
|
|
572
|
+
self.group_name = None
|
|
552
573
|
self.process_groups = process_groups
|
|
553
574
|
if self.process_groups != 0:
|
|
554
575
|
self.rank_id = get_rank()
|
|
@@ -560,43 +581,53 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
560
581
|
elif self.rank_size > 1:
|
|
561
582
|
self.is_global = True
|
|
562
583
|
self.group_device_num = self.rank_size
|
|
563
|
-
self.device_list = [i for i in range(0, self.rank_size)]
|
|
564
584
|
if context.get_context("device_target") == "Ascend":
|
|
565
|
-
|
|
566
|
-
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
|
567
|
-
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
|
585
|
+
self.group_name = "hccl_world_group"
|
|
568
586
|
elif context.get_context("device_target") == "GPU":
|
|
569
|
-
|
|
570
|
-
SYNC_BN_GROUP_NAME = "nccl_world_group"
|
|
587
|
+
self.group_name = "nccl_world_group"
|
|
571
588
|
|
|
572
589
|
if self.is_global:
|
|
573
590
|
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
|
574
591
|
momentum=self.momentum,
|
|
575
|
-
group=
|
|
592
|
+
group=self.group_name,
|
|
576
593
|
device_num=self.group_device_num)
|
|
577
594
|
|
|
578
595
|
def _create_sync_groups(self):
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
596
|
+
""" create groups by process groups. """
|
|
597
|
+
for sub_group in self.process_groups:
|
|
598
|
+
validator.check_isinstance("sub group", sub_group, list)
|
|
599
|
+
self.group_device_num = len(sub_group)
|
|
600
|
+
if self.rank_id in sub_group and self.group_device_num > 1:
|
|
583
601
|
self.is_global = True
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
602
|
+
rank_list_name = '_'.join('%s' % id for id in sub_group)
|
|
603
|
+
group_dict = _syncbatchnorm_group_dict()
|
|
604
|
+
if rank_list_name not in group_dict:
|
|
605
|
+
md5 = hashlib.md5()
|
|
606
|
+
md5.update(rank_list_name.encode('utf-8'))
|
|
607
|
+
hash_name = md5.hexdigest()
|
|
608
|
+
self.group_name = str(self.group_device_num) + '_' + hash_name
|
|
609
|
+
group_dict[rank_list_name] = self.group_name
|
|
610
|
+
management.create_group(self.group_name, sub_group)
|
|
611
|
+
logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
|
|
612
|
+
rank_list_name, self.group_name))
|
|
613
|
+
else:
|
|
614
|
+
self.group_name = group_dict[rank_list_name]
|
|
615
|
+
logger.info("the group for {} already exists, no need to create".format(rank_list_name))
|
|
588
616
|
|
|
589
617
|
@staticmethod
|
|
590
|
-
@
|
|
618
|
+
@_primexpr
|
|
591
619
|
def _check_input_dim(shape, cls_name):
|
|
620
|
+
def _check(dim):
|
|
621
|
+
if dim not in (2, 4):
|
|
622
|
+
raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
|
|
592
623
|
dim = len(shape)
|
|
593
|
-
|
|
594
|
-
|
|
624
|
+
_check(dim)
|
|
625
|
+
|
|
595
626
|
|
|
596
627
|
def _check_rank_ids(self, process_groups, rank_size):
|
|
597
628
|
seen = set()
|
|
598
629
|
for rid in itertools.chain(*process_groups):
|
|
599
|
-
validator.check_int_range(rid, 0, rank_size,
|
|
630
|
+
validator.check_int_range(rid, 0, rank_size, validator.INC_LEFT, "rank id in process_groups", self.cls_name)
|
|
600
631
|
if rid in seen:
|
|
601
632
|
raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
|
|
602
633
|
f"but got {process_groups}.")
|
|
@@ -625,13 +656,13 @@ class LayerNorm(Cell):
|
|
|
625
656
|
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
|
626
657
|
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
|
627
658
|
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
|
628
|
-
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
659
|
+
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
629
660
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
630
661
|
'he_uniform', etc. Default: 'ones'.
|
|
631
|
-
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
662
|
+
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
632
663
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
633
664
|
'he_uniform', etc. Default: 'zeros'.
|
|
634
|
-
epsilon (float):
|
|
665
|
+
epsilon (float): :math:`\epsilon` added to the denominator for numerical stability. Default: 1e-7.
|
|
635
666
|
|
|
636
667
|
Inputs:
|
|
637
668
|
- **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
|
|
@@ -775,7 +806,7 @@ class InstanceNorm1d(_InstanceNorm):
|
|
|
775
806
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
776
807
|
|
|
777
808
|
Args:
|
|
778
|
-
num_features (int): `C` from an expected input of size (N, C, L)
|
|
809
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, L)`.
|
|
779
810
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
780
811
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
781
812
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -823,11 +854,11 @@ class InstanceNorm1d(_InstanceNorm):
|
|
|
823
854
|
"""
|
|
824
855
|
|
|
825
856
|
@staticmethod
|
|
826
|
-
@
|
|
857
|
+
@_primexpr
|
|
827
858
|
def _check_input_dim(shape, cls_name):
|
|
828
859
|
dim = len(shape)
|
|
829
|
-
|
|
830
|
-
|
|
860
|
+
_check_dim(dim, 3, cls_name)
|
|
861
|
+
|
|
831
862
|
|
|
832
863
|
|
|
833
864
|
class InstanceNorm2d(_InstanceNorm):
|
|
@@ -854,7 +885,7 @@ class InstanceNorm2d(_InstanceNorm):
|
|
|
854
885
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
855
886
|
|
|
856
887
|
Args:
|
|
857
|
-
num_features (int): `C` from an expected input of size (N, C, H, W)
|
|
888
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
|
|
858
889
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
859
890
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
860
891
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -902,11 +933,10 @@ class InstanceNorm2d(_InstanceNorm):
|
|
|
902
933
|
"""
|
|
903
934
|
|
|
904
935
|
@staticmethod
|
|
905
|
-
@
|
|
936
|
+
@_primexpr
|
|
906
937
|
def _check_input_dim(shape, cls_name):
|
|
907
938
|
dim = len(shape)
|
|
908
|
-
|
|
909
|
-
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
939
|
+
_check_dim(dim, 4, cls_name)
|
|
910
940
|
|
|
911
941
|
|
|
912
942
|
class InstanceNorm3d(_InstanceNorm):
|
|
@@ -933,7 +963,7 @@ class InstanceNorm3d(_InstanceNorm):
|
|
|
933
963
|
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
|
|
934
964
|
|
|
935
965
|
Args:
|
|
936
|
-
num_features (int): `C` from an expected input of size (N, C, D, H, W)
|
|
966
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)`.
|
|
937
967
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
938
968
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
939
969
|
running_mean and running_var computation. Default: 0.1.
|
|
@@ -979,12 +1009,12 @@ class InstanceNorm3d(_InstanceNorm):
|
|
|
979
1009
|
>>> print(output.shape)
|
|
980
1010
|
(2, 3, 5, 2, 2)
|
|
981
1011
|
"""
|
|
1012
|
+
|
|
982
1013
|
@staticmethod
|
|
983
|
-
@
|
|
1014
|
+
@_primexpr
|
|
984
1015
|
def _check_input_dim(shape, cls_name):
|
|
985
1016
|
dim = len(shape)
|
|
986
|
-
|
|
987
|
-
raise ValueError(f"For '{cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
|
1017
|
+
_check_dim(dim, 5, cls_name)
|
|
988
1018
|
|
|
989
1019
|
|
|
990
1020
|
class GroupNorm(Cell):
|
|
@@ -1007,10 +1037,10 @@ class GroupNorm(Cell):
|
|
|
1007
1037
|
affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
|
|
1008
1038
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
1009
1039
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
1010
|
-
'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be
|
|
1040
|
+
'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be :math:`(num\_channels)`.
|
|
1011
1041
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
1012
1042
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
1013
|
-
'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be
|
|
1043
|
+
'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be :math:`(num\_channels)`.
|
|
1014
1044
|
|
|
1015
1045
|
Inputs:
|
|
1016
1046
|
- **x** (Tensor) - The input feature with shape :math:`(N, C, H, W)` .
|
|
@@ -1079,19 +1109,20 @@ class GroupNorm(Cell):
|
|
|
1079
1109
|
return output
|
|
1080
1110
|
|
|
1081
1111
|
@staticmethod
|
|
1082
|
-
@
|
|
1112
|
+
@_primexpr
|
|
1083
1113
|
def _check_input_dim(shape, cls_name):
|
|
1084
1114
|
dim = len(shape)
|
|
1085
|
-
|
|
1086
|
-
raise ValueError(f"For '{cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
|
1115
|
+
_check_dim(dim, 4, cls_name)
|
|
1087
1116
|
|
|
1088
1117
|
@staticmethod
|
|
1089
|
-
@
|
|
1118
|
+
@_primexpr
|
|
1090
1119
|
def _channel_check(channel, num_channel, prim_name=None):
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1120
|
+
def _check():
|
|
1121
|
+
msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
|
|
1122
|
+
if channel != num_channel:
|
|
1123
|
+
raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to "
|
|
1124
|
+
f"num_channels, but got channel: {channel}, num_channels: {num_channel}.")
|
|
1125
|
+
_check()
|
|
1095
1126
|
|
|
1096
1127
|
@staticmethod
|
|
1097
1128
|
@constexpr
|