mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -62,7 +62,8 @@ class _BatchNorm(Cell):
|
|
|
62
62
|
moving_mean_init='zeros',
|
|
63
63
|
moving_var_init='ones',
|
|
64
64
|
use_batch_statistics=None,
|
|
65
|
-
data_format='NCHW'
|
|
65
|
+
data_format='NCHW',
|
|
66
|
+
dtype=mstype.float32):
|
|
66
67
|
"""Initialize _BatchNorm."""
|
|
67
68
|
super(_BatchNorm, self).__init__()
|
|
68
69
|
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
|
@@ -87,13 +88,13 @@ class _BatchNorm(Cell):
|
|
|
87
88
|
self.moving_mean_init = moving_mean_init
|
|
88
89
|
self.moving_var_init = moving_var_init
|
|
89
90
|
self.moving_mean = Parameter(initializer(
|
|
90
|
-
moving_mean_init, num_features), name="mean", requires_grad=False)
|
|
91
|
+
moving_mean_init, num_features, dtype=dtype), name="mean", requires_grad=False)
|
|
91
92
|
self.moving_variance = Parameter(initializer(
|
|
92
|
-
moving_var_init, num_features), name="variance", requires_grad=False)
|
|
93
|
+
moving_var_init, num_features, dtype=dtype), name="variance", requires_grad=False)
|
|
93
94
|
self.gamma = Parameter(initializer(
|
|
94
|
-
gamma_init, num_features), name="gamma", requires_grad=affine)
|
|
95
|
+
gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
|
|
95
96
|
self.beta = Parameter(initializer(
|
|
96
|
-
beta_init, num_features), name="beta", requires_grad=affine)
|
|
97
|
+
beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
|
|
97
98
|
|
|
98
99
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
99
100
|
|
|
@@ -194,23 +195,28 @@ class BatchNorm1d(_BatchNorm):
|
|
|
194
195
|
affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
|
|
195
196
|
Default: ``True`` .
|
|
196
197
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
197
|
-
The values of str refer to the function `initializer
|
|
198
|
-
|
|
198
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
199
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
200
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
199
201
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
200
|
-
The values of str refer to the function `initializer
|
|
201
|
-
|
|
202
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
203
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
204
|
+
including ``'zeros'`` , ``'ones'``, etc. Default: ``'zeros'`` .
|
|
202
205
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
203
|
-
The values of str refer to the function `initializer
|
|
204
|
-
|
|
206
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
207
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
208
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
|
|
205
209
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
|
206
|
-
The values of str refer to the function `initializer
|
|
207
|
-
|
|
210
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
211
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
212
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
208
213
|
use_batch_statistics (bool): If ``true`` , use the mean value and variance value of current batch data. If
|
|
209
214
|
``false`` , use the mean value and variance value of specified value. If ``None`` , the training process
|
|
210
215
|
will use the mean and variance of current batch data and track the running mean and variance, the
|
|
211
216
|
evaluation process will use the running mean and variance. Default: ``None`` .
|
|
212
217
|
data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
|
|
213
218
|
Default: ``'NCHW'`` .
|
|
219
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
214
220
|
|
|
215
221
|
Inputs:
|
|
216
222
|
- **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
|
|
@@ -234,7 +240,7 @@ class BatchNorm1d(_BatchNorm):
|
|
|
234
240
|
>>> import mindspore as ms
|
|
235
241
|
>>> net = ms.nn.BatchNorm1d(num_features=4)
|
|
236
242
|
>>> x = ms.Tensor(np.array([[0.7, 0.5, 0.5, 0.6],
|
|
237
|
-
...
|
|
243
|
+
... [0.5, 0.4, 0.6, 0.9]]).astype(np.float32))
|
|
238
244
|
>>> output = net(x)
|
|
239
245
|
>>> print(output)
|
|
240
246
|
[[ 0.6999965 0.4999975 0.4999975 0.59999704 ]
|
|
@@ -285,17 +291,21 @@ class BatchNorm2d(_BatchNorm):
|
|
|
285
291
|
affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
|
|
286
292
|
Default: ``True`` .
|
|
287
293
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
288
|
-
The values of str refer to the function `initializer
|
|
289
|
-
|
|
294
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
295
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
296
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
290
297
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
291
|
-
The values of str refer to the function `initializer
|
|
292
|
-
|
|
298
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
299
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
300
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
|
|
293
301
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
294
|
-
The values of str refer to the function `initializer
|
|
295
|
-
|
|
302
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
303
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
304
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
|
|
296
305
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
|
297
|
-
The values of str refer to the function `initializer
|
|
298
|
-
|
|
306
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
307
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
308
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
299
309
|
use_batch_statistics (bool): Default: ``None`` .
|
|
300
310
|
|
|
301
311
|
- If ``true`` , use the mean value and variance value of current batch data and track running mean
|
|
@@ -307,6 +317,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
307
317
|
|
|
308
318
|
data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
|
|
309
319
|
Default: ``'NCHW'`` .
|
|
320
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
310
321
|
|
|
311
322
|
Inputs:
|
|
312
323
|
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Supported types: float16, float32.
|
|
@@ -369,21 +380,26 @@ class BatchNorm3d(Cell):
|
|
|
369
380
|
running_mean and running_var computation. Default: ``0.9`` .
|
|
370
381
|
affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
|
|
371
382
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
372
|
-
The values of str refer to the function `initializer
|
|
373
|
-
|
|
383
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
384
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
385
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
374
386
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
375
|
-
The values of str refer to the function `initializer
|
|
376
|
-
|
|
387
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
388
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
389
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
|
|
377
390
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
378
|
-
The values of str refer to the function `initializer
|
|
379
|
-
|
|
391
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
392
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
393
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
|
|
380
394
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
|
381
|
-
The values of str refer to the function `initializer
|
|
382
|
-
|
|
395
|
+
The values of str refer to the function `mindspore.common.initializer
|
|
396
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
397
|
+
including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
|
|
383
398
|
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If
|
|
384
399
|
``false``, use the mean value and variance value of specified value. If ``None`` , the training process
|
|
385
400
|
will use the mean and variance of current batch data and track the running mean and variance, the
|
|
386
401
|
evaluation process will use the running mean and variance. Default: ``None`` .
|
|
402
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
387
403
|
|
|
388
404
|
Inputs:
|
|
389
405
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
|
@@ -420,7 +436,8 @@ class BatchNorm3d(Cell):
|
|
|
420
436
|
beta_init='zeros',
|
|
421
437
|
moving_mean_init='zeros',
|
|
422
438
|
moving_var_init='ones',
|
|
423
|
-
use_batch_statistics=None
|
|
439
|
+
use_batch_statistics=None,
|
|
440
|
+
dtype=mstype.float32):
|
|
424
441
|
"""Initialize BatchNorm3d."""
|
|
425
442
|
super(BatchNorm3d, self).__init__()
|
|
426
443
|
self.bn2d = BatchNorm2d(num_features=num_features,
|
|
@@ -432,7 +449,8 @@ class BatchNorm3d(Cell):
|
|
|
432
449
|
moving_mean_init=moving_mean_init,
|
|
433
450
|
moving_var_init=moving_var_init,
|
|
434
451
|
use_batch_statistics=use_batch_statistics,
|
|
435
|
-
data_format="NCHW"
|
|
452
|
+
data_format="NCHW",
|
|
453
|
+
dtype=dtype)
|
|
436
454
|
self.shape = P.Shape()
|
|
437
455
|
self.reshape = P.Reshape()
|
|
438
456
|
|
|
@@ -477,6 +495,7 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
477
495
|
|
|
478
496
|
Note:
|
|
479
497
|
Currently, SyncBatchNorm only supports 2D and 4D inputs.
|
|
498
|
+
:math:`\gamma` and :math:`\beta` are trainable scale and shift.
|
|
480
499
|
|
|
481
500
|
Args:
|
|
482
501
|
num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
|
|
@@ -505,6 +524,7 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
505
524
|
Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
|
|
506
525
|
group. All int values must be in [0, rank_size) and different from each other. Default: ``None`` ,
|
|
507
526
|
indicating synchronization across all devices.
|
|
527
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
508
528
|
|
|
509
529
|
Inputs:
|
|
510
530
|
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
@@ -529,11 +549,14 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
529
549
|
|
|
530
550
|
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
531
551
|
Please see the `Ascend tutorial
|
|
532
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
552
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/rank_table.html>`_
|
|
533
553
|
for more details.
|
|
534
554
|
|
|
535
|
-
For the GPU devices, users need to prepare the host file and mpi, please see the `
|
|
536
|
-
<https://www.mindspore.cn/tutorials/experts/en/r2.
|
|
555
|
+
For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
|
|
556
|
+
<https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/mpirun.html>`_ .
|
|
557
|
+
|
|
558
|
+
For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
|
|
559
|
+
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/dynamic_cluster.html>`_ .
|
|
537
560
|
|
|
538
561
|
This example should be run with multiple devices.
|
|
539
562
|
|
|
@@ -567,7 +590,8 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
567
590
|
moving_mean_init='zeros',
|
|
568
591
|
moving_var_init='ones',
|
|
569
592
|
use_batch_statistics=None,
|
|
570
|
-
process_groups=None
|
|
593
|
+
process_groups=None,
|
|
594
|
+
dtype=mstype.float32):
|
|
571
595
|
"""Initialize SyncBatchNorm."""
|
|
572
596
|
super(SyncBatchNorm, self).__init__(num_features,
|
|
573
597
|
eps,
|
|
@@ -577,7 +601,8 @@ class SyncBatchNorm(_BatchNorm):
|
|
|
577
601
|
beta_init,
|
|
578
602
|
moving_mean_init,
|
|
579
603
|
moving_var_init,
|
|
580
|
-
use_batch_statistics
|
|
604
|
+
use_batch_statistics,
|
|
605
|
+
dtype=dtype)
|
|
581
606
|
self.is_global = False
|
|
582
607
|
self.group_name = None
|
|
583
608
|
self.process_groups = process_groups
|
|
@@ -652,27 +677,28 @@ class LayerNorm(Cell):
|
|
|
652
677
|
normalization on a mini-batch of inputs for each single training case as described
|
|
653
678
|
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
|
|
654
679
|
Normalization, Layer Normalization performs exactly the same computation at training and
|
|
655
|
-
testing time. It is applied across all channels
|
|
656
|
-
and
|
|
680
|
+
testing time. It is applied across all channels and pixel but only one batch size.
|
|
681
|
+
:math:`\gamma` and :math:`\beta` are trainable scale and shift.
|
|
682
|
+
It can be described using the following formula:
|
|
657
683
|
|
|
658
684
|
.. math::
|
|
659
685
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
660
686
|
|
|
661
687
|
Args:
|
|
662
688
|
normalized_shape (Union(tuple[int], list[int])): The normalization is performed over axis
|
|
663
|
-
`begin_norm_axis ... R - 1`.
|
|
689
|
+
`begin_norm_axis ... R - 1`. R is the dimension size of input `x`.
|
|
664
690
|
begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions
|
|
665
|
-
`begin_norm_axis:
|
|
666
|
-
begin_params_axis (int): The
|
|
667
|
-
|
|
668
|
-
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: ``-1`` .
|
|
691
|
+
`begin_norm_axis: R`, the value should be in [-1, R). Default: ``-1`` .
|
|
692
|
+
begin_params_axis (int): The begin axis of the parameter input :math:`(\gamma, \beta)` to
|
|
693
|
+
apply LayerNorm, the value should be in [-1, R). Default: ``-1`` .
|
|
669
694
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
|
|
670
695
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
|
|
671
696
|
``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
|
|
672
697
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
|
|
673
698
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
|
|
674
699
|
``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
|
|
675
|
-
epsilon (float):
|
|
700
|
+
epsilon (float): A value added to the denominator for numerical stability(:math:`\epsilon`). Default: ``1e-7`` .
|
|
701
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
676
702
|
|
|
677
703
|
Inputs:
|
|
678
704
|
- **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
|
|
@@ -706,21 +732,27 @@ class LayerNorm(Cell):
|
|
|
706
732
|
begin_params_axis=-1,
|
|
707
733
|
gamma_init='ones',
|
|
708
734
|
beta_init='zeros',
|
|
709
|
-
epsilon=1e-7
|
|
735
|
+
epsilon=1e-7,
|
|
736
|
+
dtype=mstype.float32
|
|
710
737
|
):
|
|
711
738
|
"""Initialize LayerNorm."""
|
|
712
739
|
super(LayerNorm, self).__init__()
|
|
713
740
|
if not isinstance(normalized_shape, (tuple, list)):
|
|
714
741
|
raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
|
|
715
742
|
f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
|
|
743
|
+
if not normalized_shape:
|
|
744
|
+
raise ValueError(
|
|
745
|
+
f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
|
|
746
|
+
f"least one element, but got normalized_shape = {normalized_shape}"
|
|
747
|
+
)
|
|
716
748
|
self.normalized_shape = normalized_shape
|
|
717
749
|
self.begin_norm_axis = begin_norm_axis
|
|
718
750
|
self.begin_params_axis = begin_params_axis
|
|
719
751
|
self.epsilon = epsilon
|
|
720
752
|
self.gamma = Parameter(initializer(
|
|
721
|
-
gamma_init, normalized_shape), name="gamma")
|
|
753
|
+
gamma_init, normalized_shape, dtype=dtype), name="gamma")
|
|
722
754
|
self.beta = Parameter(initializer(
|
|
723
|
-
beta_init, normalized_shape), name="beta")
|
|
755
|
+
beta_init, normalized_shape, dtype=dtype), name="beta")
|
|
724
756
|
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
|
|
725
757
|
begin_params_axis=self.begin_params_axis,
|
|
726
758
|
epsilon=self.epsilon)
|
|
@@ -743,7 +775,8 @@ class _InstanceNorm(Cell):
|
|
|
743
775
|
momentum=0.1,
|
|
744
776
|
affine=True,
|
|
745
777
|
gamma_init='ones',
|
|
746
|
-
beta_init='zeros'
|
|
778
|
+
beta_init='zeros',
|
|
779
|
+
dtype=mstype.float32):
|
|
747
780
|
"""Initialize Normalization base class."""
|
|
748
781
|
super(_InstanceNorm, self).__init__()
|
|
749
782
|
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
|
@@ -760,12 +793,13 @@ class _InstanceNorm(Cell):
|
|
|
760
793
|
f"but got {momentum}.")
|
|
761
794
|
self.num_features = num_features
|
|
762
795
|
self.eps = eps
|
|
763
|
-
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
|
|
764
|
-
self.moving_variance = Parameter(initializer('ones', num_features), name="variance",
|
|
796
|
+
self.moving_mean = Parameter(initializer('zeros', num_features, dtype=dtype), name="mean", requires_grad=False)
|
|
797
|
+
self.moving_variance = Parameter(initializer('ones', num_features, dtype=dtype), name="variance",
|
|
798
|
+
requires_grad=False)
|
|
765
799
|
self.gamma = Parameter(initializer(
|
|
766
|
-
gamma_init, num_features), name="gamma", requires_grad=affine)
|
|
800
|
+
gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
|
|
767
801
|
self.beta = Parameter(initializer(
|
|
768
|
-
beta_init, num_features), name="beta", requires_grad=affine)
|
|
802
|
+
beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
|
|
769
803
|
|
|
770
804
|
self.shape = P.Shape()
|
|
771
805
|
self.momentum = momentum
|
|
@@ -829,6 +863,7 @@ class InstanceNorm1d(_InstanceNorm):
|
|
|
829
863
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
830
864
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
|
|
831
865
|
When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
|
|
866
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
832
867
|
|
|
833
868
|
Inputs:
|
|
834
869
|
- **x** (Tensor) - Tensor of shape :math:`(N, C, L)`. Data type: float16 or float32.
|
|
@@ -906,6 +941,7 @@ class InstanceNorm2d(_InstanceNorm):
|
|
|
906
941
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
907
942
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
|
|
908
943
|
When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
|
|
944
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
909
945
|
|
|
910
946
|
Inputs:
|
|
911
947
|
- **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
|
|
@@ -982,6 +1018,7 @@ class InstanceNorm3d(_InstanceNorm):
|
|
|
982
1018
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
983
1019
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
|
|
984
1020
|
When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
|
|
1021
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
985
1022
|
|
|
986
1023
|
Inputs:
|
|
987
1024
|
- **x** (Tensor) - Tensor of shape :math:`(N, C, D, H, W)`. Data type: float16 or float32.
|
|
@@ -1031,7 +1068,9 @@ class GroupNorm(Cell):
|
|
|
1031
1068
|
normalization on a mini-batch of inputs for each single training case as described
|
|
1032
1069
|
in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization
|
|
1033
1070
|
divides the channels into groups and computes within each group the mean and variance for normalization,
|
|
1034
|
-
and it performs very stable over a wide range of batch size.
|
|
1071
|
+
and it performs very stable over a wide range of batch size. :math:`\gamma` and :math:`\beta` are trainable scale
|
|
1072
|
+
and shift.
|
|
1073
|
+
It can be described using the following formula:
|
|
1035
1074
|
|
|
1036
1075
|
.. math::
|
|
1037
1076
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
@@ -1050,6 +1089,7 @@ class GroupNorm(Cell):
|
|
|
1050
1089
|
The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
|
|
1051
1090
|
``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` . If beta_init is a Tensor, the shape
|
|
1052
1091
|
must be :math:`(num\_channels)`.
|
|
1092
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
1053
1093
|
|
|
1054
1094
|
Inputs:
|
|
1055
1095
|
- **x** (Tensor) - The input feature with shape :math:`(N, C, H, W)` .
|
|
@@ -1084,7 +1124,8 @@ class GroupNorm(Cell):
|
|
|
1084
1124
|
[0. 0. 0. 0.]]]]
|
|
1085
1125
|
"""
|
|
1086
1126
|
|
|
1087
|
-
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'
|
|
1127
|
+
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros',
|
|
1128
|
+
dtype=mstype.float32):
|
|
1088
1129
|
"""Initialize GroupNorm."""
|
|
1089
1130
|
super(GroupNorm, self).__init__()
|
|
1090
1131
|
self.num_groups = validator.check_positive_int(num_groups, "num_groups", self.cls_name)
|
|
@@ -1096,27 +1137,27 @@ class GroupNorm(Cell):
|
|
|
1096
1137
|
self.affine = validator.check_bool(affine, arg_name="affine", prim_name=self.cls_name)
|
|
1097
1138
|
|
|
1098
1139
|
self.gamma = Parameter(initializer(
|
|
1099
|
-
gamma_init, num_channels), name="gamma", requires_grad=affine)
|
|
1140
|
+
gamma_init, num_channels, dtype=dtype), name="gamma", requires_grad=affine)
|
|
1100
1141
|
self.beta = Parameter(initializer(
|
|
1101
|
-
beta_init, num_channels), name="beta", requires_grad=affine)
|
|
1142
|
+
beta_init, num_channels, dtype=dtype), name="beta", requires_grad=affine)
|
|
1143
|
+
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
1144
|
+
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
1102
1145
|
self.shape = F.shape
|
|
1103
1146
|
self.reshape = F.reshape
|
|
1104
|
-
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
1105
1147
|
self.square = F.square
|
|
1106
|
-
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
1107
1148
|
self.sqrt = P.Sqrt()
|
|
1108
1149
|
|
|
1109
1150
|
def _cal_output(self, x):
|
|
1110
1151
|
"""calculate groupnorm output"""
|
|
1111
|
-
batch, channel, height, width =
|
|
1152
|
+
batch, channel, height, width = F.shape(x)
|
|
1112
1153
|
self._channel_check(channel, self.num_channels, self.cls_name)
|
|
1113
|
-
x =
|
|
1154
|
+
x = F.reshape(x, (batch, self.num_groups, -1))
|
|
1114
1155
|
mean = self.reduce_mean(x, 2)
|
|
1115
|
-
var = self.reduce_sum(
|
|
1156
|
+
var = F.div(self.reduce_sum(F.square(F.sub(x, mean)), 2), (channel * height * width / self.num_groups))
|
|
1116
1157
|
std = self.sqrt(var + self.eps)
|
|
1117
|
-
x = (x
|
|
1118
|
-
x =
|
|
1119
|
-
output = x *
|
|
1158
|
+
x = F.div(F.sub(x, mean), std)
|
|
1159
|
+
x = F.reshape(x, (batch, channel, height, width))
|
|
1160
|
+
output = F.add(x * F.reshape(self.gamma, (-1, 1, 1)), F.reshape(self.beta, (-1, 1, 1)))
|
|
1120
1161
|
return output
|
|
1121
1162
|
|
|
1122
1163
|
@staticmethod
|
|
@@ -1144,7 +1185,7 @@ class GroupNorm(Cell):
|
|
|
1144
1185
|
return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
|
|
1145
1186
|
|
|
1146
1187
|
def construct(self, x):
|
|
1147
|
-
self._check_input_dim(
|
|
1188
|
+
self._check_input_dim(F.shape(x), self.cls_name)
|
|
1148
1189
|
self._check_dtype(x.dtype, [mstype.float16, mstype.float32], self.cls_name)
|
|
1149
1190
|
output = self._cal_output(x)
|
|
1150
1191
|
return output
|
mindspore/nn/layer/padding.py
CHANGED
|
@@ -220,7 +220,7 @@ class _ConstantPadNd(Cell):
|
|
|
220
220
|
output = ops.Pad(new_padding)(x)
|
|
221
221
|
mask = ops.Pad(new_padding)(mask)
|
|
222
222
|
ones = ops.OnesLike()(output)
|
|
223
|
-
value = ops.
|
|
223
|
+
value = ops.fill(output.dtype, output.shape, self.value)
|
|
224
224
|
output = ops.Add()(ops.Mul()(mask, output), ops.Mul()(ops.Sub()(ones, mask), value))
|
|
225
225
|
slice_op = ops.Slice()
|
|
226
226
|
begin, size = _get_begin_size(output.shape, start, end)
|