mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/dense.py
CHANGED
|
@@ -41,8 +41,8 @@ def check_dense_inputs_same_shape(input1, input2, prim_name=None):
|
|
|
41
41
|
@constexpr(check=False)
|
|
42
42
|
def _check_is_tensor(param_name, input_data, cls_name):
|
|
43
43
|
"""Internal function, used to check whether the input data is Tensor."""
|
|
44
|
-
if input_data is not None and not isinstance(P.typeof(input_data), mstype.
|
|
45
|
-
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.
|
|
44
|
+
if input_data is not None and not isinstance(P.typeof(input_data), mstype.TensorType):
|
|
45
|
+
raise TypeError(f"For '{cls_name}', the '{param_name}' must be '{mstype.TensorType}', "
|
|
46
46
|
f"but got '{P.typeof(input_data)}'")
|
|
47
47
|
|
|
48
48
|
|
|
@@ -66,17 +66,18 @@ class BiDense(Cell):
|
|
|
66
66
|
where :math:`x_{1}` is the first input tensor, :math:`x_{2}` is the second input tensor
|
|
67
67
|
, :math:`A` is a weight matrix with the same data type as the :math:`x_{*}` created by the layer
|
|
68
68
|
, and :math:`b` is a bias vector with the same data type as the :math:`x_{*}` created by the layer
|
|
69
|
-
(only if has_bias is True).
|
|
69
|
+
(only if has_bias is ``True`` ).
|
|
70
70
|
|
|
71
71
|
Args:
|
|
72
72
|
in1_channels (int): The number of channels in the input1 space.
|
|
73
73
|
in2_channels (int): The number of channels in the input2 space.
|
|
74
74
|
out_channels (int): The number of channels in the output space.
|
|
75
75
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter.
|
|
76
|
-
The values of str refer to the function `initializer`. Default: None.
|
|
76
|
+
The values of str refer to the function `initializer`. Default: ``None`` .
|
|
77
77
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
|
|
78
|
-
The values of str refer to the function `initializer`. Default: None.
|
|
79
|
-
has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: True.
|
|
78
|
+
The values of str refer to the function `initializer`. Default: ``None`` .
|
|
79
|
+
has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
|
|
80
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
80
81
|
|
|
81
82
|
Shape:
|
|
82
83
|
- **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
|
|
@@ -90,17 +91,17 @@ class BiDense(Cell):
|
|
|
90
91
|
are the same shape as the inputs.
|
|
91
92
|
|
|
92
93
|
Dtype:
|
|
93
|
-
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2
|
|
94
|
-
- **
|
|
94
|
+
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
|
|
95
|
+
- **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
|
|
95
96
|
- **output** (Tensor) - With the same dtype as the inputs.
|
|
96
97
|
|
|
97
98
|
Weights:
|
|
98
99
|
- **weight** (Parameter) - The learnable weights with shape
|
|
99
100
|
:math:`(\text{out_channels}, \text{in1_channels}, \text{in2_channels})`.
|
|
100
|
-
When `weight_init` is
|
|
101
|
+
When `weight_init` is ``None`` , the values are initialized from
|
|
101
102
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1_channels}}`.
|
|
102
103
|
- **bias** (Parameter) - The learnable bias of shape :math:`(\text{out_channels})`.
|
|
103
|
-
If `has_bias` is
|
|
104
|
+
If `has_bias` is ``True`` and `bias_init` is ``None`` , the values are initialized from
|
|
104
105
|
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = \frac{1}{\text{in1_channels}}`.
|
|
105
106
|
|
|
106
107
|
Raises:
|
|
@@ -116,6 +117,9 @@ class BiDense(Cell):
|
|
|
116
117
|
``Ascend`` ``GPU`` ``CPU``
|
|
117
118
|
|
|
118
119
|
Examples:
|
|
120
|
+
>>> import mindspore
|
|
121
|
+
>>> from mindspore import Tensor, nn
|
|
122
|
+
>>> import numpy as np
|
|
119
123
|
>>> x1 = Tensor(np.random.randn(128, 20), mindspore.float32)
|
|
120
124
|
>>> x2 = Tensor(np.random.randn(128, 30), mindspore.float32)
|
|
121
125
|
>>> net = nn.BiDense(20, 30, 40)
|
|
@@ -130,7 +134,8 @@ class BiDense(Cell):
|
|
|
130
134
|
out_channels,
|
|
131
135
|
weight_init=None,
|
|
132
136
|
bias_init=None,
|
|
133
|
-
has_bias=True
|
|
137
|
+
has_bias=True,
|
|
138
|
+
dtype=mstype.float32):
|
|
134
139
|
super().__init__()
|
|
135
140
|
self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
|
|
136
141
|
self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
|
|
@@ -153,7 +158,8 @@ class BiDense(Cell):
|
|
|
153
158
|
f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
|
|
154
159
|
f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
|
|
155
160
|
f"'in2_channels': {in2_channels}")
|
|
156
|
-
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)
|
|
161
|
+
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
|
|
162
|
+
'weight')
|
|
157
163
|
|
|
158
164
|
if self.has_bias:
|
|
159
165
|
if bias_init is None:
|
|
@@ -163,7 +169,7 @@ class BiDense(Cell):
|
|
|
163
169
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
|
164
170
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
165
171
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
166
|
-
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
172
|
+
self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
|
|
167
173
|
self.bias_add = P.BiasAdd()
|
|
168
174
|
self.matmul = P.MatMul()
|
|
169
175
|
|
mindspore/nn/layer/embedding.py
CHANGED
|
@@ -62,13 +62,15 @@ class Embedding(Cell):
|
|
|
62
62
|
Args:
|
|
63
63
|
vocab_size (int): Size of the dictionary of embeddings.
|
|
64
64
|
embedding_size (int): The size of each embedding vector.
|
|
65
|
-
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
|
|
65
|
+
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
|
|
66
66
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
67
|
-
Refer to class `initializer
|
|
68
|
-
|
|
69
|
-
|
|
67
|
+
Refer to class `mindspore.common.initializer
|
|
68
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
69
|
+
for the values of string when a string is specified. Default: ``'normal'`` .
|
|
70
|
+
dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
|
|
70
71
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
71
|
-
will be initialized to zero. Default: None. The feature is inactivated.
|
|
72
|
+
will be initialized to zero. Default: ``None`` . The feature is inactivated.
|
|
73
|
+
|
|
72
74
|
Inputs:
|
|
73
75
|
- **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
|
|
74
76
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
@@ -86,6 +88,9 @@ class Embedding(Cell):
|
|
|
86
88
|
``Ascend`` ``GPU`` ``CPU``
|
|
87
89
|
|
|
88
90
|
Examples:
|
|
91
|
+
>>> import mindspore
|
|
92
|
+
>>> from mindspore import Tensor, nn
|
|
93
|
+
>>> import numpy as np
|
|
89
94
|
>>> net = nn.Embedding(20000, 768, True)
|
|
90
95
|
>>> x = Tensor(np.ones([8, 128]), mindspore.int32)
|
|
91
96
|
>>> # Maps the input word IDs to word embedding.
|
|
@@ -126,13 +131,10 @@ class Embedding(Cell):
|
|
|
126
131
|
self.array_mul = P.MatMul()
|
|
127
132
|
self.reshape = P.Reshape()
|
|
128
133
|
self.get_shp = P.Shape()
|
|
129
|
-
self.get_tensor_shp = P.TensorShape()
|
|
130
134
|
self.concat = P.Concat()
|
|
131
135
|
|
|
132
136
|
def construct(self, ids):
|
|
133
137
|
out_shape = self.get_shp(ids) + (self.embedding_size,)
|
|
134
|
-
if F.is_sequence_value_unknown(self.get_shp(ids)):
|
|
135
|
-
out_shape = self.concat((self.get_tensor_shp(ids), Tensor([self.embedding_size])))
|
|
136
138
|
flat_ids = self.reshape_flat(ids, self.shp_flat)
|
|
137
139
|
|
|
138
140
|
if self.use_one_hot:
|
|
@@ -145,12 +147,11 @@ class Embedding(Cell):
|
|
|
145
147
|
return output
|
|
146
148
|
|
|
147
149
|
def extend_repr(self):
|
|
148
|
-
|
|
149
|
-
self.
|
|
150
|
-
return s
|
|
150
|
+
return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
|
|
151
|
+
f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
|
|
151
152
|
|
|
152
153
|
|
|
153
|
-
@
|
|
154
|
+
@_primexpr
|
|
154
155
|
def _make_axis_range(start, end):
|
|
155
156
|
axis = tuple(range(start, end))
|
|
156
157
|
return axis
|
|
@@ -177,19 +178,20 @@ class EmbeddingLookup(Cell):
|
|
|
177
178
|
embedding_size (int): The size of each embedding vector.
|
|
178
179
|
param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
179
180
|
Refer to class `initializer` for the values of string when a string
|
|
180
|
-
is specified. Default: 'normal'.
|
|
181
|
+
is specified. Default: ``'normal'`` .
|
|
181
182
|
target (str): Specifies the target where the op is executed. The value must in
|
|
182
|
-
['DEVICE', 'CPU']. Default: 'CPU'.
|
|
183
|
+
[ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` .
|
|
183
184
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
|
184
|
-
:class:`mindspore.nn.EmbeddingLookup`. Default: '
|
|
185
|
-
manual_shapes (tuple): The accompaniment array in field slice mode. Default: None.
|
|
185
|
+
:class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'`` .
|
|
186
|
+
manual_shapes (tuple): The accompaniment array in field slice mode. Default: ``None`` .
|
|
186
187
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
187
|
-
or None. Default: None
|
|
188
|
-
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
|
|
189
|
-
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
|
|
188
|
+
or None. Default: ``None`` .
|
|
189
|
+
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: ``True`` .
|
|
190
|
+
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in
|
|
190
191
|
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
|
191
192
|
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
|
192
193
|
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
|
194
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
193
195
|
|
|
194
196
|
Inputs:
|
|
195
197
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
@@ -216,6 +218,9 @@ class EmbeddingLookup(Cell):
|
|
|
216
218
|
``Ascend`` ``GPU`` ``CPU``
|
|
217
219
|
|
|
218
220
|
Examples:
|
|
221
|
+
>>> import mindspore
|
|
222
|
+
>>> from mindspore import Tensor, nn
|
|
223
|
+
>>> import numpy as np
|
|
219
224
|
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
|
220
225
|
>>> result = nn.EmbeddingLookup(4,2)(input_indices)
|
|
221
226
|
>>> print(result.shape)
|
|
@@ -228,7 +233,7 @@ class EmbeddingLookup(Cell):
|
|
|
228
233
|
|
|
229
234
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
230
235
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
231
|
-
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
236
|
+
max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
|
|
232
237
|
"""Initialize EmbeddingLookup."""
|
|
233
238
|
super(EmbeddingLookup, self).__init__()
|
|
234
239
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
@@ -252,8 +257,8 @@ class EmbeddingLookup(Cell):
|
|
|
252
257
|
if enable_ps:
|
|
253
258
|
self._process_vocab_cache(slice_mode)
|
|
254
259
|
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
|
255
|
-
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]
|
|
256
|
-
|
|
260
|
+
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
261
|
+
dtype=dtype), name='embedding_table')
|
|
257
262
|
parallel_mode = _get_parallel_mode()
|
|
258
263
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
259
264
|
self.gather_revert = P.Gather()
|
|
@@ -264,7 +269,7 @@ class EmbeddingLookup(Cell):
|
|
|
264
269
|
if is_auto_parallel:
|
|
265
270
|
self.unique = P.Unique().shard(((1,),))
|
|
266
271
|
if self.cache_enable and enable_ps:
|
|
267
|
-
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
|
|
272
|
+
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
|
|
268
273
|
if is_auto_parallel:
|
|
269
274
|
self.unique.add_prim_attr('cache_enable', True)
|
|
270
275
|
indices_shape_size = 2
|
|
@@ -300,15 +305,15 @@ class EmbeddingLookup(Cell):
|
|
|
300
305
|
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
|
|
301
306
|
elif slice_mode == "batch_slice" and is_auto_parallel:
|
|
302
307
|
indices_strategy = [get_group_size()]
|
|
303
|
-
indices_strategy.extend([1]*(indices_shape_size - 1))
|
|
308
|
+
indices_strategy.extend([1] * (indices_shape_size - 1))
|
|
304
309
|
indices_strategy = tuple(indices_strategy)
|
|
305
310
|
self.gatherv2.shard(((1, 1), indices_strategy))
|
|
306
311
|
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
|
307
312
|
else:
|
|
308
313
|
if is_auto_parallel:
|
|
309
314
|
support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
|
|
310
|
-
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
|
|
311
|
-
"but got \"{}\"."
|
|
315
|
+
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
|
|
316
|
+
f"but got \"{slice_mode}\".")
|
|
312
317
|
if self.cache_enable and not enable_ps:
|
|
313
318
|
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
|
314
319
|
self.embedding_table.unique = self.forward_unique
|
|
@@ -351,7 +356,8 @@ class EmbeddingLookup(Cell):
|
|
|
351
356
|
if _is_role_worker():
|
|
352
357
|
self.vocab_size = self.vocab_cache_size
|
|
353
358
|
|
|
354
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init
|
|
359
|
+
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
|
|
360
|
+
dtype=mstype.float32):
|
|
355
361
|
"""PS embeddingLookup cache enable set."""
|
|
356
362
|
if self.sparse:
|
|
357
363
|
self.forward_unique = True
|
|
@@ -365,10 +371,10 @@ class EmbeddingLookup(Cell):
|
|
|
365
371
|
if _enable_distributed_mindrt():
|
|
366
372
|
self.rank_id = get_rank()
|
|
367
373
|
if self.is_ps_server:
|
|
368
|
-
self._slice_pserver_embeddings("zeros")
|
|
374
|
+
self._slice_pserver_embeddings("zeros", dtype=dtype)
|
|
369
375
|
self._set_cache_enable_and_key_for_pserver(param_key)
|
|
370
376
|
|
|
371
|
-
def _slice_pserver_embeddings(self, param_init):
|
|
377
|
+
def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
|
|
372
378
|
'''
|
|
373
379
|
Method to slice embedding tables on Parameter Servers.
|
|
374
380
|
It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
|
|
@@ -396,7 +402,7 @@ class EmbeddingLookup(Cell):
|
|
|
396
402
|
for i in range(server_num):
|
|
397
403
|
self.embedding_table_list.append(Parameter(initializer(param_init,
|
|
398
404
|
[self.embedding_table_vocab_dim_list[i],
|
|
399
|
-
self.embedding_size]),
|
|
405
|
+
self.embedding_size], dtype=dtype),
|
|
400
406
|
name="embedding_table_server_" + str(i)))
|
|
401
407
|
|
|
402
408
|
self.embedding_offset.append(offset)
|
|
@@ -495,17 +501,20 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
495
501
|
field_size (int): The field size of the final outputs.
|
|
496
502
|
param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
497
503
|
Refer to class `initializer` for the values of string when a string
|
|
498
|
-
is specified. Default: 'normal'.
|
|
504
|
+
is specified. Default: ``'normal'`` .
|
|
499
505
|
target (str): Specifies the target where the op is executed. The value must in
|
|
500
|
-
['DEVICE', 'CPU']. Default: 'CPU'.
|
|
506
|
+
[ ``'DEVICE'`` , ``'CPU'`` ]. Default: ``'CPU'`` .
|
|
501
507
|
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
|
502
|
-
:class:`mindspore.nn.EmbeddingLookup`. Default: '
|
|
503
|
-
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
Default:
|
|
508
|
+
:class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
|
|
509
|
+
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
|
510
|
+
Default: ``None`` .
|
|
511
|
+
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
|
|
512
|
+
Default: ``None`` .
|
|
513
|
+
sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
|
|
514
|
+
Default: ``True`` .
|
|
515
|
+
operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
|
|
516
|
+
``'MAX'`` . Default: ``'SUM'`` .
|
|
517
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
509
518
|
|
|
510
519
|
Inputs:
|
|
511
520
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
|
|
@@ -524,17 +533,19 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
524
533
|
TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
|
|
525
534
|
TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
|
|
526
535
|
ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
|
|
527
|
-
ValueError: If `target` is neither 'CPU' nor 'DEVICE'
|
|
528
|
-
ValueError: If `slice_mode` is not one of 'batch_slice'
|
|
529
|
-
'table_column_slice'.
|
|
530
|
-
ValueError: If `sparse` is False and `target` is 'CPU'.
|
|
531
|
-
ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
|
|
532
|
-
ValueError: If `operator` is not one of 'SUM'
|
|
536
|
+
ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
|
|
537
|
+
ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
|
|
538
|
+
``'table_column_slice'`` .
|
|
539
|
+
ValueError: If `sparse` is False and `target` is ``'CPU'`` .
|
|
540
|
+
ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
|
|
541
|
+
ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
|
|
533
542
|
|
|
534
543
|
Supported Platforms:
|
|
535
544
|
``Ascend`` ``GPU``
|
|
536
545
|
|
|
537
546
|
Examples:
|
|
547
|
+
>>> import mindspore
|
|
548
|
+
>>> from mindspore import Tensor, nn
|
|
538
549
|
>>> input_indices = Tensor([[2, 4, 6, 0, 0], [1, 3, 5, 0, 0]], mindspore.int32)
|
|
539
550
|
>>> input_values = Tensor([[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]], mindspore.float32)
|
|
540
551
|
>>> field_ids = Tensor([[0, 1, 1, 0, 0], [0, 0, 1, 0, 0]], mindspore.int32)
|
|
@@ -548,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
548
559
|
OPERATOR_MAX = 'MAX'
|
|
549
560
|
|
|
550
561
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
|
551
|
-
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'
|
|
562
|
+
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
|
|
563
|
+
dtype=mstype.float32):
|
|
552
564
|
"""Initialize MultiFieldEmbeddingLookup."""
|
|
553
565
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
554
|
-
slice_mode, feature_num_list, max_norm, sparse)
|
|
566
|
+
slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
|
|
555
567
|
self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
|
|
556
568
|
self.operator = operator
|
|
557
569
|
|
|
@@ -615,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
615
627
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
|
616
628
|
else:
|
|
617
629
|
if is_auto_parallel:
|
|
618
|
-
raise ValueError(
|
|
619
|
-
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
|
|
632
|
+
f"and 'table_column_slice'], but got {str(slice_mode)}.")
|
|
620
633
|
|
|
621
634
|
# Min value for fp32
|
|
622
635
|
self.negative_inf_value = -3.402823466E+38
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
A FlashAttention Layer.
|
|
17
|
+
"""
|
|
18
|
+
import math
|
|
19
|
+
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
21
|
+
from mindspore.common.tensor import Tensor
|
|
22
|
+
from mindspore import ops
|
|
23
|
+
from mindspore.nn.cell import Cell
|
|
24
|
+
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
|
|
25
|
+
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
|
+
from mindspore._c_expression import MSContext
|
|
27
|
+
|
|
28
|
+
__all__ = ['FlashAttention']
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FlashAttention(Cell):
|
|
32
|
+
"""Flash Attention Layer.
|
|
33
|
+
|
|
34
|
+
This function contains the flash attention primitives used in FlashAttention (see paper)
|
|
35
|
+
`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/pdf/2205.14135.pdf>`
|
|
36
|
+
|
|
37
|
+
Specifically, it includes the following:
|
|
38
|
+
|
|
39
|
+
1. An interface for calling flashattention operation.
|
|
40
|
+
2. Two configuration parameters for enabling local block sparse of flashattention.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
head_dim(int): The hidden size of input.
|
|
44
|
+
dropout_rate(float): The dropout rate of the attention score. Default 0.0.
|
|
45
|
+
prev_block_num(int): A integer to define the number of blocks to look ahead for local block sparse attention.
|
|
46
|
+
Default 65536.
|
|
47
|
+
next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
|
|
48
|
+
Default 65536.
|
|
49
|
+
tiling_stgy_name(str): A str to define tiling strategy of flash attention.
|
|
50
|
+
dp(int): data parallel.
|
|
51
|
+
Default 1.
|
|
52
|
+
mp(int): model parallel.
|
|
53
|
+
Default 1.
|
|
54
|
+
high_precision(bool): This mode has higher precision but some performance loss.
|
|
55
|
+
Default False.
|
|
56
|
+
have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
|
|
57
|
+
Default True
|
|
58
|
+
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
59
|
+
Default: False
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
Inputs:
|
|
63
|
+
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
|
+
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
|
+
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
66
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
|
|
67
|
+
seq_length]): A matrix to pass masked information.
|
|
68
|
+
|
|
69
|
+
Outputs:
|
|
70
|
+
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
71
|
+
|
|
72
|
+
Supported Platforms:
|
|
73
|
+
``Ascend``
|
|
74
|
+
|
|
75
|
+
Examples:
|
|
76
|
+
>>> import numpy as np
|
|
77
|
+
>>> from mindspore import dtype as mstype
|
|
78
|
+
>>> from mindspore.nn.layer.flash_attention import FlashAttention
|
|
79
|
+
>>> from mindspore import Tensor
|
|
80
|
+
>>> model = FlashAttention(head_dim=128,
|
|
81
|
+
... dropout_rate=0.1,
|
|
82
|
+
... prev_block_num=7,
|
|
83
|
+
... next_block_num=0
|
|
84
|
+
... )
|
|
85
|
+
>>> query = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
|
|
86
|
+
>>> key = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
|
|
87
|
+
>>> value = Tensor(np.ones((2, 16, 4096, 128)), mstype.float16)
|
|
88
|
+
>>> attention_mask = Tensor(np.ones((2, 4096, 4096)), mstype.float16)
|
|
89
|
+
>>> output = model(query, key, value, attention_mask)
|
|
90
|
+
>>> print(output.shape)
|
|
91
|
+
(2, 16, 4096, 128)
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(self,
|
|
95
|
+
head_dim,
|
|
96
|
+
head_num,
|
|
97
|
+
dropout_rate=0.0,
|
|
98
|
+
prev_block_num=65536,
|
|
99
|
+
next_block_num=65536,
|
|
100
|
+
tiling_stgy_name="sparse",
|
|
101
|
+
dp=1,
|
|
102
|
+
mp=1,
|
|
103
|
+
high_precision=False,
|
|
104
|
+
have_attention_mask_batch=True,
|
|
105
|
+
alibi=False
|
|
106
|
+
):
|
|
107
|
+
super(FlashAttention, self).__init__()
|
|
108
|
+
|
|
109
|
+
scaling_constant = math.sqrt(head_dim)
|
|
110
|
+
if scaling_constant == 0:
|
|
111
|
+
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
+
self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
|
|
113
|
+
|
|
114
|
+
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
|
|
115
|
+
if self.is_910A:
|
|
116
|
+
self.flash_attention = get_flash_attention(
|
|
117
|
+
prev_block_num=prev_block_num,
|
|
118
|
+
next_block_num=next_block_num,
|
|
119
|
+
tiling_stgy_name=tiling_stgy_name,
|
|
120
|
+
high_precision=high_precision
|
|
121
|
+
)
|
|
122
|
+
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
123
|
+
else:
|
|
124
|
+
if alibi:
|
|
125
|
+
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
126
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
127
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
128
|
+
self.reshape = ops.Reshape()
|
|
129
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
|
+
self.zeros = ops.Zeros()
|
|
131
|
+
self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
|
|
132
|
+
fa_strategies = ((dp, 1, mp),
|
|
133
|
+
(dp, 1, mp),
|
|
134
|
+
(dp, 1, mp),
|
|
135
|
+
(dp, 1, 1, 1))
|
|
136
|
+
if dropout_rate > 1e-5:
|
|
137
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
138
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
139
|
+
next_tokens=next_block_num,
|
|
140
|
+
keep_prob=1 - dropout_rate,
|
|
141
|
+
scale_value=1.0,
|
|
142
|
+
inner_precise=0 if high_precision else 1).shard(fa_strategies)
|
|
143
|
+
|
|
144
|
+
self.ones = ops.Ones()
|
|
145
|
+
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
146
|
+
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
147
|
+
self.dropout_rate = dropout_rate
|
|
148
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
149
|
+
self.alibi = alibi
|
|
150
|
+
if self.dropout_rate > 1e-5:
|
|
151
|
+
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
152
|
+
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
153
|
+
self.tensor_one = Tensor(1.0, mstype.float16)
|
|
154
|
+
self.drop_gen_mask = ops.DropoutGenMask()
|
|
155
|
+
self.do_dropout = ops.DropoutDoMask().shard(((dp, mp, 1, 1),))
|
|
156
|
+
self.depend = ops.Depend()
|
|
157
|
+
|
|
158
|
+
def shard(self, in_strategy=None, out_strategy=None):
|
|
159
|
+
"""Distributed configuration of FlashAttention
|
|
160
|
+
:param in_strategy: Describe the split strategy of operator input. Default: None.
|
|
161
|
+
:param out_strategy: Describe the split strategy of operator output, it is only for certain operators,
|
|
162
|
+
such as MatMul. Default: None.
|
|
163
|
+
:return:
|
|
164
|
+
"""
|
|
165
|
+
if in_strategy is None:
|
|
166
|
+
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
167
|
+
in_strategy = (
|
|
168
|
+
(1, 1, 1, 1),
|
|
169
|
+
(1, 1, 1, 1),
|
|
170
|
+
(1, 1, 1, 1),
|
|
171
|
+
)
|
|
172
|
+
self.flash_attention.shard(in_strategy)
|
|
173
|
+
dp = in_strategy[0][0]
|
|
174
|
+
mp = in_strategy[0][1]
|
|
175
|
+
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
176
|
+
inputs_tensor_map = [
|
|
177
|
+
[3, 2, 1, 0],
|
|
178
|
+
[3, 2, 1, 0],
|
|
179
|
+
[3, 2, 1, 0],
|
|
180
|
+
]
|
|
181
|
+
if self.have_attention_mask_batch:
|
|
182
|
+
inputs_tensor_map.append([3, 1, 0])
|
|
183
|
+
else:
|
|
184
|
+
inputs_tensor_map.append([-1, 1, 0])
|
|
185
|
+
|
|
186
|
+
input_empty_args_num = 2
|
|
187
|
+
# dropout_mask
|
|
188
|
+
if self.dropout_rate > 1e-5:
|
|
189
|
+
input_empty_args_num -= 1
|
|
190
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
191
|
+
|
|
192
|
+
if self.alibi:
|
|
193
|
+
input_empty_args_num -= 1
|
|
194
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
195
|
+
|
|
196
|
+
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
197
|
+
|
|
198
|
+
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
199
|
+
[3, 2, 1, 0], # O
|
|
200
|
+
[3, 2, 1], # L
|
|
201
|
+
[3, 2, 1] # M
|
|
202
|
+
])
|
|
203
|
+
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
204
|
+
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
205
|
+
|
|
206
|
+
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
207
|
+
"""FlashAttention forward
|
|
208
|
+
:param query: [bsz, head_num, seq_len, head_dim]
|
|
209
|
+
:param key: [bsz, head_num, seq_len, head_dim]
|
|
210
|
+
:param value: [bsz, head_num, seq_len, head_dim]
|
|
211
|
+
:param attn_mask: [1 or bsz, seq_len, seq_len], if not None
|
|
212
|
+
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
213
|
+
:return: output [bsz, head_num, seq_len, head_dim]
|
|
214
|
+
"""
|
|
215
|
+
query = self.scale_mul(query, self.scale_factor)
|
|
216
|
+
bsz, head_num, seq_len, head_dim = query.shape
|
|
217
|
+
_, k_head_num, k_seq_len, _ = key.shape
|
|
218
|
+
_, v_head_num, v_seq_len, _ = value.shape
|
|
219
|
+
if head_num != k_head_num or head_num != v_head_num:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"the head_num of query, key and value must be the same, "
|
|
222
|
+
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
223
|
+
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
226
|
+
|
|
227
|
+
if head_dim > 304:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
230
|
+
|
|
231
|
+
if self.is_910A:
|
|
232
|
+
# 910A -- FlashAttentionPrimtive
|
|
233
|
+
if self.dropout_rate > 1e-5:
|
|
234
|
+
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
|
+
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
236
|
+
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
237
|
+
ones = self.depend(ones, query)
|
|
238
|
+
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
|
+
else:
|
|
240
|
+
drop_mask = None
|
|
241
|
+
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
242
|
+
else:
|
|
243
|
+
# FlashAttentionScore
|
|
244
|
+
# Useless input, just for binary calls.
|
|
245
|
+
if self.dropout_rate > 1e-5:
|
|
246
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
247
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
248
|
+
else:
|
|
249
|
+
drop_mask_bits = None
|
|
250
|
+
# (B, N, S, D) -> (B, S, H)
|
|
251
|
+
query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
252
|
+
key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
253
|
+
value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
|
|
254
|
+
attn_mask = self.attn_expand_dims(attn_mask, 1)
|
|
255
|
+
output, _, _ = self.flash_attention(query,
|
|
256
|
+
key,
|
|
257
|
+
value,
|
|
258
|
+
attn_mask,
|
|
259
|
+
drop_mask_bits,
|
|
260
|
+
None,
|
|
261
|
+
None)
|
|
262
|
+
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
+
|
|
264
|
+
return output
|