mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -16,11 +16,11 @@
|
|
|
16
16
|
Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
|
|
17
17
|
TransformerEncoder, TransformerDecoder, Transformer.
|
|
18
18
|
"""
|
|
19
|
-
import copy
|
|
20
19
|
import math
|
|
21
20
|
from typing import Union, Optional
|
|
22
21
|
import mindspore
|
|
23
22
|
import mindspore.ops as ops
|
|
23
|
+
import mindspore.common.dtype as mstype
|
|
24
24
|
from mindspore.common.tensor import Tensor
|
|
25
25
|
from mindspore.common.parameter import Parameter
|
|
26
26
|
from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
|
|
@@ -36,24 +36,17 @@ __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderL
|
|
|
36
36
|
'TransformerEncoder', 'TransformerDecoder', 'Transformer']
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
class _Linear(Dense):
|
|
40
|
-
def __init__(self, in_channels, out_channels, has_bias=True):
|
|
41
|
-
fan_in, _ = _calculate_fan_in_and_fan_out((out_channels, in_channels))
|
|
42
|
-
bound = 1 / math.sqrt(fan_in)
|
|
43
|
-
super().__init__(in_channels, out_channels, weight_init=HeUniform(math.sqrt(5)),
|
|
44
|
-
bias_init=Uniform(bound), has_bias=has_bias, activation=None)
|
|
45
|
-
|
|
46
|
-
|
|
47
39
|
class MultiheadAttention(Cell):
|
|
48
40
|
r"""
|
|
49
41
|
This is an implementation of multihead attention in the paper `Attention is all you need
|
|
50
|
-
<https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector
|
|
51
|
-
|
|
42
|
+
<https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector, the key vector and value vector,
|
|
43
|
+
the attention will be performed as the following:
|
|
52
44
|
|
|
53
45
|
.. math::
|
|
54
|
-
MultiHeadAttention(query, key,
|
|
46
|
+
MultiHeadAttention(query, key, value) = Concat(head_1, \dots, head_h)W^O
|
|
55
47
|
|
|
56
|
-
where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
|
|
48
|
+
where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`, and :math:`W^O` , :math:`W_i^Q` , :math:`W_i^K` ,
|
|
49
|
+
:math:`W_i^V` are weight matrices. The default input / output projection layers is with a bias.
|
|
57
50
|
|
|
58
51
|
if query, key and value tensor is same, then it will be self attention.
|
|
59
52
|
|
|
@@ -70,36 +63,37 @@ class MultiheadAttention(Cell):
|
|
|
70
63
|
vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
|
|
71
64
|
batch_first (bool): If ``True``, then the input and output shape are :math:`(batch, seq, feature)` ,
|
|
72
65
|
else :math:`(seq, batch, feature)` . Default: ``False``.
|
|
66
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
73
67
|
|
|
74
68
|
Inputs:
|
|
75
69
|
- **query** (Tensor): The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
|
|
76
70
|
otherwise the shape is :math:`(L, N, E_q)` when `batch_first=False` or :math:`(N, L, E_q)` when
|
|
77
|
-
`batch_first=True
|
|
78
|
-
and :math:`E_q` is the query embedding dimension `embed_dim`.
|
|
79
|
-
key-value pairs to produce the output.
|
|
71
|
+
`batch_first=True` , where :math:`L`is the target sequence length, :math:`N` is the batch size,
|
|
72
|
+
and :math:`E_q` is the query embedding dimension `embed_dim`. Supported types: float16, float32,
|
|
73
|
+
float64. Queries are compared against key-value pairs to produce the output.
|
|
80
74
|
- **key** (Tensor): The key embeddings. If `key` is unbatched, the shape is :math:`(S, E_k)`, otherwise
|
|
81
75
|
the shape is :math:`(S, N, E_k)` when `batch_first=False` or :math:`(N, S, E_k)` when
|
|
82
|
-
`batch_first=True
|
|
83
|
-
and :math:`E_k` is the key embedding dimension `kdim`.
|
|
76
|
+
`batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
|
|
77
|
+
and :math:`E_k` is the key embedding dimension `kdim`. Supported types: float16, float32, float64.
|
|
84
78
|
- **value** (Tensor): The value embeddings. If `value` is unbatched, the shape is :math:`(S, E_v)`,
|
|
85
79
|
otherwise the shape is :math:`(S, N, E_v)` when `batch_first=False` or :math:`(N, S, E_v)` when
|
|
86
|
-
`batch_first=True
|
|
87
|
-
and :math:`E_v` is the value embedding dimension `vdim`.
|
|
80
|
+
`batch_first=True` , where :math:`S` is the source sequence length, :math:`N` is the batch size,
|
|
81
|
+
and :math:`E_v` is the value embedding dimension `vdim`. Supported types: float16, float32, float64.
|
|
88
82
|
- **key_padding_mask** (Tensor, optional): If specified, a mask of shape :math:`(N, S)` indicating which
|
|
89
83
|
elements within `key` to ignore for the purpose of attention (i.e. treat as "padding").
|
|
90
|
-
For unbatched `query`, shape should be :math:`(S)`. Binary and
|
|
84
|
+
For unbatched `query`, shape should be :math:`(S)`. Binary and float masks are supported.
|
|
91
85
|
For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
|
|
92
86
|
the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
|
|
87
|
+
Supported float types: float16, float32, float64. Default: ``None``.
|
|
93
88
|
- **need_weights** (bool): Whether returns `attn_output_weights` in addition to `attn_outputs`.
|
|
94
89
|
Default: ``True``.
|
|
95
90
|
- **attn_mask** (Tensor, optional): If specified, a 2D or 3D mask preventing attention to certain positions.
|
|
96
|
-
Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{
|
|
91
|
+
Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num_heads}, L, S)`, where :math:`N` is the
|
|
97
92
|
batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length.
|
|
98
93
|
A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
|
|
99
|
-
in the batch.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
the attention weight.
|
|
94
|
+
in the batch. For a binary mask, a ``True`` value indicates that the corresponding position is not allowed
|
|
95
|
+
to attend. For a float mask, the mask values will be added to the attention weight.
|
|
96
|
+
Supported float types: float16, float32, float64. Default: ``None``.
|
|
103
97
|
- **average_attn_weights** (bool): If true, indicates that the returned `attn_weights` should be averaged
|
|
104
98
|
across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
|
|
105
99
|
has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
|
|
@@ -109,33 +103,39 @@ class MultiheadAttention(Cell):
|
|
|
109
103
|
|
|
110
104
|
- **attn_output** - Attention outputs. If input is unbatched, the output shape is :math:`(L, E)`, otherwise
|
|
111
105
|
the output shape is :math:`(L, N, E)` when `batch_first=False` or :math:`(N, L, E)` when
|
|
112
|
-
`batch_first=True
|
|
106
|
+
`batch_first=True` , where :math:`L` is the target sequence length, :math:`N` is the batch size,
|
|
113
107
|
and :math:`E` is the embedding dimension `embed_dim`.
|
|
114
108
|
- **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`,
|
|
115
109
|
returns attention weights averaged across heads with shape :math:`(L, S)` when input is unbatched or
|
|
116
110
|
:math:`(N, L, S)` when input is batched, where :math:`N` is the batch size, :math:`L` is
|
|
117
111
|
the target sequence length, and :math:`S` is the source sequence length.
|
|
118
112
|
If `average_attn_weights=False`, returns attention weights per
|
|
119
|
-
head of shape :math:`(\text{
|
|
120
|
-
:math:`(N, \text{
|
|
113
|
+
head of shape :math:`(\text{num_heads}, L, S)` when input is unbatched or
|
|
114
|
+
:math:`(N, \text{num_heads}, L, S)` when input is batched.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
ValueError: If the init argument `embed_dim` is not divisible by `num_heads`.
|
|
118
|
+
TypeError: If the input argument `key_padding_mask` is not bool or floating types.
|
|
121
119
|
|
|
122
120
|
Supported Platforms:
|
|
123
121
|
``Ascend`` ``GPU`` ``CPU``
|
|
124
122
|
|
|
125
123
|
Examples:
|
|
124
|
+
>>> import mindspore as ms
|
|
125
|
+
>>> import numpy as np
|
|
126
126
|
>>> embed_dim, num_heads = 128, 8
|
|
127
127
|
>>> seq_length, batch_size = 10, 8
|
|
128
|
-
>>> query = Tensor(np.random.randn(seq_length, batch_size, embed_dim),
|
|
129
|
-
>>> key = Tensor(np.random.randn(seq_length, batch_size, embed_dim),
|
|
130
|
-
>>> value = Tensor(np.random.randn(seq_length, batch_size, embed_dim),
|
|
131
|
-
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
128
|
+
>>> query = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
|
|
129
|
+
>>> key = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
|
|
130
|
+
>>> value = ms.Tensor(np.random.randn(seq_length, batch_size, embed_dim), ms.float32)
|
|
131
|
+
>>> multihead_attn = ms.nn.MultiheadAttention(embed_dim, num_heads)
|
|
132
132
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
133
133
|
>>> print(attn_output.shape)
|
|
134
134
|
(10, 8, 128)
|
|
135
135
|
"""
|
|
136
136
|
|
|
137
|
-
def __init__(self, embed_dim, num_heads, dropout=0
|
|
138
|
-
add_zero_attn=False, kdim=None, vdim=None, batch_first=False):
|
|
137
|
+
def __init__(self, embed_dim, num_heads, dropout=0.0, has_bias=True, add_bias_kv=False,
|
|
138
|
+
add_zero_attn=False, kdim=None, vdim=None, batch_first=False, dtype=mstype.float32):
|
|
139
139
|
super().__init__()
|
|
140
140
|
self.embed_dim = embed_dim
|
|
141
141
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
@@ -149,32 +149,39 @@ class MultiheadAttention(Cell):
|
|
|
149
149
|
if self.head_dim * num_heads != self.embed_dim:
|
|
150
150
|
raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")
|
|
151
151
|
|
|
152
|
+
if dtype is None:
|
|
153
|
+
dtype = mindspore.float32
|
|
152
154
|
if not self._qkv_same_embed_dim:
|
|
153
|
-
self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight')
|
|
154
|
-
self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim)), 'k_proj_weight')
|
|
155
|
-
self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim)), 'v_proj_weight')
|
|
155
|
+
self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim), dtype), 'q_proj_weight')
|
|
156
|
+
self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim), dtype), 'k_proj_weight')
|
|
157
|
+
self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim), dtype), 'v_proj_weight')
|
|
156
158
|
self.in_proj_weight = None
|
|
157
159
|
else:
|
|
158
|
-
self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim)
|
|
160
|
+
self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim), dtype),
|
|
161
|
+
'in_proj_weight')
|
|
159
162
|
self.q_proj_weight = None
|
|
160
163
|
self.k_proj_weight = None
|
|
161
164
|
self.v_proj_weight = None
|
|
162
165
|
|
|
163
166
|
if has_bias:
|
|
164
|
-
self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim)), 'in_proj_bias')
|
|
167
|
+
self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim), dtype), 'in_proj_bias')
|
|
165
168
|
else:
|
|
166
169
|
self.in_proj_bias = None
|
|
167
|
-
|
|
170
|
+
fan_in, _ = _calculate_fan_in_and_fan_out((embed_dim, embed_dim))
|
|
171
|
+
bound = 1 / math.sqrt(fan_in)
|
|
172
|
+
self.out_proj = Dense(embed_dim, embed_dim, has_bias=has_bias, weight_init=HeUniform(math.sqrt(5)),
|
|
173
|
+
bias_init=Uniform(bound), dtype=dtype)
|
|
168
174
|
|
|
169
175
|
if add_bias_kv:
|
|
170
|
-
self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_k')
|
|
171
|
-
self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_v')
|
|
176
|
+
self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_k')
|
|
177
|
+
self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim), dtype), 'bias_v')
|
|
172
178
|
else:
|
|
173
179
|
self.bias_k = self.bias_v = None
|
|
174
180
|
|
|
175
181
|
self.add_zero_attn = add_zero_attn
|
|
176
182
|
self.k_is_v = False
|
|
177
183
|
self.q_is_k = False
|
|
184
|
+
self.dtype = dtype
|
|
178
185
|
|
|
179
186
|
def __call__(self, *args, **kwargs):
|
|
180
187
|
query = kwargs.get('query', args[0])
|
|
@@ -215,7 +222,7 @@ class MultiheadAttention(Cell):
|
|
|
215
222
|
attn_mask=attn_mask, use_separate_proj_weight=True,
|
|
216
223
|
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
|
217
224
|
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
|
|
218
|
-
k_is_v=self.k_is_v, q_is_k=self.q_is_k)
|
|
225
|
+
k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
|
|
219
226
|
else:
|
|
220
227
|
attn_output, attn_output_weights = multi_head_attention_forward(
|
|
221
228
|
query, key, value, self.embed_dim, self.num_heads,
|
|
@@ -225,7 +232,7 @@ class MultiheadAttention(Cell):
|
|
|
225
232
|
training=self.training,
|
|
226
233
|
key_padding_mask=key_padding_mask,
|
|
227
234
|
attn_mask=attn_mask, average_attn_weights=average_attn_weights,
|
|
228
|
-
k_is_v=self.k_is_v, q_is_k=self.q_is_k)
|
|
235
|
+
k_is_v=self.k_is_v, q_is_k=self.q_is_k, dtype=self.dtype)
|
|
229
236
|
|
|
230
237
|
if self.batch_first and is_batched:
|
|
231
238
|
attn_output = attn_output.swapaxes(1, 0)
|
|
@@ -245,65 +252,90 @@ class TransformerEncoderLayer(Cell):
|
|
|
245
252
|
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
|
|
246
253
|
dropout (float): The dropout value. Default: ``0.1``.
|
|
247
254
|
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
|
|
248
|
-
can be a string (
|
|
249
|
-
a callable (
|
|
255
|
+
can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
|
|
256
|
+
:class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
|
|
257
|
+
:func:`mindspore.ops.gelu` ). Default: ``"relu"``.
|
|
250
258
|
layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
|
|
251
|
-
batch_first (bool): If `batch_first
|
|
252
|
-
|
|
259
|
+
batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
|
|
260
|
+
:math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
|
|
253
261
|
Default: ``False``.
|
|
254
|
-
norm_first (bool): If `norm_first = True`, layer norm is
|
|
255
|
-
operations
|
|
262
|
+
norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
|
|
263
|
+
operations; if `norm_first = False`, layer norm is located after the attention and feedforward
|
|
264
|
+
operations. Default: ``False``.
|
|
265
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
256
266
|
|
|
257
267
|
Inputs:
|
|
258
|
-
- **src** (Tensor): the sequence to the encoder layer.
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
268
|
+
- **src** (Tensor): the sequence to the encoder layer. For unbatched input, the shape is
|
|
269
|
+
:math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
|
|
270
|
+
`batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
|
|
271
|
+
length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
|
|
272
|
+
Supported types: float16, float32, float64.
|
|
273
|
+
- **src_mask** (Tensor, optional): the mask for the src sequence. The shape is :math:`(S, S)`
|
|
274
|
+
or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
275
|
+
- **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch. The shape is
|
|
276
|
+
:math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
277
|
+
float64, bool. Default: ``None``.
|
|
262
278
|
|
|
263
279
|
Outputs:
|
|
264
|
-
Tensor.
|
|
280
|
+
Tensor. The shape and dtype of Tensor is the same with `src` .
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If the init argument `activation` is not str, callable or Cell instance.
|
|
284
|
+
ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
|
|
285
|
+
:class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
|
|
286
|
+
:func:`mindspore.ops.gelu`, "relu" or "gelu" .
|
|
265
287
|
|
|
266
288
|
Supported Platforms:
|
|
267
289
|
``Ascend`` ``GPU`` ``CPU``
|
|
268
290
|
|
|
269
291
|
Examples:
|
|
270
|
-
>>>
|
|
271
|
-
>>>
|
|
292
|
+
>>> import mindspore as ms
|
|
293
|
+
>>> import numpy as np
|
|
294
|
+
>>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
295
|
+
>>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
|
|
272
296
|
>>> out = encoder_layer(src)
|
|
297
|
+
>>> print(out.shape)
|
|
298
|
+
(10, 32, 512)
|
|
273
299
|
>>> # Alternatively, when batch_first=True:
|
|
274
|
-
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
275
|
-
>>> src = Tensor(np.random.rand(32, 10, 512),
|
|
300
|
+
>>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
301
|
+
>>> src = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
|
|
276
302
|
>>> out = encoder_layer(src)
|
|
277
303
|
>>> print(out.shape)
|
|
278
304
|
(32, 10, 512)
|
|
279
305
|
"""
|
|
280
|
-
__constants__ = ['batch_first', 'norm_first']
|
|
281
306
|
|
|
282
307
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
283
308
|
activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
|
|
284
|
-
batch_first: bool = False, norm_first: bool = False):
|
|
309
|
+
batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
|
|
285
310
|
super().__init__()
|
|
286
|
-
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
|
311
|
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
|
|
287
312
|
# feedforward layer
|
|
288
|
-
|
|
313
|
+
fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
|
|
314
|
+
bound = 1 / math.sqrt(fan_in)
|
|
315
|
+
self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
|
|
316
|
+
bias_init=Uniform(bound), dtype=dtype)
|
|
289
317
|
self.dropout = Dropout(p=dropout)
|
|
290
|
-
|
|
318
|
+
fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
|
|
319
|
+
bound1 = 1 / math.sqrt(fan_in1)
|
|
320
|
+
self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
|
|
321
|
+
bias_init=Uniform(bound1), dtype=dtype)
|
|
291
322
|
|
|
292
323
|
self.norm_first = norm_first
|
|
293
|
-
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
294
|
-
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
324
|
+
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
325
|
+
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
295
326
|
self.dropout1 = Dropout(p=dropout)
|
|
296
327
|
self.dropout2 = Dropout(p=dropout)
|
|
328
|
+
self.activation1 = activation
|
|
297
329
|
|
|
298
330
|
if not isinstance(activation, str) and not isinstance(activation, Cell) \
|
|
299
331
|
and not callable(activation):
|
|
300
332
|
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
|
|
301
333
|
f" but get {activation}.")
|
|
302
|
-
if isinstance(activation, Cell) and (not isinstance(activation, ReLU)
|
|
334
|
+
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
|
|
303
335
|
not isinstance(activation, GELU)):
|
|
304
336
|
raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
|
|
305
337
|
f" but get {activation}.")
|
|
306
|
-
if callable(activation) and (activation is not ops.relu
|
|
338
|
+
if callable(activation) and (activation is not ops.relu and \
|
|
307
339
|
activation is not ops.gelu):
|
|
308
340
|
raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
|
|
309
341
|
f" but get {activation}.")
|
|
@@ -311,6 +343,14 @@ class TransformerEncoderLayer(Cell):
|
|
|
311
343
|
if isinstance(activation, str):
|
|
312
344
|
activation = _get_activation_fn(activation)
|
|
313
345
|
self.activation = activation
|
|
346
|
+
self.d_model = d_model
|
|
347
|
+
self.nhead = nhead
|
|
348
|
+
self.dim_feedforward = dim_feedforward
|
|
349
|
+
self.dropout_num = dropout
|
|
350
|
+
self.layernorm_eps = layer_norm_eps
|
|
351
|
+
self.batch_first = batch_first
|
|
352
|
+
self.norm_first = norm_first
|
|
353
|
+
self.dtype = dtype
|
|
314
354
|
|
|
315
355
|
def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
|
|
316
356
|
src_key_padding_mask: Optional[Tensor] = None):
|
|
@@ -338,7 +378,7 @@ class TransformerEncoderLayer(Cell):
|
|
|
338
378
|
return self.dropout1(x)
|
|
339
379
|
|
|
340
380
|
def _ff_block(self, x):
|
|
341
|
-
x = self.
|
|
381
|
+
x = self.dense2(self.dropout(self.activation(self.dense1(x))))
|
|
342
382
|
return self.dropout2(x)
|
|
343
383
|
|
|
344
384
|
|
|
@@ -353,74 +393,101 @@ class TransformerDecoderLayer(Cell):
|
|
|
353
393
|
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
|
|
354
394
|
dropout (float): The dropout value. Default: ``0.1``.
|
|
355
395
|
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
|
|
356
|
-
can be a string (
|
|
357
|
-
a callable (
|
|
396
|
+
can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
|
|
397
|
+
:class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
|
|
398
|
+
:func:`mindspore.ops.gelu` ). Default: ``"relu"``.
|
|
358
399
|
layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
|
|
359
|
-
batch_first (bool): If `batch_first
|
|
400
|
+
batch_first (bool): If `batch_first=True` , then the shape of input and output tensors is
|
|
360
401
|
:math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)`.
|
|
361
402
|
Default: ``False``.
|
|
362
|
-
norm_first (bool): If `norm_first = True`, layer norm is
|
|
363
|
-
operations
|
|
403
|
+
norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
|
|
404
|
+
operations; if `norm_first = False`, layer norm is located after the attention and feedforward
|
|
405
|
+
operations. Default: ``False``.
|
|
406
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
364
407
|
|
|
365
408
|
Inputs:
|
|
366
|
-
- **tgt** (Tensor): The sequence to the decoder layer.
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
- **
|
|
371
|
-
|
|
372
|
-
- **
|
|
373
|
-
Default: ``None``.
|
|
409
|
+
- **tgt** (Tensor): The sequence to the decoder layer. For unbatched input, the shape is
|
|
410
|
+
:math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
|
|
411
|
+
`batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
|
|
412
|
+
length. Supported types: float16, float32, float64.
|
|
413
|
+
- **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
|
|
414
|
+
float32, float64.
|
|
415
|
+
- **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
|
|
416
|
+
or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
417
|
+
- **memory_mask** (Tensor, optional): The mask of the memory sequence. The shape is
|
|
418
|
+
:math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
419
|
+
- **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch. The shape is
|
|
420
|
+
:math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
421
|
+
float64, bool. Default: ``None``.
|
|
422
|
+
- **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch. The shape
|
|
423
|
+
is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
424
|
+
float64, bool. Default: ``None``.
|
|
374
425
|
|
|
375
426
|
Outputs:
|
|
376
|
-
Tensor.
|
|
427
|
+
Tensor. The shape and dtype of Tensor is the same with `tgt` .
|
|
428
|
+
|
|
429
|
+
Raises:
|
|
430
|
+
ValueError: If the init argument `activation` is not str, callable or Cell instance.
|
|
431
|
+
ValueError: If the init argument `activation` is not :class:`mindspore.nn.ReLU`,
|
|
432
|
+
:class:`mindspore.nn.GELU` instance, :func:`mindspore.ops.relu`,
|
|
433
|
+
:func:`mindspore.ops.gelu` , "relu" or "gelu" .
|
|
377
434
|
|
|
378
435
|
Supported Platforms:
|
|
379
436
|
``Ascend`` ``GPU`` ``CPU``
|
|
380
437
|
|
|
381
438
|
Examples:
|
|
382
|
-
>>>
|
|
383
|
-
>>>
|
|
384
|
-
>>>
|
|
439
|
+
>>> import mindspore as ms
|
|
440
|
+
>>> import numpy as np
|
|
441
|
+
>>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
442
|
+
>>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
|
|
443
|
+
>>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
|
|
385
444
|
>>> out = decoder_layer(tgt, memory)
|
|
445
|
+
>>> print(out.shape)
|
|
446
|
+
(20, 32, 512)
|
|
386
447
|
>>> # Alternatively, when `batch_first` is ``True``:
|
|
387
|
-
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
388
|
-
>>> memory = Tensor(np.random.rand(32, 10, 512),
|
|
389
|
-
>>> tgt = Tensor(np.random.rand(32, 20, 512),
|
|
448
|
+
>>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
|
449
|
+
>>> memory = ms.Tensor(np.random.rand(32, 10, 512), ms.float32)
|
|
450
|
+
>>> tgt = ms.Tensor(np.random.rand(32, 20, 512), ms.float32)
|
|
390
451
|
>>> out = decoder_layer(tgt, memory)
|
|
391
452
|
>>> print(out.shape)
|
|
392
453
|
(32, 20, 512)
|
|
393
454
|
"""
|
|
394
|
-
__constants__ = ['batch_first', 'norm_first']
|
|
395
455
|
|
|
396
456
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
397
457
|
activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
|
|
398
|
-
batch_first: bool = False, norm_first: bool = False):
|
|
458
|
+
batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
|
|
399
459
|
super().__init__()
|
|
400
|
-
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
|
401
|
-
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
|
460
|
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
|
|
461
|
+
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, dtype=dtype)
|
|
402
462
|
# feedforward layer
|
|
403
|
-
|
|
463
|
+
fan_in, _ = _calculate_fan_in_and_fan_out((dim_feedforward, d_model))
|
|
464
|
+
bound = 1 / math.sqrt(fan_in)
|
|
465
|
+
self.dense1 = Dense(d_model, dim_feedforward, weight_init=HeUniform(math.sqrt(5)),
|
|
466
|
+
bias_init=Uniform(bound), dtype=dtype)
|
|
404
467
|
self.dropout = Dropout(p=dropout)
|
|
405
|
-
|
|
468
|
+
fan_in1, _ = _calculate_fan_in_and_fan_out((d_model, dim_feedforward))
|
|
469
|
+
bound1 = 1 / math.sqrt(fan_in1)
|
|
470
|
+
self.dense2 = Dense(dim_feedforward, d_model, weight_init=HeUniform(math.sqrt(5)),
|
|
471
|
+
bias_init=Uniform(bound1), dtype=dtype)
|
|
406
472
|
|
|
407
473
|
self.norm_first = norm_first
|
|
408
|
-
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
409
|
-
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
410
|
-
self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
474
|
+
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
475
|
+
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
476
|
+
self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
411
477
|
self.dropout1 = Dropout(p=dropout)
|
|
412
478
|
self.dropout2 = Dropout(p=dropout)
|
|
413
479
|
self.dropout3 = Dropout(p=dropout)
|
|
480
|
+
self.activation1 = activation
|
|
414
481
|
|
|
415
482
|
if not isinstance(activation, str) and not isinstance(activation, Cell) \
|
|
416
483
|
and not callable(activation):
|
|
417
484
|
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
|
|
418
485
|
f" but get {activation}.")
|
|
419
|
-
if isinstance(activation, Cell) and (not isinstance(activation, ReLU)
|
|
486
|
+
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) and \
|
|
420
487
|
not isinstance(activation, GELU)):
|
|
421
488
|
raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
|
|
422
489
|
f" but get {activation}.")
|
|
423
|
-
if callable(activation) and (activation is not ops.relu
|
|
490
|
+
if callable(activation) and (activation is not ops.relu and \
|
|
424
491
|
activation is not ops.gelu):
|
|
425
492
|
raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
|
|
426
493
|
f" but get {activation}.")
|
|
@@ -428,6 +495,14 @@ class TransformerDecoderLayer(Cell):
|
|
|
428
495
|
if isinstance(activation, str):
|
|
429
496
|
activation = _get_activation_fn(activation)
|
|
430
497
|
self.activation = activation
|
|
498
|
+
self.d_model = d_model
|
|
499
|
+
self.nhead = nhead
|
|
500
|
+
self.dim_feedforward = dim_feedforward
|
|
501
|
+
self.dropout_num = dropout
|
|
502
|
+
self.layernorm_eps = layer_norm_eps
|
|
503
|
+
self.batch_first = batch_first
|
|
504
|
+
self.norm_first = norm_first
|
|
505
|
+
self.dtype = dtype
|
|
431
506
|
|
|
432
507
|
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
|
433
508
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
@@ -459,46 +534,61 @@ class TransformerDecoderLayer(Cell):
|
|
|
459
534
|
return self.dropout2(x)
|
|
460
535
|
|
|
461
536
|
def _ff_block(self, x):
|
|
462
|
-
x = self.
|
|
537
|
+
x = self.dense2(self.dropout(self.activation(self.dense1(x))))
|
|
463
538
|
return self.dropout3(x)
|
|
464
539
|
|
|
465
540
|
|
|
466
541
|
class TransformerEncoder(Cell):
|
|
467
542
|
r"""
|
|
468
|
-
Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead
|
|
543
|
+
Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead
|
|
469
544
|
attention and feedforward layer. Users can build the
|
|
470
545
|
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
|
471
546
|
|
|
472
547
|
Args:
|
|
473
|
-
encoder_layer (Cell): An instance of the TransformerEncoderLayer
|
|
548
|
+
encoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerEncoderLayer` class.
|
|
474
549
|
num_layers (int): The number of encoder-layers in the encoder.
|
|
475
|
-
norm (Cell, optional): The layer normalization module.
|
|
550
|
+
norm (Cell, optional): The layer normalization module. Default: ``None``.
|
|
476
551
|
|
|
477
552
|
Inputs:
|
|
478
|
-
- **src** (Tensor): The sequence to the encoder.
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
553
|
+
- **src** (Tensor): The sequence to the encoder. For unbatched input, the shape is
|
|
554
|
+
:math:`(S, E)` ; otherwise if `batch_first=False` in TransformerEncoderLayer, the shape is
|
|
555
|
+
:math:`(S, N, E)` and if `batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the
|
|
556
|
+
source sequence length, :math:`(N)` is the batch number and :math:`(E)` is the feature number.
|
|
557
|
+
Supported types: float16, float32, float64.
|
|
558
|
+
- **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
|
|
559
|
+
or :math:`(N*nhead, S, S)` , where `nhead` is the arguent in TransformerDecoderLayer.
|
|
560
|
+
Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
561
|
+
- **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch. The shape is
|
|
562
|
+
:math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
563
|
+
float64, bool. Default: ``None``.
|
|
482
564
|
|
|
483
565
|
Outputs:
|
|
484
|
-
Tensor.
|
|
566
|
+
Tensor. The shape and dtype of Tensor is the same with `src` .
|
|
567
|
+
|
|
568
|
+
Raises:
|
|
569
|
+
AssertionError: If the input argument `src_key_padding_mask` is not bool or floating types.
|
|
485
570
|
|
|
486
571
|
Supported Platforms:
|
|
487
572
|
``Ascend`` ``GPU`` ``CPU``
|
|
488
573
|
|
|
489
574
|
Examples:
|
|
490
|
-
>>>
|
|
491
|
-
>>>
|
|
492
|
-
>>>
|
|
575
|
+
>>> import mindspore as ms
|
|
576
|
+
>>> import numpy as np
|
|
577
|
+
>>> encoder_layer = ms.nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
578
|
+
>>> transformer_encoder = ms.nn.TransformerEncoder(encoder_layer, num_layers=6)
|
|
579
|
+
>>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
|
|
493
580
|
>>> out = transformer_encoder(src)
|
|
494
581
|
>>> print(out.shape)
|
|
495
582
|
(10, 32, 512)
|
|
496
583
|
"""
|
|
497
|
-
__constants__ = ['norm']
|
|
498
584
|
|
|
499
585
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
|
500
586
|
super(TransformerEncoder, self).__init__()
|
|
501
|
-
|
|
587
|
+
layers = TransformerEncoderLayer(encoder_layer.d_model, encoder_layer.nhead, encoder_layer.dim_feedforward,
|
|
588
|
+
encoder_layer.dropout_num, encoder_layer.activation1,
|
|
589
|
+
encoder_layer.layernorm_eps, encoder_layer.batch_first,
|
|
590
|
+
encoder_layer.norm_first, dtype=encoder_layer.dtype)
|
|
591
|
+
self.layers = CellList([layers for _ in range(num_layers)])
|
|
502
592
|
self.num_layers = num_layers
|
|
503
593
|
self.norm = norm
|
|
504
594
|
|
|
@@ -527,38 +617,51 @@ class TransformerDecoder(Cell):
|
|
|
527
617
|
Args:
|
|
528
618
|
decoder_layer (Cell): An instance of the :class:`mindspore.nn.TransformerDecoderLayer` class.
|
|
529
619
|
num_layers (int): The number of decoder-layers in the decoder.
|
|
530
|
-
norm (Cell, optional): The layer normalization module.
|
|
620
|
+
norm (Cell, optional): The layer normalization module. Default: ``None``.
|
|
531
621
|
|
|
532
622
|
Inputs:
|
|
533
|
-
- **tgt** (Tensor): The sequence to the decoder.
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
- **
|
|
538
|
-
|
|
539
|
-
- **
|
|
540
|
-
|
|
623
|
+
- **tgt** (Tensor): The sequence to the decoder. For unbatched input, the shape is
|
|
624
|
+
:math:`(T, E)` ; otherwise if `batch_first=False` in TransformerDecoderLayer, the shape is
|
|
625
|
+
:math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the
|
|
626
|
+
target sequence length. Supported types: float16, float32, float64.
|
|
627
|
+
- **memory** (Tensor): The sequence from the last layer of the encoder. Supported types: float16,
|
|
628
|
+
float32, float64.
|
|
629
|
+
- **tgt_mask** (Tensor, optional): the mask of the tgt sequence. The shape is :math:`(T, T)`
|
|
630
|
+
or :math:`(N*nhead, T, T)` , where `nhead` is the arguent in TransformerDecoderLayer.
|
|
631
|
+
Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
632
|
+
- **memory_mask** (Tensor, optional): the mask of the memory sequence. The shape is
|
|
633
|
+
:math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
634
|
+
- **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch. Supported
|
|
635
|
+
types: float16, float32, float64, bool. Default: ``None``.
|
|
636
|
+
- **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch. The shape
|
|
637
|
+
is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
638
|
+
float64, bool. Default: ``None``.
|
|
541
639
|
|
|
542
640
|
Outputs:
|
|
543
|
-
Tensor.
|
|
641
|
+
Tensor. The shape and dtype of Tensor is the same with `tgt` .
|
|
544
642
|
|
|
545
643
|
Supported Platforms:
|
|
546
644
|
``Ascend`` ``GPU`` ``CPU``
|
|
547
645
|
|
|
548
646
|
Examples:
|
|
549
|
-
>>>
|
|
550
|
-
>>>
|
|
551
|
-
>>>
|
|
552
|
-
>>>
|
|
647
|
+
>>> import mindspore as ms
|
|
648
|
+
>>> import numpy as np
|
|
649
|
+
>>> decoder_layer = ms.nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
650
|
+
>>> transformer_decoder = ms.nn.TransformerDecoder(decoder_layer, num_layers=6)
|
|
651
|
+
>>> memory = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
|
|
652
|
+
>>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
|
|
553
653
|
>>> out = transformer_decoder(tgt, memory)
|
|
554
654
|
>>> print(out.shape)
|
|
555
655
|
(20, 32, 512)
|
|
556
656
|
"""
|
|
557
|
-
__constants__ = ['norm']
|
|
558
657
|
|
|
559
658
|
def __init__(self, decoder_layer, num_layers, norm=None):
|
|
560
659
|
super(TransformerDecoder, self).__init__()
|
|
561
|
-
|
|
660
|
+
layers = TransformerDecoderLayer(decoder_layer.d_model, decoder_layer.nhead, decoder_layer.dim_feedforward,
|
|
661
|
+
decoder_layer.dropout_num, decoder_layer.activation1,
|
|
662
|
+
decoder_layer.layernorm_eps, decoder_layer.batch_first,
|
|
663
|
+
decoder_layer.norm_first, dtype=decoder_layer.dtype)
|
|
664
|
+
self.layers = CellList([layers for _ in range(num_layers)])
|
|
562
665
|
self.num_layers = num_layers
|
|
563
666
|
self.norm = norm
|
|
564
667
|
|
|
@@ -566,7 +669,6 @@ class TransformerDecoder(Cell):
|
|
|
566
669
|
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
|
|
567
670
|
memory_key_padding_mask: Optional[Tensor] = None):
|
|
568
671
|
output = tgt
|
|
569
|
-
|
|
570
672
|
for mod in self.layers:
|
|
571
673
|
output = mod(output, memory, tgt_mask=tgt_mask,
|
|
572
674
|
memory_mask=memory_mask,
|
|
@@ -582,52 +684,74 @@ class TransformerDecoder(Cell):
|
|
|
582
684
|
class Transformer(Cell):
|
|
583
685
|
r"""
|
|
584
686
|
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
|
585
|
-
the residual addition before the layer normalization. And the default hidden
|
|
687
|
+
the residual addition before the layer normalization. And the default hidden activation is `gelu`.
|
|
586
688
|
The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
|
587
689
|
|
|
588
690
|
Args:
|
|
589
|
-
d_model (int): The number of expected features in the inputs tensor. Default: ``512``.
|
|
691
|
+
d_model (int): The number of expected features in the inputs tensor for Encoder and Decoder. Default: ``512``.
|
|
590
692
|
nhead (int): The number of heads in the MultiheadAttention modules. Default: ``8``.
|
|
591
693
|
num_encoder_layers (int): The number of encoder-layers in the encoder. Default: ``6``.
|
|
592
694
|
num_decoder_layers (int): The number of decoder-layers in the decoder. Default: ``6``.
|
|
593
695
|
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
|
|
594
696
|
dropout (float): The dropout value. Default: ``0.1``.
|
|
595
697
|
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
|
|
596
|
-
can be a string (
|
|
597
|
-
a callable (
|
|
698
|
+
can be a string (``"relu"`` or ``"gelu"``), Cell instance (:class:`mindspore.nn.ReLU` or
|
|
699
|
+
:class:`mindspore.nn.GELU` ) or a callable ( :func:`mindspore.ops.relu` or
|
|
700
|
+
:func:`mindspore.ops.gelu` ). Default: ``"relu"``.
|
|
598
701
|
custom_encoder (Cell): Custom encoder. Default: ``None``.
|
|
599
702
|
custom_decoder (Cell): Custom decoder. Default: ``None``.
|
|
600
703
|
layer_norm_eps (float): the epsilion value in layer normalization module. Default: ``1e-5``.
|
|
601
|
-
batch_first (bool): If `batch_first
|
|
704
|
+
batch_first (bool): If `batch_first=True`, then the shape of input and output tensors is
|
|
602
705
|
:math:`(batch, seq, feature)` , otherwise the shape is :math:`(seq, batch, feature)` .
|
|
603
706
|
Default: ``False``.
|
|
604
|
-
norm_first (bool): If `norm_first = True`, layer norm is
|
|
605
|
-
operations
|
|
707
|
+
norm_first (bool): If `norm_first = True`, layer norm is located prior to attention and feedforward
|
|
708
|
+
operations; if `norm_first = False`, layer norm is located after the attention and feedforward
|
|
709
|
+
operations. Default: ``False``.
|
|
710
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
606
711
|
|
|
607
712
|
Inputs:
|
|
608
|
-
- **src** (Tensor): The source sequence to the encoder.
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
713
|
+
- **src** (Tensor): The source sequence to the encoder. For unbatched input, the shape is
|
|
714
|
+
:math:`(S, E)` ; otherwise if `batch_first=False` , the shape is :math:`(S, N, E)` and if
|
|
715
|
+
`batch_first=True` , the shape is :math:`(S, N, E)`, where :math:`(S)` is the source sequence
|
|
716
|
+
length, :math:`(N)` is the batch number and :math:`(E)` is the feature number. Supported
|
|
717
|
+
types: float16, float32, float64.
|
|
718
|
+
- **tgt** (Tensor): The target sequence to the decoder. For unbatched input, the shape is
|
|
719
|
+
:math:`(T, E)` ; otherwise if `batch_first=False` , the shape is :math:`(T, N, E)` and if
|
|
720
|
+
`batch_first=True` , the shape is :math:`(T, N, E)`, where :math:`(T)` is the target sequence
|
|
721
|
+
length. Supported types: float16, float32, float64.
|
|
722
|
+
- **src_mask** (Tensor, optional): The mask of the src sequence. The shape is :math:`(S, S)`
|
|
723
|
+
or :math:`(N*nhead, S, S)`. Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
724
|
+
- **tgt_mask** (Tensor, optional): The mask of the tgt sequence. The shape is :math:`(T, T)`
|
|
725
|
+
or :math:`(N*nhead, T, T)`. Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
726
|
+
- **memory_mask** (Tensor, optional): The additive mask of the encoder output. The shape is
|
|
727
|
+
:math:`(T, S)` . Supported types: float16, float32, float64, bool. Default: ``None``.
|
|
728
|
+
- **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch. The shape is
|
|
729
|
+
:math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
730
|
+
float64, bool. Default: ``None``.
|
|
731
|
+
- **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch. The shape is
|
|
732
|
+
:math:`(T)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16, float32,
|
|
733
|
+
float64, bool. Default: ``None``.
|
|
734
|
+
- **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch. The shape
|
|
735
|
+
is :math:`(S)` for unbatched input, otherwise :math:`(N, S)` . Supported types: float16,
|
|
736
|
+
float32, float64, bool. Default: ``None``.
|
|
620
737
|
|
|
621
738
|
Outputs:
|
|
622
|
-
Tensor.
|
|
739
|
+
Tensor. The shape is :math:`(T, E)` for unbatched input, otherwise if `batch_first=False` , the shape is
|
|
740
|
+
:math:`(T, N, E)` and if `batch_first=True` , the shape is :math:`(N, T, E)`.
|
|
741
|
+
|
|
742
|
+
Raises:
|
|
743
|
+
ValueError: If the batch sizes of the init argument `src` and `tgt` are not equal.
|
|
744
|
+
ValueError: If the number of features of the init argument `src` and `tgt` is not equal to that of `d_model`.
|
|
623
745
|
|
|
624
746
|
Supported Platforms:
|
|
625
747
|
``Ascend`` ``GPU`` ``CPU``
|
|
626
748
|
|
|
627
749
|
Examples:
|
|
628
|
-
>>>
|
|
629
|
-
>>>
|
|
630
|
-
>>>
|
|
750
|
+
>>> import mindspore as ms
|
|
751
|
+
>>> import numpy as np
|
|
752
|
+
>>> transformer_model = ms.nn.Transformer(nhead=16, num_encoder_layers=12)
|
|
753
|
+
>>> src = ms.Tensor(np.random.rand(10, 32, 512), ms.float32)
|
|
754
|
+
>>> tgt = ms.Tensor(np.random.rand(20, 32, 512), ms.float32)
|
|
631
755
|
>>> out = transformer_model(src, tgt)
|
|
632
756
|
>>> print(out.shape)
|
|
633
757
|
(20, 32, 512)
|
|
@@ -637,23 +761,23 @@ class Transformer(Cell):
|
|
|
637
761
|
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
638
762
|
activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None,
|
|
639
763
|
custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
|
|
640
|
-
batch_first: bool = False, norm_first: bool = False):
|
|
764
|
+
batch_first: bool = False, norm_first: bool = False, dtype=mstype.float32):
|
|
641
765
|
super(Transformer, self).__init__()
|
|
642
766
|
|
|
643
767
|
if custom_encoder is not None:
|
|
644
768
|
self.encoder = custom_encoder
|
|
645
769
|
else:
|
|
646
770
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
647
|
-
activation, layer_norm_eps, batch_first, norm_first)
|
|
648
|
-
encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
771
|
+
activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
|
|
772
|
+
encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
649
773
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
650
774
|
|
|
651
775
|
if custom_decoder is not None:
|
|
652
776
|
self.decoder = custom_decoder
|
|
653
777
|
else:
|
|
654
778
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
|
|
655
|
-
activation, layer_norm_eps, batch_first, norm_first)
|
|
656
|
-
decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
|
|
779
|
+
activation, layer_norm_eps, batch_first, norm_first, dtype=dtype)
|
|
780
|
+
decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps, dtype=dtype)
|
|
657
781
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
|
658
782
|
|
|
659
783
|
for _, p in self.parameters_and_names():
|
|
@@ -695,7 +819,3 @@ def _get_activation_fn(activation: str):
|
|
|
695
819
|
return ops.gelu
|
|
696
820
|
|
|
697
821
|
raise ValueError(f"The activation must be relu/gelu, but get {activation}")
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
def _get_clones(module, N):
|
|
701
|
-
return CellList([copy.deepcopy(module) for i in range(N)])
|