mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/dense.py
CHANGED
|
@@ -77,6 +77,7 @@ class BiDense(Cell):
|
|
|
77
77
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter.
|
|
78
78
|
The values of str refer to the function `initializer`. Default: ``None`` .
|
|
79
79
|
has_bias (bool): Specifies whether the layer uses :math:`\text{bias}` vector. Default: ``True`` .
|
|
80
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
80
81
|
|
|
81
82
|
Shape:
|
|
82
83
|
- **input1** - :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1_channels}` and
|
|
@@ -90,8 +91,8 @@ class BiDense(Cell):
|
|
|
90
91
|
are the same shape as the inputs.
|
|
91
92
|
|
|
92
93
|
Dtype:
|
|
93
|
-
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2
|
|
94
|
-
- **
|
|
94
|
+
- **input1** (Tensor) - The dtype must be float16 or float32 and be same as **input2** .
|
|
95
|
+
- **input2** (Tensor) - The dtype must be float16 or float32 and be same as **input1** .
|
|
95
96
|
- **output** (Tensor) - With the same dtype as the inputs.
|
|
96
97
|
|
|
97
98
|
Weights:
|
|
@@ -133,7 +134,8 @@ class BiDense(Cell):
|
|
|
133
134
|
out_channels,
|
|
134
135
|
weight_init=None,
|
|
135
136
|
bias_init=None,
|
|
136
|
-
has_bias=True
|
|
137
|
+
has_bias=True,
|
|
138
|
+
dtype=mstype.float32):
|
|
137
139
|
super().__init__()
|
|
138
140
|
self.in_channels = Validator.check_positive_int(in1_channels, "in1_channels", self.cls_name)
|
|
139
141
|
self.in_channels = Validator.check_positive_int(in2_channels, "in2_channels", self.cls_name)
|
|
@@ -156,7 +158,8 @@ class BiDense(Cell):
|
|
|
156
158
|
f"equal to 'in2_channels'. But got 'weight_init': {weight_init}, "
|
|
157
159
|
f"'out_channels': {out_channels}, 'in_channels': {in1_channels}, "
|
|
158
160
|
f"'in2_channels': {in2_channels}")
|
|
159
|
-
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels)
|
|
161
|
+
self.weight = Parameter(initializer(weight_init, (out_channels, in1_channels, in2_channels), dtype=dtype),
|
|
162
|
+
'weight')
|
|
160
163
|
|
|
161
164
|
if self.has_bias:
|
|
162
165
|
if bias_init is None:
|
|
@@ -166,7 +169,7 @@ class BiDense(Cell):
|
|
|
166
169
|
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
|
|
167
170
|
f"be equal to 1, and the first dim must be equal to 'out_channels'. But got "
|
|
168
171
|
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
|
169
|
-
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
|
172
|
+
self.bias = Parameter(initializer(bias_init, [out_channels], dtype=dtype), name="bias")
|
|
170
173
|
self.bias_add = P.BiasAdd()
|
|
171
174
|
self.matmul = P.MatMul()
|
|
172
175
|
|
mindspore/nn/layer/embedding.py
CHANGED
|
@@ -64,11 +64,13 @@ class Embedding(Cell):
|
|
|
64
64
|
embedding_size (int): The size of each embedding vector.
|
|
65
65
|
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: ``False`` .
|
|
66
66
|
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
|
67
|
-
Refer to class `initializer
|
|
68
|
-
|
|
67
|
+
Refer to class `mindspore.common.initializer
|
|
68
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.common.initializer.html>`_
|
|
69
|
+
for the values of string when a string is specified. Default: ``'normal'`` .
|
|
69
70
|
dtype (:class:`mindspore.dtype`): Data type of `x`. Default: ``mstype.float32`` .
|
|
70
71
|
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
|
71
72
|
will be initialized to zero. Default: ``None`` . The feature is inactivated.
|
|
73
|
+
|
|
72
74
|
Inputs:
|
|
73
75
|
- **x** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{x_length})`. The elements of
|
|
74
76
|
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
|
@@ -145,9 +147,8 @@ class Embedding(Cell):
|
|
|
145
147
|
return output
|
|
146
148
|
|
|
147
149
|
def extend_repr(self):
|
|
148
|
-
|
|
149
|
-
self.
|
|
150
|
-
return s
|
|
150
|
+
return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
|
|
151
|
+
f'embedding_table={self.embedding_table}, dtype={self.dtype}, padding_idx={self.padding_idx}'
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
@_primexpr
|
|
@@ -190,6 +191,7 @@ class EmbeddingLookup(Cell):
|
|
|
190
191
|
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
|
191
192
|
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
|
192
193
|
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
|
194
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
193
195
|
|
|
194
196
|
Inputs:
|
|
195
197
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
@@ -231,7 +233,7 @@ class EmbeddingLookup(Cell):
|
|
|
231
233
|
|
|
232
234
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
233
235
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
234
|
-
max_norm=None, sparse=True, vocab_cache_size=0):
|
|
236
|
+
max_norm=None, sparse=True, vocab_cache_size=0, dtype=mstype.float32):
|
|
235
237
|
"""Initialize EmbeddingLookup."""
|
|
236
238
|
super(EmbeddingLookup, self).__init__()
|
|
237
239
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
@@ -255,8 +257,8 @@ class EmbeddingLookup(Cell):
|
|
|
255
257
|
if enable_ps:
|
|
256
258
|
self._process_vocab_cache(slice_mode)
|
|
257
259
|
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
|
258
|
-
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]
|
|
259
|
-
|
|
260
|
+
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
261
|
+
dtype=dtype), name='embedding_table')
|
|
260
262
|
parallel_mode = _get_parallel_mode()
|
|
261
263
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
262
264
|
self.gather_revert = P.Gather()
|
|
@@ -267,7 +269,7 @@ class EmbeddingLookup(Cell):
|
|
|
267
269
|
if is_auto_parallel:
|
|
268
270
|
self.unique = P.Unique().shard(((1,),))
|
|
269
271
|
if self.cache_enable and enable_ps:
|
|
270
|
-
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init)
|
|
272
|
+
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
|
|
271
273
|
if is_auto_parallel:
|
|
272
274
|
self.unique.add_prim_attr('cache_enable', True)
|
|
273
275
|
indices_shape_size = 2
|
|
@@ -310,8 +312,8 @@ class EmbeddingLookup(Cell):
|
|
|
310
312
|
else:
|
|
311
313
|
if is_auto_parallel:
|
|
312
314
|
support_mode = ["field_slice", "table_row_slice", "table_column_slice", "batch_slice"]
|
|
313
|
-
raise ValueError("For '{}', the 'slice_mode' must be in {}, "
|
|
314
|
-
"but got \"{}\"."
|
|
315
|
+
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
|
|
316
|
+
f"but got \"{slice_mode}\".")
|
|
315
317
|
if self.cache_enable and not enable_ps:
|
|
316
318
|
raise ValueError(f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
|
317
319
|
self.embedding_table.unique = self.forward_unique
|
|
@@ -354,7 +356,8 @@ class EmbeddingLookup(Cell):
|
|
|
354
356
|
if _is_role_worker():
|
|
355
357
|
self.vocab_size = self.vocab_cache_size
|
|
356
358
|
|
|
357
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init
|
|
359
|
+
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
|
|
360
|
+
dtype=mstype.float32):
|
|
358
361
|
"""PS embeddingLookup cache enable set."""
|
|
359
362
|
if self.sparse:
|
|
360
363
|
self.forward_unique = True
|
|
@@ -368,10 +371,10 @@ class EmbeddingLookup(Cell):
|
|
|
368
371
|
if _enable_distributed_mindrt():
|
|
369
372
|
self.rank_id = get_rank()
|
|
370
373
|
if self.is_ps_server:
|
|
371
|
-
self._slice_pserver_embeddings("zeros")
|
|
374
|
+
self._slice_pserver_embeddings("zeros", dtype=dtype)
|
|
372
375
|
self._set_cache_enable_and_key_for_pserver(param_key)
|
|
373
376
|
|
|
374
|
-
def _slice_pserver_embeddings(self, param_init):
|
|
377
|
+
def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
|
|
375
378
|
'''
|
|
376
379
|
Method to slice embedding tables on Parameter Servers.
|
|
377
380
|
It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
|
|
@@ -399,7 +402,7 @@ class EmbeddingLookup(Cell):
|
|
|
399
402
|
for i in range(server_num):
|
|
400
403
|
self.embedding_table_list.append(Parameter(initializer(param_init,
|
|
401
404
|
[self.embedding_table_vocab_dim_list[i],
|
|
402
|
-
self.embedding_size]),
|
|
405
|
+
self.embedding_size], dtype=dtype),
|
|
403
406
|
name="embedding_table_server_" + str(i)))
|
|
404
407
|
|
|
405
408
|
self.embedding_offset.append(offset)
|
|
@@ -505,12 +508,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
505
508
|
:class:`mindspore.nn.EmbeddingLookup`. Default: ``'batch_slice'``.
|
|
506
509
|
feature_num_list (tuple): The accompaniment array in field slice mode. This is unused currently.
|
|
507
510
|
Default: ``None`` .
|
|
508
|
-
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
509
|
-
|
|
511
|
+
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32.
|
|
512
|
+
Default: ``None`` .
|
|
510
513
|
sparse (bool): Using sparse mode. When 'target' is set to ``'CPU'`` , 'sparse' has to be true.
|
|
511
514
|
Default: ``True`` .
|
|
512
515
|
operator (str): The pooling method for the features in one field. Support ``'SUM'`` , ``'MEAN'`` and
|
|
513
516
|
``'MAX'`` . Default: ``'SUM'`` .
|
|
517
|
+
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
514
518
|
|
|
515
519
|
Inputs:
|
|
516
520
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(batch\_size, seq\_length)`.
|
|
@@ -529,12 +533,12 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
529
533
|
TypeError: If `vocab_size` or `embedding_size` or `field_size` is not an int.
|
|
530
534
|
TypeError: If `sparse` is not a bool or `feature_num_list` is not a tuple.
|
|
531
535
|
ValueError: If `vocab_size` or `embedding_size` or `field_size` is less than 1.
|
|
532
|
-
ValueError: If `target` is neither 'CPU' nor 'DEVICE'
|
|
533
|
-
ValueError: If `slice_mode` is not one of 'batch_slice'
|
|
534
|
-
'table_column_slice'.
|
|
535
|
-
ValueError: If `sparse` is False and `target` is 'CPU'.
|
|
536
|
-
ValueError: If `slice_mode` is 'field_slice' and `feature_num_list` is None.
|
|
537
|
-
ValueError: If `operator` is not one of 'SUM'
|
|
536
|
+
ValueError: If `target` is neither ``'CPU'`` nor ``'DEVICE'``.
|
|
537
|
+
ValueError: If `slice_mode` is not one of ``'batch_slice'``, ``'field_slice'``, ``'table_row_slice'``,
|
|
538
|
+
``'table_column_slice'`` .
|
|
539
|
+
ValueError: If `sparse` is False and `target` is ``'CPU'`` .
|
|
540
|
+
ValueError: If `slice_mode` is ``'field_slice'`` and `feature_num_list` is None.
|
|
541
|
+
ValueError: If `operator` is not one of ``'SUM'``, ``'MAX'``, ``'MEAN'`` .
|
|
538
542
|
|
|
539
543
|
Supported Platforms:
|
|
540
544
|
``Ascend`` ``GPU``
|
|
@@ -555,10 +559,11 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
555
559
|
OPERATOR_MAX = 'MAX'
|
|
556
560
|
|
|
557
561
|
def __init__(self, vocab_size, embedding_size, field_size, param_init='normal', target='CPU',
|
|
558
|
-
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM'
|
|
562
|
+
slice_mode='batch_slice', feature_num_list=None, max_norm=None, sparse=True, operator='SUM',
|
|
563
|
+
dtype=mstype.float32):
|
|
559
564
|
"""Initialize MultiFieldEmbeddingLookup."""
|
|
560
565
|
super(MultiFieldEmbeddingLookup, self).__init__(vocab_size, embedding_size, param_init, target,
|
|
561
|
-
slice_mode, feature_num_list, max_norm, sparse)
|
|
566
|
+
slice_mode, feature_num_list, max_norm, sparse, dtype=dtype)
|
|
562
567
|
self.field_size = Validator.check_positive_int(field_size, 'field_size', self.cls_name)
|
|
563
568
|
self.operator = operator
|
|
564
569
|
|
|
@@ -622,8 +627,9 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|
|
622
627
|
self.inf_add.shard(((1, 1, get_group_size()), (1, 1, 1)))
|
|
623
628
|
else:
|
|
624
629
|
if is_auto_parallel:
|
|
625
|
-
raise ValueError(
|
|
626
|
-
|
|
630
|
+
raise ValueError(
|
|
631
|
+
f"For '{self.cls_name}', the 'slice_mode' must be in ['table_row_slice', 'batch_slice' "
|
|
632
|
+
f"and 'table_column_slice'], but got {str(slice_mode)}.")
|
|
627
633
|
|
|
628
634
|
# Min value for fp32
|
|
629
635
|
self.negative_inf_value = -3.402823466E+38
|
|
@@ -17,12 +17,13 @@ A FlashAttention Layer.
|
|
|
17
17
|
"""
|
|
18
18
|
import math
|
|
19
19
|
|
|
20
|
-
import mindspore.
|
|
21
|
-
from mindspore import ops
|
|
22
|
-
from mindspore.common import dtype as mstype
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
23
21
|
from mindspore.common.tensor import Tensor
|
|
22
|
+
from mindspore import ops
|
|
24
23
|
from mindspore.nn.cell import Cell
|
|
25
24
|
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
|
|
25
|
+
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
|
+
from mindspore._c_expression import MSContext
|
|
26
27
|
|
|
27
28
|
__all__ = ['FlashAttention']
|
|
28
29
|
|
|
@@ -56,14 +57,15 @@ class FlashAttention(Cell):
|
|
|
56
57
|
Default True
|
|
57
58
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
58
59
|
Default: False
|
|
60
|
+
use_mqa(bool): Using MHA if True, only take effect under 910B. Default: False.
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
Inputs:
|
|
62
64
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
63
65
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
66
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
66
|
-
|
|
67
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
68
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
67
69
|
|
|
68
70
|
Outputs:
|
|
69
71
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -92,6 +94,7 @@ class FlashAttention(Cell):
|
|
|
92
94
|
|
|
93
95
|
def __init__(self,
|
|
94
96
|
head_dim,
|
|
97
|
+
head_num,
|
|
95
98
|
dropout_rate=0.0,
|
|
96
99
|
prev_block_num=65536,
|
|
97
100
|
next_block_num=65536,
|
|
@@ -100,27 +103,63 @@ class FlashAttention(Cell):
|
|
|
100
103
|
mp=1,
|
|
101
104
|
high_precision=False,
|
|
102
105
|
have_attention_mask_batch=True,
|
|
103
|
-
alibi=False
|
|
106
|
+
alibi=False,
|
|
107
|
+
use_mqa=False
|
|
104
108
|
):
|
|
105
109
|
super(FlashAttention, self).__init__()
|
|
106
110
|
|
|
107
|
-
self.flash_attention = get_flash_attention(
|
|
108
|
-
prev_block_num=prev_block_num,
|
|
109
|
-
next_block_num=next_block_num,
|
|
110
|
-
tiling_stgy_name=tiling_stgy_name,
|
|
111
|
-
high_precision=high_precision
|
|
112
|
-
)
|
|
113
|
-
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
114
111
|
scaling_constant = math.sqrt(head_dim)
|
|
115
|
-
if scaling_constant
|
|
116
|
-
self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
|
|
117
|
-
else:
|
|
112
|
+
if scaling_constant == 0:
|
|
118
113
|
raise ValueError("the scaling constant must not be 0.")
|
|
119
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
120
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
121
114
|
self.dropout_rate = dropout_rate
|
|
122
|
-
self.
|
|
123
|
-
self.
|
|
115
|
+
self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "ascend910"
|
|
116
|
+
if self.is_910A:
|
|
117
|
+
self.scale_factor = Tensor([1. / math.sqrt(scaling_constant)], dtype=mstype.float16)
|
|
118
|
+
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
119
|
+
self.ones = ops.Ones()
|
|
120
|
+
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
121
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
122
|
+
self.alibi = alibi
|
|
123
|
+
self.flash_attention = get_flash_attention(
|
|
124
|
+
prev_block_num=prev_block_num,
|
|
125
|
+
next_block_num=next_block_num,
|
|
126
|
+
tiling_stgy_name=tiling_stgy_name,
|
|
127
|
+
high_precision=high_precision
|
|
128
|
+
)
|
|
129
|
+
self.flash_attention.add_prim_attr("primitive_target", "Ascend")
|
|
130
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
131
|
+
(dp, mp, 1, 1),
|
|
132
|
+
(dp, mp, 1, 1))
|
|
133
|
+
self.shard(fa_strategies)
|
|
134
|
+
else:
|
|
135
|
+
if alibi:
|
|
136
|
+
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
137
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
138
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
139
|
+
self.reshape = ops.Reshape()
|
|
140
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
141
|
+
self.zeros = ops.Zeros()
|
|
142
|
+
self.attn_cast = ops.Cast()
|
|
143
|
+
if use_mqa:
|
|
144
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
145
|
+
(dp, 1, 1, 1),
|
|
146
|
+
(dp, 1, 1, 1),
|
|
147
|
+
(dp, 1, 1, 1))
|
|
148
|
+
else:
|
|
149
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
150
|
+
(dp, mp, 1, 1),
|
|
151
|
+
(dp, mp, 1, 1),
|
|
152
|
+
(dp, 1, 1, 1))
|
|
153
|
+
if dropout_rate > 1e-5:
|
|
154
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
155
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
156
|
+
next_tokens=next_block_num,
|
|
157
|
+
keep_prob=1 - dropout_rate,
|
|
158
|
+
scale_value=1. / scaling_constant,
|
|
159
|
+
inner_precise=0 if high_precision else 1,
|
|
160
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
161
|
+
|
|
162
|
+
self.dropout_rate = dropout_rate
|
|
124
163
|
if self.dropout_rate > 1e-5:
|
|
125
164
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
126
165
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -136,49 +175,49 @@ class FlashAttention(Cell):
|
|
|
136
175
|
such as MatMul. Default: None.
|
|
137
176
|
:return:
|
|
138
177
|
"""
|
|
139
|
-
if
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
]
|
|
161
|
-
if self.have_attention_mask_batch:
|
|
162
|
-
inputs_tensor_map.append([3, 1, 0])
|
|
163
|
-
else:
|
|
164
|
-
inputs_tensor_map.append([-1, 1, 0])
|
|
178
|
+
if self.is_910A:
|
|
179
|
+
if in_strategy is None:
|
|
180
|
+
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
181
|
+
in_strategy = (
|
|
182
|
+
(1, 1, 1, 1),
|
|
183
|
+
(1, 1, 1, 1),
|
|
184
|
+
(1, 1, 1, 1),
|
|
185
|
+
)
|
|
186
|
+
self.flash_attention.shard(in_strategy)
|
|
187
|
+
dp = in_strategy[0][0]
|
|
188
|
+
mp = in_strategy[0][1]
|
|
189
|
+
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
190
|
+
inputs_tensor_map = [
|
|
191
|
+
[3, 2, 1, 0],
|
|
192
|
+
[3, 2, 1, 0],
|
|
193
|
+
[3, 2, 1, 0],
|
|
194
|
+
]
|
|
195
|
+
if self.have_attention_mask_batch:
|
|
196
|
+
inputs_tensor_map.append([3, 1, 0])
|
|
197
|
+
else:
|
|
198
|
+
inputs_tensor_map.append([-1, 1, 0])
|
|
165
199
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
200
|
+
input_empty_args_num = 2
|
|
201
|
+
# dropout_mask
|
|
202
|
+
if self.dropout_rate > 1e-5:
|
|
203
|
+
input_empty_args_num -= 1
|
|
204
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
169
205
|
|
|
170
|
-
|
|
171
|
-
|
|
206
|
+
if self.alibi:
|
|
207
|
+
input_empty_args_num -= 1
|
|
208
|
+
inputs_tensor_map.append([3, 2, 1, 0])
|
|
172
209
|
|
|
173
|
-
|
|
210
|
+
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
174
211
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
212
|
+
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
213
|
+
[3, 2, 1, 0], # O
|
|
214
|
+
[3, 2, 1], # L
|
|
215
|
+
[3, 2, 1] # M
|
|
216
|
+
])
|
|
217
|
+
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
218
|
+
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
219
|
+
else:
|
|
220
|
+
self.flash_attention.shard(in_strategy)
|
|
182
221
|
|
|
183
222
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
184
223
|
"""FlashAttention forward
|
|
@@ -189,35 +228,49 @@ class FlashAttention(Cell):
|
|
|
189
228
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
190
229
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
191
230
|
"""
|
|
192
|
-
query = self.scale_mul(query, self.scale_factor)
|
|
193
231
|
bsz, head_num, seq_len, head_dim = query.shape
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
output =
|
|
232
|
+
if self.is_910A:
|
|
233
|
+
_, k_head_num, k_seq_len, _ = key.shape
|
|
234
|
+
_, v_head_num, v_seq_len, _ = value.shape
|
|
235
|
+
if head_num != k_head_num or head_num != v_head_num:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"the head_num of query, key and value must be the same, "
|
|
238
|
+
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
239
|
+
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"query, key, value seq_len must be a multiple of 16, "
|
|
242
|
+
"and the seq_len between key and value must be equal.")
|
|
243
|
+
# 910A -- FlashAttentionPrimtive
|
|
244
|
+
if head_dim > 304:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
247
|
+
if self.dropout_rate > 1e-5:
|
|
248
|
+
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
249
|
+
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
250
|
+
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
251
|
+
ones = self.depend(ones, query)
|
|
252
|
+
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
253
|
+
else:
|
|
254
|
+
drop_mask = None
|
|
255
|
+
query = self.scale_mul(query, self.scale_factor)
|
|
256
|
+
key = self.scale_mul(key, self.scale_factor)
|
|
257
|
+
attn_mask = self.cast(attn_mask, mstype.float16)
|
|
258
|
+
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
221
259
|
else:
|
|
222
|
-
|
|
260
|
+
# 910B -- FlashAttentionScore
|
|
261
|
+
if self.dropout_rate > 1e-5:
|
|
262
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
263
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
264
|
+
else:
|
|
265
|
+
drop_mask_bits = None
|
|
266
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
267
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
268
|
+
output, _, _ = self.flash_attention(query,
|
|
269
|
+
key,
|
|
270
|
+
value,
|
|
271
|
+
attn_mask,
|
|
272
|
+
drop_mask_bits,
|
|
273
|
+
None,
|
|
274
|
+
None,
|
|
275
|
+
None)
|
|
223
276
|
return output
|
mindspore/nn/layer/image.py
CHANGED
|
@@ -83,17 +83,17 @@ class ImageGradients(Cell):
|
|
|
83
83
|
_check_input_4d(F.shape(images), "images", self.cls_name)
|
|
84
84
|
batch_size, depth, height, width = P.Shape()(images)
|
|
85
85
|
if height == 1:
|
|
86
|
-
dy =
|
|
86
|
+
dy = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
87
87
|
else:
|
|
88
88
|
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :]
|
|
89
|
-
dy_last =
|
|
89
|
+
dy_last = F.fill(P.DType()(images), (batch_size, depth, 1, width), 0)
|
|
90
90
|
dy = P.Concat(2)((dy, dy_last))
|
|
91
91
|
|
|
92
92
|
if width == 1:
|
|
93
|
-
dx =
|
|
93
|
+
dx = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
94
94
|
else:
|
|
95
95
|
dx = images[:, :, :, 1:] - images[:, :, :, :width - 1]
|
|
96
|
-
dx_last =
|
|
96
|
+
dx_last = F.fill(P.DType()(images), (batch_size, depth, height, 1), 0)
|
|
97
97
|
dx = P.Concat(3)((dx, dx_last))
|
|
98
98
|
return dy, dx
|
|
99
99
|
|
|
@@ -571,7 +571,8 @@ class PixelShuffle(Cell):
|
|
|
571
571
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
572
572
|
|
|
573
573
|
Typically, the input is of shape :math:`(*, C \times r^2, H, W)` , and the output is of shape
|
|
574
|
-
:math:`(*, C, H \times r, W \times r)`,
|
|
574
|
+
:math:`(*, C, H \times r, W \times r)`,
|
|
575
|
+
where :math:`r` is an upscale factor and :math:`*` is zero or more batch dimensions.
|
|
575
576
|
|
|
576
577
|
Note:
|
|
577
578
|
The dimension of input Tensor on Ascend should be less than 7.
|
|
@@ -621,7 +622,8 @@ class PixelUnshuffle(Cell):
|
|
|
621
622
|
<https://arxiv.org/abs/1609.05158>`_ .
|
|
622
623
|
|
|
623
624
|
Typically, the input is of shape :math:`(*, C, H \times r, W \times r)` , and the output is of shape
|
|
624
|
-
:math:`(*, C \times r^2, H, W)` ,
|
|
625
|
+
:math:`(*, C \times r^2, H, W)` ,
|
|
626
|
+
where :math:`r` is a downscale factor and :math:`*` is zero or more batch dimensions.
|
|
625
627
|
|
|
626
628
|
Args:
|
|
627
629
|
downscale_factor (int): factor to unshuffle the input, and is a positive integer.
|