mindspore 2.7.0rc1__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,765 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
The basic layer of the Transformer Networks. This is an experimental interface that is subject to
|
|
17
|
-
change or deletion.
|
|
18
|
-
"""
|
|
19
|
-
from __future__ import absolute_import
|
|
20
|
-
|
|
21
|
-
from functools import wraps, partial
|
|
22
|
-
import inspect
|
|
23
|
-
import math
|
|
24
|
-
import numpy as np
|
|
25
|
-
|
|
26
|
-
from mindspore import nn, context
|
|
27
|
-
from mindspore.common.parameter import Parameter
|
|
28
|
-
from mindspore.common.initializer import initializer, Tensor
|
|
29
|
-
import mindspore.common.dtype as mstype
|
|
30
|
-
from mindspore.common.seed import _get_graph_seed
|
|
31
|
-
from mindspore.ops import operations as P
|
|
32
|
-
from mindspore._extends import cell_attr_register
|
|
33
|
-
from mindspore.nn.cell import Cell
|
|
34
|
-
from mindspore.nn.layer.activation import get_activation
|
|
35
|
-
from mindspore.ops import functional as F
|
|
36
|
-
from mindspore import _checkparam as Validator
|
|
37
|
-
from mindspore.ops.primitive import constexpr
|
|
38
|
-
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
|
39
|
-
from mindspore.context import ParallelMode
|
|
40
|
-
from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
|
|
41
|
-
from mindspore import log as logger
|
|
42
|
-
|
|
43
|
-
__all__ = [
|
|
44
|
-
"FixedSparseAttention"
|
|
45
|
-
]
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def _args_type_validator_check(*type_args, **type_kwargs):
|
|
49
|
-
"""Check whether input data type is correct."""
|
|
50
|
-
|
|
51
|
-
def type_check(func):
|
|
52
|
-
sig = inspect.signature(func)
|
|
53
|
-
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
|
|
54
|
-
|
|
55
|
-
@wraps(func)
|
|
56
|
-
def wrapper(*args, **kwargs):
|
|
57
|
-
nonlocal bound_types
|
|
58
|
-
bound_values = sig.bind(*args, **kwargs)
|
|
59
|
-
|
|
60
|
-
argument_dict = bound_values.arguments
|
|
61
|
-
if "kwargs" in bound_types:
|
|
62
|
-
bound_types = bound_types["kwargs"]
|
|
63
|
-
if "kwargs" in argument_dict:
|
|
64
|
-
argument_dict = argument_dict["kwargs"]
|
|
65
|
-
for name, value in argument_dict.items():
|
|
66
|
-
if name in bound_types:
|
|
67
|
-
bound_types[name](value, name)
|
|
68
|
-
return func(*args, **kwargs)
|
|
69
|
-
|
|
70
|
-
return wrapper
|
|
71
|
-
|
|
72
|
-
return type_check
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def _valid_type_checks(types, class_name):
|
|
76
|
-
# types should be a list of types, this function check if the type is in the valid dtypes
|
|
77
|
-
def validator_check_func(value, name):
|
|
78
|
-
# The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
|
|
79
|
-
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
|
80
|
-
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
|
81
|
-
return partial_check(name, type(value))
|
|
82
|
-
|
|
83
|
-
return validator_check_func
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def _valid_value_checks(types, class_name):
|
|
87
|
-
# the value should be a list of types, this function check if the value is in the valid dtypes
|
|
88
|
-
def validator_check_func(value, name):
|
|
89
|
-
# The args of Validator.check_type_name is (arg_name, arg_type, valid_types, prim_name)
|
|
90
|
-
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
|
91
|
-
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
|
92
|
-
return partial_check(name, value)
|
|
93
|
-
|
|
94
|
-
return validator_check_func
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class _LayerInputCheck:
|
|
98
|
-
"""
|
|
99
|
-
A input check class for the inputs of the transformer model.
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
@staticmethod
|
|
103
|
-
def check_shape_length(input_shape, param_name, func_name, target_len):
|
|
104
|
-
"""
|
|
105
|
-
Check the input shape's length is equal to the expected shape
|
|
106
|
-
:param input_shape(list): a list of the tensor shapes.
|
|
107
|
-
:param param_name(str): the name of the checked parameter.
|
|
108
|
-
:param func_name(str): the name of the function.
|
|
109
|
-
:param target_len: the expected length of the shape.
|
|
110
|
-
:return:
|
|
111
|
-
"""
|
|
112
|
-
if not isinstance(target_len, list):
|
|
113
|
-
target_len = [target_len]
|
|
114
|
-
matched = False
|
|
115
|
-
for item in target_len:
|
|
116
|
-
if len(input_shape) == item:
|
|
117
|
-
matched = True
|
|
118
|
-
if not matched:
|
|
119
|
-
raise ValueError(f"{func_name} {param_name} shape length must be one of {target_len} dimension, "
|
|
120
|
-
f"but got shape {input_shape}")
|
|
121
|
-
return True
|
|
122
|
-
|
|
123
|
-
@staticmethod
|
|
124
|
-
def check_shape_equal(input_shape, param_name, func_name, target_shape):
|
|
125
|
-
"""
|
|
126
|
-
Check the input shape's is equal to the expected shape
|
|
127
|
-
:param input_shape(list): a list of the tensor shapes.
|
|
128
|
-
:param param_name(str): the name of the checked parameter.
|
|
129
|
-
:param func_name(str): the name of the function.
|
|
130
|
-
:param target_shape: the expected shape.
|
|
131
|
-
:return:
|
|
132
|
-
"""
|
|
133
|
-
if not isinstance(target_shape[0], list):
|
|
134
|
-
target_shape = [target_shape]
|
|
135
|
-
if isinstance(input_shape, tuple):
|
|
136
|
-
input_shape = list(input_shape)
|
|
137
|
-
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
|
138
|
-
[len(item) for item in target_shape])
|
|
139
|
-
matched = False
|
|
140
|
-
for item in target_shape:
|
|
141
|
-
if item == input_shape:
|
|
142
|
-
matched = True
|
|
143
|
-
break
|
|
144
|
-
|
|
145
|
-
if not matched:
|
|
146
|
-
raise ValueError(f"{func_name} {param_name} shape must be one of {target_shape},"
|
|
147
|
-
f"but got {input_shape}")
|
|
148
|
-
return True
|
|
149
|
-
|
|
150
|
-
@staticmethod
|
|
151
|
-
def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value):
|
|
152
|
-
""" Check whether the input_shape[dim] is equal to target value"""
|
|
153
|
-
if input_shape[dim] != target_value:
|
|
154
|
-
raise ValueError(f"{cls_name} {param_name} at {dim} shape must be {target_value},"
|
|
155
|
-
f"but got {input_shape[dim]}")
|
|
156
|
-
return True
|
|
157
|
-
|
|
158
|
-
@staticmethod
|
|
159
|
-
def check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
|
|
160
|
-
"""
|
|
161
|
-
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
|
|
162
|
-
batch size will not be checked.
|
|
163
|
-
"""
|
|
164
|
-
length, hidden = target_shape
|
|
165
|
-
if isinstance(input_shape, tuple):
|
|
166
|
-
input_shape = list(input_shape)
|
|
167
|
-
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
|
168
|
-
[len(target_shape), len(target_shape) + 1])
|
|
169
|
-
if input_shape[-1] != hidden:
|
|
170
|
-
raise ValueError(f"For {func_name}, the last dimension of {param_name} shape must be {hidden},"
|
|
171
|
-
f"but got the last dimension {input_shape[-1]} in {input_shape}.")
|
|
172
|
-
if input_shape[0] == 0:
|
|
173
|
-
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape greater than 0,"
|
|
174
|
-
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
|
175
|
-
if len(input_shape) == 2 and input_shape[0] % length != 0:
|
|
176
|
-
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape should be divisible "
|
|
177
|
-
f"by {length}, "
|
|
178
|
-
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
|
179
|
-
return True
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
@constexpr
|
|
183
|
-
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
|
|
184
|
-
""" If the past is True, check whether the inputs is None"""
|
|
185
|
-
if not use_past:
|
|
186
|
-
if is_tensor:
|
|
187
|
-
raise TypeError(f"{func_name} {param_name} must be {default_value}, if use_pat is False, but found "
|
|
188
|
-
f"a tensor")
|
|
189
|
-
if not is_default:
|
|
190
|
-
raise TypeError(f"{func_name} {param_name} must be {default_value}, if use_pat is False.")
|
|
191
|
-
else:
|
|
192
|
-
if not is_tensor:
|
|
193
|
-
raise TypeError(f"{func_name} {param_name} must be tensor, if use_pat is True")
|
|
194
|
-
return True
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
@constexpr
|
|
198
|
-
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|
199
|
-
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
class _Dropout(nn.Cell):
|
|
203
|
-
r"""
|
|
204
|
-
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
|
205
|
-
"""
|
|
206
|
-
|
|
207
|
-
def __init__(self, keep_prob=0.5, dtype=mstype.float32):
|
|
208
|
-
super(_Dropout, self).__init__()
|
|
209
|
-
if keep_prob <= 0 or keep_prob > 1:
|
|
210
|
-
raise ValueError(
|
|
211
|
-
"dropout probability should be a number in range (0, 1], but got {}".format(
|
|
212
|
-
keep_prob))
|
|
213
|
-
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
|
|
214
|
-
Validator.check_value_type('keep_prob', keep_prob, [float], self.cls_name)
|
|
215
|
-
self.keep_prob = keep_prob
|
|
216
|
-
self.is_ascend = context.get_context('device_target') in ["Ascend"]
|
|
217
|
-
if self.is_ascend:
|
|
218
|
-
seed0, seed1 = _get_graph_seed(0, "dropout")
|
|
219
|
-
self.seed0 = seed0
|
|
220
|
-
self.seed1 = seed1
|
|
221
|
-
self.dtype = dtype
|
|
222
|
-
self.get_shape = P.Shape()
|
|
223
|
-
self.dropout_gen_mask = P.DropoutGenMask(Seed0=self.seed0, Seed1=self.seed1)
|
|
224
|
-
self.dropout_do_mask = P.DropoutDoMask()
|
|
225
|
-
self.cast = P.Cast()
|
|
226
|
-
else:
|
|
227
|
-
self.dropout = P.Dropout(keep_prob)
|
|
228
|
-
|
|
229
|
-
def construct(self, x):
|
|
230
|
-
r"""
|
|
231
|
-
Input: a tensor
|
|
232
|
-
Returns: a tensor
|
|
233
|
-
"""
|
|
234
|
-
if not self.training:
|
|
235
|
-
return x
|
|
236
|
-
|
|
237
|
-
if self.keep_prob == 1:
|
|
238
|
-
return x
|
|
239
|
-
|
|
240
|
-
if not self.is_ascend:
|
|
241
|
-
out, _ = self.dropout(x)
|
|
242
|
-
return out
|
|
243
|
-
|
|
244
|
-
shape = self.get_shape(x)
|
|
245
|
-
dtype = P.DType()(x)
|
|
246
|
-
keep_prob = self.cast(self.keep_prob, dtype)
|
|
247
|
-
output = self.dropout_gen_mask(shape, keep_prob)
|
|
248
|
-
return self.dropout_do_mask(x, output, keep_prob)
|
|
249
|
-
|
|
250
|
-
def extend_repr(self):
|
|
251
|
-
return 'keep_prob={}'.format(self.keep_prob)
|
|
252
|
-
|
|
253
|
-
def shard(self, strategy):
|
|
254
|
-
if self.is_ascend:
|
|
255
|
-
self.dropout_gen_mask.shard(strategy)
|
|
256
|
-
self.dropout_do_mask.shard(strategy)
|
|
257
|
-
else:
|
|
258
|
-
self.dropout.shard(strategy)
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
class _LayerNorm(Cell):
|
|
262
|
-
r"""
|
|
263
|
-
A self-defined layer norm operation using reduce sum and reduce mean
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
normalized_shape (tuple): The shape of the input tensor
|
|
267
|
-
eps (float): The epsilon value of the denominator. Default 1e-5.
|
|
268
|
-
param_init_type: The param init type.
|
|
269
|
-
Inputs:
|
|
270
|
-
- **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
|
|
271
|
-
|
|
272
|
-
Outputs:
|
|
273
|
-
Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
|
|
274
|
-
"""
|
|
275
|
-
|
|
276
|
-
def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32, is_self_defined=False):
|
|
277
|
-
super(_LayerNorm, self).__init__()
|
|
278
|
-
if param_init_type not in [mstype.float32, mstype.float16]:
|
|
279
|
-
raise TypeError("The type of parameter 'param_init_type' should in [float32, float16], "
|
|
280
|
-
"but got the type : {}.".format(type(param_init_type)))
|
|
281
|
-
self.is_self_defined = is_self_defined
|
|
282
|
-
if not self.is_self_defined:
|
|
283
|
-
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
|
|
284
|
-
begin_params_axis=-1,
|
|
285
|
-
epsilon=eps)
|
|
286
|
-
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
|
|
287
|
-
parallel_optimizer=False)
|
|
288
|
-
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
|
|
289
|
-
parallel_optimizer=False)
|
|
290
|
-
self.mean = P.ReduceMean(keep_dims=True)
|
|
291
|
-
self.square = P.Square()
|
|
292
|
-
self.sqrt = P.Sqrt()
|
|
293
|
-
self.sub1 = P.Sub()
|
|
294
|
-
self.sub2 = P.Sub()
|
|
295
|
-
self.add = P.Add()
|
|
296
|
-
self.eps = eps
|
|
297
|
-
self.mul = P.Mul()
|
|
298
|
-
self.add2 = P.Add()
|
|
299
|
-
self.real_div = P.RealDiv()
|
|
300
|
-
|
|
301
|
-
def construct(self, x):
|
|
302
|
-
r"""
|
|
303
|
-
x : batch x seq_length x hidden_size
|
|
304
|
-
"""
|
|
305
|
-
if self.is_self_defined:
|
|
306
|
-
mean = self.mean(x, -1)
|
|
307
|
-
diff = self.sub1(x, mean)
|
|
308
|
-
variance = self.mean(self.square(diff), -1)
|
|
309
|
-
variance_eps = self.sqrt(self.add(variance, self.eps))
|
|
310
|
-
output = self.real_div(diff, variance_eps)
|
|
311
|
-
output = self.add2(self.mul(output, self.gamma), self.beta)
|
|
312
|
-
else:
|
|
313
|
-
output, _, _ = self.layer_norm(x, self.gamma, self.beta)
|
|
314
|
-
return output
|
|
315
|
-
|
|
316
|
-
def shard(self, strategy):
|
|
317
|
-
r"""
|
|
318
|
-
Set the shard for the layer norm. the strategy size should be equal to the inputs.
|
|
319
|
-
|
|
320
|
-
Note:
|
|
321
|
-
It is valid only in semi auto parallel or auto parallel mode.
|
|
322
|
-
In other parallel modes, strategies set here will be ignored.
|
|
323
|
-
|
|
324
|
-
Args:
|
|
325
|
-
strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
|
|
326
|
-
Examples:
|
|
327
|
-
>>> import mindspore
|
|
328
|
-
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
|
|
329
|
-
>>> net.shard(((10, 2, 1),))
|
|
330
|
-
"""
|
|
331
|
-
if self.is_self_defined:
|
|
332
|
-
self.mean.shard(strategy)
|
|
333
|
-
self.square.shard(strategy)
|
|
334
|
-
self.sqrt.shard(strategy)
|
|
335
|
-
self.sub1.shard((strategy[0], strategy[0]))
|
|
336
|
-
self.sub2.shard((strategy[0], strategy[0]))
|
|
337
|
-
self.add.shard((strategy[0], ()))
|
|
338
|
-
self.mul.shard((strategy[0], (1,)))
|
|
339
|
-
self.add2.shard((strategy[0], (1,)))
|
|
340
|
-
self.real_div.shard((strategy[0], strategy[0]))
|
|
341
|
-
else:
|
|
342
|
-
self.layer_norm.shard((strategy[0], (1,), (1,)))
|
|
343
|
-
|
|
344
|
-
return self
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
class _Linear(Cell):
|
|
348
|
-
r"""
|
|
349
|
-
The dense connected layer. Once the parallel mode is enabled, the input shape should be
|
|
350
|
-
3-D tensor.
|
|
351
|
-
|
|
352
|
-
Applies dense connected layer for the input. This layer implements the operation as:
|
|
353
|
-
|
|
354
|
-
.. math::
|
|
355
|
-
\text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
|
|
356
|
-
|
|
357
|
-
where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
|
|
358
|
-
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
|
|
359
|
-
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
360
|
-
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
361
|
-
|
|
362
|
-
Args:
|
|
363
|
-
in_channels (int): The number of channels in the input space.
|
|
364
|
-
out_channels (int): The number of channels in the output space.
|
|
365
|
-
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
366
|
-
is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
|
|
367
|
-
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
368
|
-
same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
|
369
|
-
has_bias (bool): Specifies whether the layer uses a bias vector. Default: ``True``.
|
|
370
|
-
activation (str): activate function applied to the output of the fully connected layer,
|
|
371
|
-
eg. 'ReLU'. Default: ``None``.
|
|
372
|
-
expert_num (int): The number of experts used in this Linear. Here, for the case expert_num > 1, BatchMatMul is
|
|
373
|
-
used and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
|
374
|
-
outer_batch (int): The replication number of experts. The replication is effective only when MoE is applied.
|
|
375
|
-
Default: 1.
|
|
376
|
-
expert_group_size (int): The number of tokens in each data parallel group. Default: ``None``.
|
|
377
|
-
compute_dtype (dtype.Number): The computation type. Default: mstype.float16
|
|
378
|
-
Inputs:
|
|
379
|
-
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
|
380
|
-
to :math:`in\_channels` in `Inputs`.
|
|
381
|
-
|
|
382
|
-
Outputs:
|
|
383
|
-
Tensor of shape :math:`(*, out\_channels)`.
|
|
384
|
-
|
|
385
|
-
Raises:
|
|
386
|
-
TypeError: If `in_channels` or `out_channels` is not an int.
|
|
387
|
-
TypeError: If `has_bias` is not a bool.
|
|
388
|
-
TypeError: If `activation` is not one of str, Cell, Primitive, None.
|
|
389
|
-
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
|
390
|
-
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
|
|
391
|
-
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
392
|
-
or shape[0] of `bias_init` is not equal to `out_channels`.
|
|
393
|
-
|
|
394
|
-
Supported Platforms:
|
|
395
|
-
``Ascend`` ``GPU``
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
@cell_attr_register
|
|
399
|
-
@_args_type_validator_check(in_channels=Validator.check_positive_int,
|
|
400
|
-
out_channels=Validator.check_positive_int,
|
|
401
|
-
has_bias=Validator.check_bool,
|
|
402
|
-
transpose_b=Validator.check_bool,
|
|
403
|
-
expert_num=Validator.check_positive_int,
|
|
404
|
-
outer_batch=Validator.check_positive_int,
|
|
405
|
-
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
|
406
|
-
"Linear"),
|
|
407
|
-
compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
|
|
408
|
-
"Linear"))
|
|
409
|
-
def __init__(self,
|
|
410
|
-
in_channels,
|
|
411
|
-
out_channels,
|
|
412
|
-
weight_init='normal',
|
|
413
|
-
bias_init='zeros',
|
|
414
|
-
has_bias=True,
|
|
415
|
-
activation=None,
|
|
416
|
-
transpose_b=True,
|
|
417
|
-
expert_num=1,
|
|
418
|
-
outer_batch=1,
|
|
419
|
-
expert_group_size=None,
|
|
420
|
-
param_init_type=mstype.float32,
|
|
421
|
-
compute_dtype=mstype.float16):
|
|
422
|
-
super(_Linear, self).__init__()
|
|
423
|
-
self.in_channels = in_channels
|
|
424
|
-
self.out_channels = out_channels
|
|
425
|
-
if not (isinstance(activation, str) or activation is None or issubclass(activation, nn.Cell)):
|
|
426
|
-
raise TypeError(f"For Linear cell, the activation should str type or nn.Cell type, but got {activation}.")
|
|
427
|
-
|
|
428
|
-
if isinstance(weight_init, Tensor):
|
|
429
|
-
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels \
|
|
430
|
-
or weight_init.shape[1] != in_channels:
|
|
431
|
-
raise ValueError("The shape of parameter 'weight_init' is error, please check shape of 'weight_init'.")
|
|
432
|
-
weight_shape = [out_channels, in_channels] if transpose_b else [in_channels, out_channels]
|
|
433
|
-
self.expert_num = expert_num
|
|
434
|
-
self.outer_batch = outer_batch
|
|
435
|
-
self.expert_group_size = expert_group_size
|
|
436
|
-
if self.expert_num > 1:
|
|
437
|
-
self.expert_flag = True
|
|
438
|
-
self.weight = Parameter(initializer(weight_init, [self.expert_num] + weight_shape, param_init_type),
|
|
439
|
-
name="weight")
|
|
440
|
-
self.matmul = P.BatchMatMul(transpose_b=transpose_b)
|
|
441
|
-
else:
|
|
442
|
-
self.expert_flag = False
|
|
443
|
-
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
|
|
444
|
-
self.matmul = P.MatMul(transpose_b=transpose_b)
|
|
445
|
-
self.use_expert_group_size = _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) \
|
|
446
|
-
and not _is_sharding_propagation() and self.expert_flag is True
|
|
447
|
-
if self.use_expert_group_size is True and self.expert_group_size is None:
|
|
448
|
-
raise ValueError("'expert_group_size' should be configured as an integer in MoEConfig.")
|
|
449
|
-
self.bias = None
|
|
450
|
-
self.has_bias = has_bias
|
|
451
|
-
if self.has_bias:
|
|
452
|
-
if isinstance(bias_init, Tensor) and (bias_init.ndim != 1 or bias_init.shape[0] != out_channels):
|
|
453
|
-
raise ValueError("The shape of parameter 'bias_init' is error, please check shape of 'bias_init'.")
|
|
454
|
-
if self.expert_flag:
|
|
455
|
-
self.bias = Parameter(initializer(bias_init,
|
|
456
|
-
[1, self.expert_num, 1, out_channels], param_init_type), name="bias")
|
|
457
|
-
else:
|
|
458
|
-
self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
|
|
459
|
-
self.bias.parallel_optimizer = False
|
|
460
|
-
self.bias_add = P.Add()
|
|
461
|
-
self.act_name = activation
|
|
462
|
-
if callable(activation):
|
|
463
|
-
self.activation = activation()
|
|
464
|
-
else:
|
|
465
|
-
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
|
466
|
-
self.activation_flag = self.activation is not None
|
|
467
|
-
self.dtype = compute_dtype
|
|
468
|
-
self.cast = P.Cast()
|
|
469
|
-
|
|
470
|
-
def construct(self, x):
|
|
471
|
-
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
|
|
472
|
-
x = P.Reshape()(x, (-1, self.in_channels))
|
|
473
|
-
if self.expert_flag:
|
|
474
|
-
if self.use_expert_group_size is True:
|
|
475
|
-
x = P.Reshape()(x, (-1, self.expert_num, self.expert_group_size, self.in_channels))
|
|
476
|
-
else:
|
|
477
|
-
x = P.Reshape()(x, (self.outer_batch, self.expert_num, -1, self.in_channels))
|
|
478
|
-
ori_dtype = F.dtype(x)
|
|
479
|
-
weight = self.cast(self.weight, self.dtype)
|
|
480
|
-
x = self.cast(x, self.dtype)
|
|
481
|
-
x = self.matmul(x, weight)
|
|
482
|
-
if self.has_bias:
|
|
483
|
-
x = self.bias_add(x, self.cast(self.bias, self.dtype))
|
|
484
|
-
if self.activation_flag:
|
|
485
|
-
x = self.activation(x)
|
|
486
|
-
x = F.cast(x, ori_dtype)
|
|
487
|
-
output = P.Reshape()(x, out_shape)
|
|
488
|
-
return output
|
|
489
|
-
|
|
490
|
-
def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
|
|
491
|
-
r"""
|
|
492
|
-
Set the shard for the linear. the strategy size should be equal to the inputs.
|
|
493
|
-
|
|
494
|
-
Note:
|
|
495
|
-
It is valid only in semi auto parallel or auto parallel mode.
|
|
496
|
-
In other parallel modes, strategies set here will be ignored.
|
|
497
|
-
|
|
498
|
-
Args:
|
|
499
|
-
strategy_matmul (tuple): The strategy for the matmul. Should be the same shape as the inputs.
|
|
500
|
-
strategy_bias (tuple): The strategy for the bias_add. Should be the same shape as the inputs.
|
|
501
|
-
strategy_activation (tuple): The strategy for the strategy_activation. Should be the same shape as
|
|
502
|
-
the inputs.
|
|
503
|
-
"""
|
|
504
|
-
self.matmul.shard(strategy_matmul)
|
|
505
|
-
if self.has_bias:
|
|
506
|
-
self.bias_add.shard(strategy_bias)
|
|
507
|
-
if self.activation_flag and isinstance(self.act_name, str):
|
|
508
|
-
# some operations has many primitives, need to manually set the shard
|
|
509
|
-
if self.act_name.lower() == "leakyrelu":
|
|
510
|
-
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
|
|
511
|
-
elif self.act_name.lower() == "logsigmoid":
|
|
512
|
-
self.activation.mul.shard((strategy_activation[0], ()))
|
|
513
|
-
self.activation.exp.shard(strategy_activation)
|
|
514
|
-
self.activation.add.shard((strategy_activation[0], ()))
|
|
515
|
-
self.activation.rec.shard(strategy_activation)
|
|
516
|
-
self.activation.log.shard(strategy_activation)
|
|
517
|
-
elif self.act_name.lower() == "logsoftmax":
|
|
518
|
-
raise ValueError("The 'LogSoftmax' function is not supported in semi auto parallel "
|
|
519
|
-
"or auto parallel mode.")
|
|
520
|
-
else:
|
|
521
|
-
getattr(self.activation, self.act_name).shard(strategy_activation)
|
|
522
|
-
elif self.activation_flag and isinstance(self.activation, Cell):
|
|
523
|
-
if hasattr(self.activation, 'activation_shard') and strategy_activation:
|
|
524
|
-
shard_tuple = strategy_activation[0]
|
|
525
|
-
if len(shard_tuple) == 2:
|
|
526
|
-
parallel_config = OpParallelConfig(data_parallel=shard_tuple[0],
|
|
527
|
-
model_parallel=shard_tuple[1])
|
|
528
|
-
elif len(shard_tuple) == 4:
|
|
529
|
-
parallel_config = MoEParallelConfig(data_parallel=shard_tuple[0],
|
|
530
|
-
expert_parallel=shard_tuple[1],
|
|
531
|
-
model_parallel=shard_tuple[2])
|
|
532
|
-
else:
|
|
533
|
-
raise ValueError("The user-defined activation function currently only supports the case where the "
|
|
534
|
-
"input policy is 2 or 4, so that relevant policies can be extracted from it."
|
|
535
|
-
"To avoid this error, you need to add the function of extracting "
|
|
536
|
-
"'ParallelConfig' or 'OpParallelConfig' for the incoming strategy_activation ")
|
|
537
|
-
self.activation.activation_shard(parallel_config)
|
|
538
|
-
else:
|
|
539
|
-
logger.warning(f"The user passed the custom defined activation function {self.activation_flag}. "
|
|
540
|
-
f"If the user want to enable shard for the activation cell, "
|
|
541
|
-
f"the user should set the shard for each primitives in the cell.")
|
|
542
|
-
return self
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
class FixedSparseAttention(nn.Cell):
|
|
546
|
-
"""
|
|
547
|
-
Fixed Sparse Attention Layer.
|
|
548
|
-
|
|
549
|
-
This function contains the sparse attention primitives used in Sparse Transformers (see paper)
|
|
550
|
-
`Generating Long Sequences with Sparse Transformers <https://arxiv.org/abs/1904.10509>`_.
|
|
551
|
-
|
|
552
|
-
Specifically, it includes the following:
|
|
553
|
-
|
|
554
|
-
1. A faster implementation of normal attention (the upper triangle is not computed, and many operations are fused).
|
|
555
|
-
2. An implementation of "strided" and "fixed" attention, as in the Sparse Transformers paper.
|
|
556
|
-
|
|
557
|
-
Args:
|
|
558
|
-
batch_size(int): Number of input batch size.
|
|
559
|
-
num_heads(int): Number of attention heads.
|
|
560
|
-
size_per_head(int): An integer determining embedding size of each attention head,
|
|
561
|
-
only supports 64, 128 for now.
|
|
562
|
-
block_size(int): An integer determining the block size. Current implementation of sparse self-attention
|
|
563
|
-
is based on blocked sparse matrices. In which this parameter defines the size of such blocks,
|
|
564
|
-
Block X Block. Only supports 64 for now.
|
|
565
|
-
seq_length(int): length of input sequence, only supports 1024 for now. Default 1024.
|
|
566
|
-
num_different_global_patterns(int): An integer determining the number of different global attentions layouts.
|
|
567
|
-
While global attention can be fixed by which block/s are representative of
|
|
568
|
-
any local window, since there are multi-heads, each head can use a
|
|
569
|
-
different global representative, only supports 4 for now. Default 4.
|
|
570
|
-
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
|
|
571
|
-
Default `default_dpmp_config`, an instance of `OpParallelConfig` with
|
|
572
|
-
default args.
|
|
573
|
-
|
|
574
|
-
Inputs:
|
|
575
|
-
- **q** (Tensor) - Tensor query ( `mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
|
|
576
|
-
queries to query the context.
|
|
577
|
-
- **k** (Tensor) - Tensor key ( `mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
|
|
578
|
-
queries to query the context.
|
|
579
|
-
- **v** (Tensor) - Tensor value ( `mstype.fp16` [batch size, sequence length, Embedding Size]):
|
|
580
|
-
Sequence of queries to query the context.
|
|
581
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of ( `mstype.fp32`, `mstype.fp16`
|
|
582
|
-
[batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
|
|
583
|
-
|
|
584
|
-
Outputs:
|
|
585
|
-
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
|
|
586
|
-
|
|
587
|
-
Supported Platforms:
|
|
588
|
-
``Ascend``
|
|
589
|
-
|
|
590
|
-
Examples:
|
|
591
|
-
>>> import numpy as np
|
|
592
|
-
>>> from mindspore import dtype as mstype
|
|
593
|
-
>>> from mindspore.nn.transformer import FixedSparseAttention
|
|
594
|
-
>>> from mindspore import Tensor
|
|
595
|
-
>>> model = FixedSparseAttention(batch_size=2,
|
|
596
|
-
... num_heads=8,
|
|
597
|
-
... size_per_head=64,
|
|
598
|
-
... block_size=64)
|
|
599
|
-
>>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
|
|
600
|
-
>>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
|
|
601
|
-
>>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
|
|
602
|
-
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
|
|
603
|
-
>>> output = model(q, k, v, attention_mask)
|
|
604
|
-
>>> print(output.shape)
|
|
605
|
-
(2, 1024, 512)
|
|
606
|
-
"""
|
|
607
|
-
|
|
608
|
-
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
|
609
|
-
num_heads=Validator.check_positive_int,
|
|
610
|
-
size_per_head=Validator.check_positive_int,
|
|
611
|
-
block_size=Validator.check_positive_int,
|
|
612
|
-
seq_length=Validator.check_positive_int,
|
|
613
|
-
num_different_global_patterns=Validator.check_positive_int,
|
|
614
|
-
parallel_config=_valid_type_checks([OpParallelConfig], "FixedSparseAttention"))
|
|
615
|
-
def __init__(self,
|
|
616
|
-
batch_size,
|
|
617
|
-
num_heads,
|
|
618
|
-
size_per_head,
|
|
619
|
-
block_size,
|
|
620
|
-
seq_length=1024,
|
|
621
|
-
num_different_global_patterns=4,
|
|
622
|
-
parallel_config=default_dpmp_config):
|
|
623
|
-
super(FixedSparseAttention, self).__init__()
|
|
624
|
-
dp, mp = parallel_config.data_parallel, parallel_config.model_parallel
|
|
625
|
-
if num_heads % mp != 0:
|
|
626
|
-
raise ValueError(f"The number of heads {num_heads} must be a "
|
|
627
|
-
f"multiple of parallel_config.model_parallel {mp}.")
|
|
628
|
-
if batch_size % dp != 0:
|
|
629
|
-
raise ValueError(f"The batch_size {batch_size} must be a "
|
|
630
|
-
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
|
631
|
-
self.seq_length = seq_length
|
|
632
|
-
self.batch_size = batch_size
|
|
633
|
-
self.hidden_size = size_per_head * num_heads
|
|
634
|
-
self.num_heads = num_heads
|
|
635
|
-
self.block_size = block_size
|
|
636
|
-
self.block_num = seq_length // block_size
|
|
637
|
-
self.size_per_head = size_per_head
|
|
638
|
-
self.global_size = seq_length // 4
|
|
639
|
-
self.reshape = P.Reshape()
|
|
640
|
-
self.transpose = P.Transpose().shard(((dp, 1, mp, 1),))
|
|
641
|
-
self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
|
642
|
-
self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1)))
|
|
643
|
-
self.multiply_data = Tensor([-10000.0], dtype=mstype.float32)
|
|
644
|
-
self.parallel_config = parallel_config
|
|
645
|
-
size_per_head_list = [64, 128]
|
|
646
|
-
if self.seq_length != 1024:
|
|
647
|
-
raise ValueError("For 'FixedSparseAttention', the class variable 'seq_length' must be 1024, "
|
|
648
|
-
"but got the value : {}.".format(seq_length))
|
|
649
|
-
if self.block_size != 64:
|
|
650
|
-
raise ValueError("For 'FixedSparseAttention', the class variable 'block_size' must be 64, "
|
|
651
|
-
"but got the value : {}.".format(block_size))
|
|
652
|
-
if num_different_global_patterns != 4:
|
|
653
|
-
raise ValueError("For 'FixedSparseAttention', the class variable 'num_different_global_patterns' "
|
|
654
|
-
"must be 4, but got the value : {}".format(num_different_global_patterns))
|
|
655
|
-
if self.size_per_head not in size_per_head_list:
|
|
656
|
-
raise ValueError("For 'FixedSparseAttention', the class variable 'size_per_head' only supports {}, "
|
|
657
|
-
"but got the value : {}.".format(size_per_head_list, self.size_per_head))
|
|
658
|
-
local_ones = np.ones((self.block_size, self.block_size),
|
|
659
|
-
dtype=np.float16)
|
|
660
|
-
global_mask_original = np.ones((self.seq_length, self.global_size), dtype=np.float16)
|
|
661
|
-
for i in range(self.seq_length):
|
|
662
|
-
for j in range(self.global_size):
|
|
663
|
-
if i // 16 >= (j // 16 + 1) * 4:
|
|
664
|
-
global_mask_original[i, j] = 0.0
|
|
665
|
-
|
|
666
|
-
global_mask_original = -10000 * global_mask_original
|
|
667
|
-
global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16))
|
|
668
|
-
global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3))
|
|
669
|
-
global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0)
|
|
670
|
-
global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16))
|
|
671
|
-
self.global_mask = Tensor(global_mask, mstype.float32)
|
|
672
|
-
self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32)
|
|
673
|
-
self.scale_factor = Tensor((math.sqrt(self.size_per_head)))
|
|
674
|
-
self.matmul_dds = P.MatmulDDS(self.batch_size, self.num_heads).shard(((mp, dp, 1, 1),
|
|
675
|
-
(mp, dp, 1, 1),
|
|
676
|
-
(1, dp, 1, 1),
|
|
677
|
-
(dp, 1, 1, 1)))
|
|
678
|
-
self.matmul_dsd = P.DSDMatmul().shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1)))
|
|
679
|
-
self.sub1 = P.Sub().shard(((1,), (dp, 1, 1, 1)))
|
|
680
|
-
self.mul1 = P.Mul().shard(((dp, 1, 1, 1), (1,)))
|
|
681
|
-
self.transpose1 = P.Transpose().shard(((dp, 1, 1, 1),))
|
|
682
|
-
self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),))
|
|
683
|
-
self.transpose3 = P.Transpose().shard(((dp, mp, 1, 1, 1, 1),))
|
|
684
|
-
self.transpose4 = P.Transpose().shard(((dp, mp, 1, 1),))
|
|
685
|
-
self.div = P.RealDiv().shard(((mp, dp, 1, 1), ()))
|
|
686
|
-
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
|
|
687
|
-
|
|
688
|
-
def construct(self, q, k, v, attention_mask):
|
|
689
|
-
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
|
|
690
|
-
_check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name)
|
|
691
|
-
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
|
|
692
|
-
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
|
693
|
-
|
|
694
|
-
q, k, v = self._transpose_inputs(q, k, v)
|
|
695
|
-
local_mask, global_mask = self._generate_attention_mask(attention_mask)
|
|
696
|
-
q = self.div(q, F.cast(self.scale_factor, F.dtype(q)))
|
|
697
|
-
k = self.div(k, F.cast(self.scale_factor, F.dtype(k)))
|
|
698
|
-
local_prob, global_prob = self.matmul_dds(q, k, local_mask, global_mask)
|
|
699
|
-
attention = self.matmul_dsd(local_prob, global_prob, v)
|
|
700
|
-
attention_merge = self.transpose3(attention, (0, 1, 3, 4, 2, 5))
|
|
701
|
-
attention_merge = F.reshape(
|
|
702
|
-
attention_merge,
|
|
703
|
-
(-1, self.num_heads, self.seq_length, self.size_per_head))
|
|
704
|
-
attention_merge = self.transpose4(attention_merge, (0, 2, 1, 3))
|
|
705
|
-
attention_merge = F.reshape(
|
|
706
|
-
attention_merge,
|
|
707
|
-
(-1, self.seq_length, self.size_per_head * self.num_heads))
|
|
708
|
-
|
|
709
|
-
return attention_merge
|
|
710
|
-
|
|
711
|
-
def _generate_attention_mask(self, attention_mask):
|
|
712
|
-
"""
|
|
713
|
-
generate global attention mask and local attention mask from origin attention mask
|
|
714
|
-
"""
|
|
715
|
-
attention_mask = self.reshape(attention_mask, (-1, self.seq_length, self.seq_length))
|
|
716
|
-
input_mask = self.slice1(attention_mask, (0, self.seq_length - 1, 0),
|
|
717
|
-
(self.batch_size, self.seq_length, self.seq_length), (1, 1, 1))
|
|
718
|
-
input_mask = self.reshape(input_mask, (-1, self.seq_length))
|
|
719
|
-
input_shape = P.Shape()(input_mask) # bs, seq_length
|
|
720
|
-
# bs, block_num, 1, block_size
|
|
721
|
-
local_shape_right = (input_shape[0], self.block_num, 1, self.block_size)
|
|
722
|
-
# bs, block_num, block_size, 1
|
|
723
|
-
local_shape_left = (input_shape[0], self.block_num, self.block_size, 1)
|
|
724
|
-
local_mask_left = self.reshape(input_mask, local_shape_left)
|
|
725
|
-
local_mask_right = self.reshape(input_mask, local_shape_right)
|
|
726
|
-
# bs, block_num, block_size, block_size
|
|
727
|
-
local_attention_mask = self.batch_matmul(local_mask_left, local_mask_right)
|
|
728
|
-
lower_triangle = P.ExpandDims()(self.local_mask_triangle, 0)
|
|
729
|
-
local_attention_mask = self.multiply(local_attention_mask, lower_triangle)
|
|
730
|
-
local_multiplied_out = self.sub1(P.Cast()(F.tuple_to_array((1.0,)), mstype.float32),
|
|
731
|
-
P.Cast()(local_attention_mask, mstype.float32))
|
|
732
|
-
local_adder = self.mul1(local_multiplied_out, self.multiply_data)
|
|
733
|
-
local_mask_original = self.transpose1(local_adder, (0, 2, 1, 3))
|
|
734
|
-
local_mask_original = self.reshape(
|
|
735
|
-
local_mask_original,
|
|
736
|
-
(self.batch_size * self.block_size, self.block_num * self.block_size))
|
|
737
|
-
local_mask_fx = self.reshape(
|
|
738
|
-
local_mask_original,
|
|
739
|
-
(self.batch_size * self.block_size // 16, 16,
|
|
740
|
-
self.block_num * self.block_size // 16, 16))
|
|
741
|
-
local_mask = self.transpose2(local_mask_fx, (2, 0, 1, 3))
|
|
742
|
-
global_mask = self.global_mask
|
|
743
|
-
|
|
744
|
-
return local_mask, global_mask
|
|
745
|
-
|
|
746
|
-
def _transpose_inputs(self, q, k, v):
|
|
747
|
-
"""
|
|
748
|
-
do reshape and transpose to inputs
|
|
749
|
-
"""
|
|
750
|
-
q = self.transpose(
|
|
751
|
-
self.reshape(
|
|
752
|
-
q,
|
|
753
|
-
(-1, 16, self.num_heads * self.size_per_head // 16, 16)),
|
|
754
|
-
(2, 0, 1, 3))
|
|
755
|
-
k = self.transpose(
|
|
756
|
-
self.reshape(
|
|
757
|
-
k, (-1, 16, self.num_heads * self.size_per_head // 16, 16)),
|
|
758
|
-
(2, 0, 1, 3))
|
|
759
|
-
v = self.transpose(
|
|
760
|
-
self.reshape(
|
|
761
|
-
v,
|
|
762
|
-
(-1, 16, self.num_heads * self.size_per_head // 16, 16)),
|
|
763
|
-
(0, 2, 3, 1))
|
|
764
|
-
|
|
765
|
-
return q, k, v
|