mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.5.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -28,7 +28,7 @@ from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
|
28
28
|
from mindspore.common._stub_tensor import _convert_stub
|
|
29
29
|
from mindspore._c_expression import typing
|
|
30
30
|
from mindspore._c_expression import Tensor as Tensor_
|
|
31
|
-
from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
|
|
31
|
+
from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones, pyboost_type_as
|
|
32
32
|
from mindspore.common import dtype as mstype
|
|
33
33
|
from mindspore.common._utils import is_shape_unknown
|
|
34
34
|
from mindspore import _checkparam as validator
|
|
@@ -580,21 +580,15 @@ class BatchNorm(Primitive):
|
|
|
580
580
|
is only supported in GPU target. Default: ``"NCHW"`` .
|
|
581
581
|
|
|
582
582
|
Inputs:
|
|
583
|
-
If `is_training` is ``False`` , inputs are Tensors.
|
|
584
|
-
|
|
585
|
-
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
|
586
|
-
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
|
587
|
-
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
588
|
-
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
589
|
-
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
590
|
-
|
|
591
|
-
If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
|
|
592
|
-
|
|
593
583
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
|
594
|
-
- **scale** (Parameter) - Parameter of shape :math:`(C,)`,
|
|
595
|
-
|
|
596
|
-
- **
|
|
597
|
-
|
|
584
|
+
- **scale** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
585
|
+
with float16 or float32 data type.
|
|
586
|
+
- **bias** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
587
|
+
has the same data type with `scale`.
|
|
588
|
+
- **mean** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
589
|
+
has the same data type with `scale`.
|
|
590
|
+
- **variance** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
591
|
+
has the same data type with `scale`.
|
|
598
592
|
|
|
599
593
|
Outputs:
|
|
600
594
|
Tuple of 5 Tensors, the normalized inputs and the updated parameters.
|
|
@@ -1026,9 +1020,9 @@ class Tile(Primitive):
|
|
|
1026
1020
|
|
|
1027
1021
|
def tile(input, dims):
|
|
1028
1022
|
r"""
|
|
1029
|
-
Creates a new tensor by
|
|
1023
|
+
Creates a new tensor by repeating `input` `dims` times. The i'th dimension of
|
|
1030
1024
|
output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
|
|
1031
|
-
are
|
|
1025
|
+
are repeated `dims[i]` times along the i'th dimension.
|
|
1032
1026
|
|
|
1033
1027
|
Note:
|
|
1034
1028
|
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
@@ -1176,7 +1170,6 @@ class Cast(Primitive):
|
|
|
1176
1170
|
if data.dtype == dtype:
|
|
1177
1171
|
return (True, x)
|
|
1178
1172
|
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1179
|
-
x.set_cast_dtype()
|
|
1180
1173
|
return (True, x)
|
|
1181
1174
|
if isinstance(x, numbers.Number):
|
|
1182
1175
|
return (True, Tensor(x, dtype=dtype))
|
|
@@ -1189,6 +1182,59 @@ class Cast(Primitive):
|
|
|
1189
1182
|
return _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
|
|
1190
1183
|
|
|
1191
1184
|
|
|
1185
|
+
class TypeAs(Primitive):
|
|
1186
|
+
"""
|
|
1187
|
+
Returns first input tensor cast to the type of the with the second input tensor.
|
|
1188
|
+
|
|
1189
|
+
.. warning::
|
|
1190
|
+
This is an experimental API that is subject to change or deletion.
|
|
1191
|
+
|
|
1192
|
+
Note:
|
|
1193
|
+
When converting complex numbers to boolean type, the imaginary part of the complex number is not
|
|
1194
|
+
taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
|
|
1195
|
+
|
|
1196
|
+
Inputs:
|
|
1197
|
+
- **input** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
|
|
1198
|
+
The tensor whose data type is to be converted.
|
|
1199
|
+
- **other ** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
|
|
1200
|
+
The tensor whose data type is specified.
|
|
1201
|
+
|
|
1202
|
+
Outputs:
|
|
1203
|
+
Tensor, the shape of tensor is the same as `input`, :math:`(x_0, x_1, ..., x_R)`.
|
|
1204
|
+
|
|
1205
|
+
Raises:
|
|
1206
|
+
TypeError: If `input` is not a Tensor.
|
|
1207
|
+
TypeError: If `other` is not a Tensor.
|
|
1208
|
+
|
|
1209
|
+
Supported Platforms:
|
|
1210
|
+
``Ascend``
|
|
1211
|
+
|
|
1212
|
+
Examples:
|
|
1213
|
+
>>> import mindspore
|
|
1214
|
+
>>> import numpy as np
|
|
1215
|
+
>>> from mindspore import Tensor, ops
|
|
1216
|
+
>>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
|
1217
|
+
>>> input = Tensor(input_np)
|
|
1218
|
+
>>> other_np = np.random.randn(2, 3, 4).astype(np.int32)
|
|
1219
|
+
>>> other = Tensor(other_np)
|
|
1220
|
+
>>> type_as = ops.TypeAs()
|
|
1221
|
+
>>> output = type_as(input, other)
|
|
1222
|
+
>>> print(output.dtype)
|
|
1223
|
+
Int32
|
|
1224
|
+
>>> print(output.shape)
|
|
1225
|
+
(2, 3, 4, 5)
|
|
1226
|
+
"""
|
|
1227
|
+
|
|
1228
|
+
@prim_attr_register
|
|
1229
|
+
def __init__(self):
|
|
1230
|
+
pass
|
|
1231
|
+
|
|
1232
|
+
def __call__(self, input, other):
|
|
1233
|
+
if input.dtype == other.dtype:
|
|
1234
|
+
return input
|
|
1235
|
+
return _convert_stub(pyboost_type_as(self, [input, other]))
|
|
1236
|
+
|
|
1237
|
+
|
|
1192
1238
|
def to_sequence(val):
|
|
1193
1239
|
"""
|
|
1194
1240
|
to_sequence
|
|
@@ -1791,7 +1837,7 @@ class Ones(Primitive):
|
|
|
1791
1837
|
Tensor, whose dtype and size are defined by input.
|
|
1792
1838
|
|
|
1793
1839
|
Raises:
|
|
1794
|
-
TypeError: If `shape` is neither an int nor
|
|
1840
|
+
TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
|
|
1795
1841
|
|
|
1796
1842
|
Supported Platforms:
|
|
1797
1843
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1827,7 +1873,7 @@ class Ones(Primitive):
|
|
|
1827
1873
|
|
|
1828
1874
|
class Zeros(Primitive):
|
|
1829
1875
|
r"""
|
|
1830
|
-
Zeros will be deprecated in the future. Please use class
|
|
1876
|
+
Zeros will be deprecated in the future. Please use class :func:`mindspore.ops.zeros` instead.
|
|
1831
1877
|
|
|
1832
1878
|
Creates a tensor filled with value zeros.
|
|
1833
1879
|
|
|
@@ -1845,7 +1891,7 @@ class Zeros(Primitive):
|
|
|
1845
1891
|
Tensor, whose dtype and size are defined by input.
|
|
1846
1892
|
|
|
1847
1893
|
Raises:
|
|
1848
|
-
TypeError: If `shape` is neither an int nor
|
|
1894
|
+
TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
|
|
1849
1895
|
|
|
1850
1896
|
Supported Platforms:
|
|
1851
1897
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1880,116 +1926,132 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1880
1926
|
scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
|
|
1881
1927
|
input_layout='BSH', sparse_mode=0):
|
|
1882
1928
|
r"""
|
|
1883
|
-
|
|
1929
|
+
Implement self-attention calculations in training scenarios.
|
|
1930
|
+
|
|
1931
|
+
- B: Batch size. Value range 1 to 2k.
|
|
1932
|
+
- S1: Sequence length of `query`. Value range 1 to 512k.
|
|
1933
|
+
- S2: Sequence length of `key` and `value`. Value range 1 to 512k.
|
|
1934
|
+
- N1: Num heads of `query`. Value range 1 to 256.
|
|
1935
|
+
- N2: Num heads of `key` and `value`, and N2 must be a factor of N1.
|
|
1936
|
+
- D: Head size. The value ranges is a multiple of 16, with the max value of 512.
|
|
1937
|
+
- H1: Hidden size of `query`, which equals to N1 * D.
|
|
1938
|
+
- H2: Hidden size of `key` and `value`, which equals to N2 * D.
|
|
1939
|
+
|
|
1940
|
+
The self attention calculation formula is defined as:
|
|
1884
1941
|
|
|
1885
1942
|
.. math::
|
|
1886
1943
|
\begin{array}{ll} \\
|
|
1887
|
-
|
|
1888
|
-
\
|
|
1944
|
+
\text { attention_out }=\operatorname{Dropout}\left(\operatorname{Softmax}\left(\text
|
|
1945
|
+
{ Mask(scale } *\left(\text { query } * \mathrm{key}^{\top}\right)+\text { pse }\right)\text
|
|
1946
|
+
{, atten_mask), keep_prob) } *\right. \text { value }
|
|
1889
1947
|
\end{array}
|
|
1890
1948
|
|
|
1891
|
-
B -- Batch size. Value range 1 to 2k.
|
|
1892
|
-
S1 -- Sequence length of query. Value range 1 to 512k.
|
|
1893
|
-
S2 -- Sequence length of key and value. Value range 1 to 512k.
|
|
1894
|
-
N1 -- Num heads of query. Value range 1 to 256.
|
|
1895
|
-
N2 -- Num heads of key and value, and N2 must be a factor of N1.
|
|
1896
|
-
D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
|
|
1897
|
-
H1 -- Hidden size of query, which equals to N1 * D.
|
|
1898
|
-
H2 -- Hidden size of key and value, which equals to N2 * D.
|
|
1899
|
-
|
|
1900
1949
|
.. warning::
|
|
1901
|
-
This is an experimental API that is subject to change or deletion.
|
|
1950
|
+
- This is an experimental API that is subject to change or deletion.
|
|
1951
|
+
- Only support on Atlas A2 training series.
|
|
1902
1952
|
|
|
1903
1953
|
Args:
|
|
1904
|
-
query (Tensor
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
head_num (int): The head num of query
|
|
1911
|
-
real_shift (
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
`
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1954
|
+
query (Tensor): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
|
|
1955
|
+
:math:`(B, N1, S1, D)`, :math:`(S1, B, H1)`, :math:`(B, S1, N1, D)` or :math:`(T1, N1, D)`.
|
|
1956
|
+
The supported dtype is float16 and bfloat16.
|
|
1957
|
+
key (Tensor): The key tensor with the same dtype as `query`. Supported shape: :math:`(B, S2, H2)`,
|
|
1958
|
+
:math:`(B, N2, S2, D)`, :math:`(S2, B, H2)`, :math:`(B, S2, N2, D)` or :math:`(T2, N2, D)`.
|
|
1959
|
+
value (Tensor): The value tensor with the same dtype and shape as `key`.
|
|
1960
|
+
head_num (int): The head num of `query`, equal to N1.
|
|
1961
|
+
real_shift (Tensor, optional): The position embedding code which is also known as pse, it has the same
|
|
1962
|
+
dtype as `query`.
|
|
1963
|
+
Default: ``None``.
|
|
1964
|
+
If S is greater than 1024 and the mask of the lower triangle is used, only the inverse 1024 lines of
|
|
1965
|
+
the lower triangle is used for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
|
|
1966
|
+
:math:`(1, N1, S1, S2)`, :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
|
|
1967
|
+
|
|
1968
|
+
- ALiBi scenario: `real_shift` must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
|
|
1969
|
+
In this scenario, `real_shift` is :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
|
|
1970
|
+
- Non-ALiBi scenario: `real_shift` is :math:`(B, N1, S1, S2)`, :math:`(1, N1, S1, S2)`.
|
|
1971
|
+
- input_layout is TND: shape should be :math:`(B, N1, 1024, S2)` and :math:`(1, N1, 1024, S2)`.
|
|
1972
|
+
|
|
1973
|
+
drop_mask (Tensor, optional): The dropout mask tensor of uint8. Input tensor of shape
|
|
1974
|
+
:math:`(B, N1, S1, S2 // 8) or None`. `S2` is a multiple of 8 when not None. Default: ``None``.
|
|
1975
|
+
padding_mask (Tensor, optional): Reserved parameter. Not implemented yet. Default: ``None``.
|
|
1976
|
+
attn_mask (Tensor, optional): The attention mask tensor of bool or uint8. For each element, 0/False
|
|
1977
|
+
indicates retention and 1/True indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
|
|
1978
|
+
:math:`(B, 1, S1, S2)`, :math:`(S1, S2)` or :math:`(2048, 2048)`.
|
|
1979
|
+
Default: ``None``.
|
|
1980
|
+
|
|
1981
|
+
- In compression scenario, `sparse_mode` is 2, 3, or 4, `attn_mask` must be :math:`(2048, 2048)`.
|
|
1982
|
+
- When `sparse_mode` is 5, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`.
|
|
1983
|
+
- When `sparse_mode` is 0 and 1, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`,
|
|
1984
|
+
:math:`(S1, S2)`.
|
|
1985
|
+
|
|
1986
|
+
prefix (Union[Tensor, tuple[int], list[int]], optional): N value of each Batch in the prefix sparse calculation
|
|
1987
|
+
scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when `sparse_mode` is 5.
|
|
1988
|
+
Default: ``None``.
|
|
1933
1989
|
If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
|
|
1934
|
-
actual_seq_qlen (Union[
|
|
1935
|
-
with increasing values and the last value equal to T1.
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1990
|
+
actual_seq_qlen (Union[Tensor, tuple[int], list[int]], optional): Size of query corresponding to each batch,
|
|
1991
|
+
array with increasing values and the last value equal to T1.
|
|
1992
|
+
Default: ``None``.
|
|
1993
|
+
actual_seq_kvlen (Union[Tensor, tuple[int], list[int]], optional): Size of key and value corresponding
|
|
1994
|
+
to each batch, array with increasing values and the last value equal to T2.
|
|
1995
|
+
Default: ``None``.
|
|
1996
|
+
keep_prob (double, optional): The keep probability of dropout. Value range is (0.0, 1.0]. When `keep_prob`
|
|
1997
|
+
is 1.0, `drop_mask` should be None.
|
|
1998
|
+
Default: ``1.0``.
|
|
1999
|
+
scale_value (double, optional): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5).
|
|
2000
|
+
Default: ``1.0``.
|
|
2001
|
+
pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
|
|
2002
|
+
When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect.
|
|
2003
|
+
Default: ``2147483647``.
|
|
2004
|
+
next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
|
|
2005
|
+
When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect. Default: ``2147483647``.
|
|
2006
|
+
The value of `pre_tokens` corresponds to S1, and the value of `next_tokens` corresponds to S2.
|
|
2007
|
+
They define the valid area on the `attn_mask` matrix. It must ensure that the band is not empty.
|
|
1947
2008
|
The following values are not allowed:
|
|
1948
2009
|
|
|
1949
2010
|
- pre_tokens < 0 and next_tokens < 0.
|
|
1950
2011
|
- (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
|
|
1951
2012
|
- (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
|
|
1952
2013
|
|
|
1953
|
-
inner_precise (int): The parameter is reserved and not implemented yet. Default
|
|
1954
|
-
input_layout (str): Specifies the layout of input `query`, key and value
|
|
1955
|
-
"SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH"
|
|
2014
|
+
inner_precise (int, optional): The parameter is reserved and not implemented yet. Default:``0``.
|
|
2015
|
+
input_layout (str, optional): Specifies the layout of input `query`, `key` and `value`. The value can be
|
|
2016
|
+
"BSH", "BNSD", "SBH", "BSND" or "TND". "TND" is an experimental format. Default: ``"BSH"``.
|
|
1956
2017
|
When input_layout is "TND", the following restrictions must be met.
|
|
1957
|
-
|
|
2018
|
+
Assume there are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
|
|
1958
2019
|
value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
|
|
1959
2020
|
list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
|
|
1960
2021
|
sum(list_seq_k) = 22.
|
|
1961
2022
|
max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
|
|
1962
2023
|
qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
|
|
1963
2024
|
|
|
1964
|
-
- The lengths of two lists
|
|
1965
|
-
|
|
2025
|
+
- The lengths of two lists must be the same, and size of list is batch. batch is less than or equal to
|
|
2026
|
+
1024.
|
|
2027
|
+
- When `input_layout` is "TND", `actual_seq_qlen` and `actual_seq_kvlen` must be not none.
|
|
1966
2028
|
Otherwise, they are none.
|
|
1967
|
-
- The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
|
|
2029
|
+
- The `actual_seq_qlen` and `actual_seq_kvlen` are the cumulative sum of sequence of key/value, so they must
|
|
1968
2030
|
be non-decreasing.
|
|
1969
|
-
- If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
|
|
1970
|
-
list_seq_k is greater than 1024.
|
|
1971
|
-
S2 is equal to max_seqlen_k.
|
|
1972
|
-
-
|
|
1973
|
-
should be
|
|
1974
|
-
- The shape of drop_mask is (
|
|
1975
|
-
-
|
|
1976
|
-
-
|
|
1977
|
-
- When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
|
|
2031
|
+
- If `real_shift` is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
|
|
2032
|
+
list_seq_k is greater than 1024. `real_shift` should be :math:`(B, N1, 1024, S2)` and
|
|
2033
|
+
:math:`(1, N1, 1024, S2)`, and S2 is equal to max_seqlen_k.
|
|
2034
|
+
- `attn_mask` must be a lower trianglar matrix, so `sparse_mode` should be 2 or 3. The shape of `attn_mask`
|
|
2035
|
+
should be :math:`(2048, 2048)`.
|
|
2036
|
+
- The shape of `drop_mask` is :math:`(qk\_pointer * N1 // 8,)`.
|
|
2037
|
+
- `prefix` is none.
|
|
2038
|
+
- `next_tokens` is 0, and `pre_tokens` is not less than max_seqlen_q.
|
|
2039
|
+
- When `sparse_mode` is 3, S1 of each batch should be less than or equal to S2.
|
|
1978
2040
|
- 0 should not exist in list_seq_k.
|
|
1979
2041
|
|
|
1980
|
-
sparse_mode (int): Indicates sparse mode. Default 0
|
|
2042
|
+
sparse_mode (int, optional): Indicates sparse mode. Default: ``0``.
|
|
1981
2043
|
|
|
1982
|
-
- 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
|
|
1983
|
-
and
|
|
1984
|
-
attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between
|
|
1985
|
-
|
|
1986
|
-
- 1: Represents allMask, that is, passing in the complete attn_mask matrix.
|
|
2044
|
+
- 0: Indicates the defaultMask mode. If `attn_mask` is not passed, the mask operation is not performed,
|
|
2045
|
+
`next_tokens` and `pre_tokens` (internally assigned as INT_MAX) are ignored. If passed in, the full
|
|
2046
|
+
`attn_mask` matrix (S1 * S2) needs to be passed in, indicating that the part between `next_tokens` and
|
|
2047
|
+
`pre_tokens` needs to be calculated.
|
|
2048
|
+
- 1: Represents allMask, that is, passing in the complete `attn_mask` matrix.
|
|
1987
2049
|
- 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
|
|
1988
|
-
vertex, and the optimized attn_mask matrix (2048*2048) is required.
|
|
2050
|
+
vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
1989
2051
|
- 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
|
|
1990
|
-
right vertex, and the optimized attn_mask matrix (2048*2048) is required.
|
|
1991
|
-
- 4: Represents the band scenario, that is, the part between counting
|
|
1992
|
-
optimized attn_mask matrix (2048*2048) is required.
|
|
2052
|
+
right vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
2053
|
+
- 4: Represents the band scenario, that is, the part between counting `next_tokens` and `pre_tokens`,
|
|
2054
|
+
and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
1993
2055
|
- 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
|
|
1994
2056
|
width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
|
|
1995
2057
|
of each Batch axis is different, not implemented yet.
|
|
@@ -1998,8 +2060,27 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1998
2060
|
- 8: Represents the block_local scenario, not implemented yet.
|
|
1999
2061
|
|
|
2000
2062
|
Returns:
|
|
2001
|
-
attention_out (Tensor
|
|
2002
|
-
|
|
2063
|
+
attention_out (Tensor) - The output of attention, it has the same shape and dtype as `query`.
|
|
2064
|
+
|
|
2065
|
+
Raises:
|
|
2066
|
+
TypeError: Dtype of `query` is not float16 or bfloat16.
|
|
2067
|
+
TypeError: `query`, `key` and `value` don't have the same dtype.
|
|
2068
|
+
TypeError: Dtype of `attn_mask` is not bool or uint8.
|
|
2069
|
+
TypeError: Dtype of `real_shift` has a different dtype as `query`.
|
|
2070
|
+
TypeError: `scale_value` or `keep_prob` is not a double number.
|
|
2071
|
+
TypeError: `input_layout` is not a string.
|
|
2072
|
+
TypeError: `num_key_value_heads` is not an int.
|
|
2073
|
+
TypeError: `sparse_mode` is not an int.
|
|
2074
|
+
TypeError: `real_shift` is not Tensor type.
|
|
2075
|
+
TypeError: `drop_mask` is not Tensor type.
|
|
2076
|
+
TypeError: `padding_mask` is not Tensor type.
|
|
2077
|
+
TypeError: `attn_mask` is not Tensor type.
|
|
2078
|
+
ValueError: `input_layout` is a string but not valid.
|
|
2079
|
+
RuntimeError: `head_num` is not divisible by `N2`.
|
|
2080
|
+
RuntimeError: `head_num` is not greater than 0.
|
|
2081
|
+
RuntimeError: `attn_mask` shape is not valid.
|
|
2082
|
+
RuntimeError: The specified value of `sparse_mode` is invalid.
|
|
2083
|
+
RuntimeError: D-axis of `query`, `key` and `value` is not the same.
|
|
2003
2084
|
|
|
2004
2085
|
Supported Platforms:
|
|
2005
2086
|
``Ascend``
|