mindspore 2.5.0__cp310-cp310-win_amd64.whl → 2.6.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +6 -4
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -33
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +19 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +22 -3
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +24 -193
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +97 -74
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +16 -11
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +4 -4
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +4 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -1
- mindspore/common/_stub_tensor.py +5 -10
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +1915 -3287
- mindspore/common/api.py +341 -354
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +5 -2
- mindspore/common/dump.py +7 -5
- mindspore/common/file_system.py +3 -0
- mindspore/common/hook_handle.py +5 -3
- mindspore/common/initializer.py +10 -6
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +2 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +106 -39
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +297 -714
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +47 -2
- mindspore/communication/comm_func.py +70 -53
- mindspore/communication/management.py +83 -17
- mindspore/context.py +214 -560
- mindspore/dataset/__init__.py +44 -20
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +102 -120
- mindspore/dataset/engine/datasets_audio.py +22 -22
- mindspore/dataset/engine/datasets_standard_format.py +43 -24
- mindspore/dataset/engine/datasets_text.py +78 -85
- mindspore/dataset/engine/datasets_user_defined.py +108 -76
- mindspore/dataset/engine/datasets_vision.py +111 -108
- mindspore/dataset/engine/iterators.py +5 -3
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/samplers.py +279 -57
- mindspore/dataset/engine/serializer_deserializer.py +2 -1
- mindspore/dataset/engine/validators.py +10 -0
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/device_context/ascend/op_debug.py +60 -1
- mindspore/device_context/ascend/op_tuning.py +0 -4
- mindspore/device_manager.py +39 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +22 -26
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +4 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +40 -22
- mindspore/experimental/optim/radam.py +5 -5
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -81
- mindspore/hal/event.py +38 -55
- mindspore/hal/memory.py +93 -144
- mindspore/hal/stream.py +81 -125
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +40 -2
- mindspore/mindrecord/__init__.py +20 -7
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +131 -700
- mindspore/mint/distributed/__init__.py +5 -1
- mindspore/mint/distributed/distributed.py +194 -109
- mindspore/mint/linalg/__init__.py +2 -0
- mindspore/mint/nn/__init__.py +280 -18
- mindspore/mint/nn/functional.py +282 -64
- mindspore/mint/nn/layer/__init__.py +4 -0
- mindspore/mint/nn/layer/_functions.py +7 -3
- mindspore/mint/nn/layer/activation.py +120 -13
- mindspore/mint/nn/layer/conv.py +218 -24
- mindspore/mint/nn/layer/normalization.py +15 -16
- mindspore/mint/nn/layer/padding.py +1 -1
- mindspore/mint/nn/layer/pooling.py +66 -1
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1250 -176
- mindspore/nn/layer/activation.py +23 -21
- mindspore/nn/layer/basic.py +22 -16
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +22 -17
- mindspore/nn/layer/embedding.py +9 -8
- mindspore/nn/layer/normalization.py +48 -42
- mindspore/nn/layer/pooling.py +75 -31
- mindspore/nn/layer/transformer.py +11 -10
- mindspore/nn/learning_rate_schedule.py +4 -2
- mindspore/nn/loss/loss.py +27 -19
- mindspore/nn/optim/ada_grad.py +6 -5
- mindspore/nn/optim/adadelta.py +9 -7
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +16 -12
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +1 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +9 -7
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +178 -117
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +3 -3
- mindspore/numpy/array_creations.py +3 -3
- mindspore/numpy/array_ops.py +1 -1
- mindspore/numpy/math_ops.py +4 -4
- mindspore/numpy/utils.py +1 -2
- mindspore/numpy/utils_const.py +1 -2
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +3 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +18 -3
- mindspore/ops/_grad_experimental/grad_debug_ops.py +8 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -4
- mindspore/ops/_vmap/vmap_array_ops.py +7 -6
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +2 -1
- mindspore/ops/_vmap/vmap_math_ops.py +4 -7
- mindspore/ops/_vmap/vmap_nn_ops.py +9 -8
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +102 -49
- mindspore/ops/auto_generate/gen_extend_func.py +281 -135
- mindspore/ops/auto_generate/gen_ops_def.py +2574 -2326
- mindspore/ops/auto_generate/gen_ops_prim.py +8566 -2755
- mindspore/ops/auto_generate/pyboost_inner_prim.py +106 -76
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +19 -24
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -3
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +28 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +1629 -2345
- mindspore/ops/function/clip_func.py +38 -45
- mindspore/ops/function/debug_func.py +36 -44
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +46 -78
- mindspore/ops/function/math_func.py +3035 -3705
- mindspore/ops/function/nn_func.py +676 -241
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +17 -30
- mindspore/ops/function/random_func.py +204 -361
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +5 -5
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +24 -17
- mindspore/ops/functional.py +6 -4
- mindspore/ops/functional_overload.py +547 -4
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +10 -5
- mindspore/ops/operations/_custom_ops_utils.py +247 -0
- mindspore/ops/operations/_grad_ops.py +1 -10
- mindspore/ops/operations/_inner_ops.py +5 -76
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +37 -22
- mindspore/ops/operations/comm_ops.py +150 -107
- mindspore/ops/operations/custom_ops.py +221 -23
- mindspore/ops/operations/debug_ops.py +115 -16
- mindspore/ops/operations/inner_ops.py +1 -1
- mindspore/ops/operations/linalg_ops.py +1 -58
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +746 -79
- mindspore/ops/operations/math_ops.py +21 -18
- mindspore/ops/operations/nn_ops.py +65 -191
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +43 -32
- mindspore/ops/tensor_method.py +232 -13
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/{aclnn_kernel_register_auto_cc_generator.py → aclnn/aclnn_kernel_register_auto_cc_generator.py} +43 -18
- mindspore/ops_generate/{gen_aclnn_implement.py → aclnn/gen_aclnn_implement.py} +49 -51
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/{add_tensor_docs_generator.py → api/add_tensor_docs_generator.py} +9 -7
- mindspore/ops_generate/{cpp_create_prim_instance_helper_generator.py → api/cpp_create_prim_instance_helper_generator.py} +6 -9
- mindspore/ops_generate/{functional_map_cpp_generator.py → api/functional_map_cpp_generator.py} +25 -12
- mindspore/ops_generate/{functional_overload_py_generator.py → api/functional_overload_py_generator.py} +8 -6
- mindspore/ops_generate/{functions_cc_generator.py → api/functions_cc_generator.py} +14 -10
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/{op_api_proto.py → api/op_api_proto.py} +98 -69
- mindspore/ops_generate/{tensor_func_reg_cpp_generator.py → api/tensor_func_reg_cpp_generator.py} +82 -43
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/{gen_utils.py → common/gen_utils.py} +72 -19
- mindspore/ops_generate/{op_proto.py → common/op_proto.py} +64 -1
- mindspore/ops_generate/{template.py → common/template.py} +96 -84
- mindspore/ops_generate/gen_ops.py +23 -325
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/{lite_ops_cpp_generator.py → op_def/lite_ops_cpp_generator.py} +47 -11
- mindspore/ops_generate/{ops_def_cc_generator.py → op_def/ops_def_cc_generator.py} +18 -7
- mindspore/ops_generate/{ops_def_h_generator.py → op_def/ops_def_h_generator.py} +5 -5
- mindspore/ops_generate/{ops_name_h_generator.py → op_def/ops_name_h_generator.py} +30 -15
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/{op_def_py_generator.py → op_def_py/op_def_py_generator.py} +6 -5
- mindspore/ops_generate/{op_prim_py_generator.py → op_def_py/op_prim_py_generator.py} +24 -15
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/{auto_grad_impl_cc_generator.py → pyboost/auto_grad_impl_cc_generator.py} +11 -7
- mindspore/ops_generate/{auto_grad_reg_cc_generator.py → pyboost/auto_grad_reg_cc_generator.py} +7 -7
- mindspore/ops_generate/{gen_pyboost_func.py → pyboost/gen_pyboost_func.py} +40 -16
- mindspore/ops_generate/{op_template_parser.py → pyboost/op_template_parser.py} +105 -24
- mindspore/ops_generate/{pyboost_functions_cpp_generator.py → pyboost/pyboost_functions_cpp_generator.py} +55 -18
- mindspore/ops_generate/{pyboost_functions_h_generator.py → pyboost/pyboost_functions_h_generator.py} +42 -10
- mindspore/ops_generate/{pyboost_functions_py_generator.py → pyboost/pyboost_functions_py_generator.py} +6 -6
- mindspore/ops_generate/{pyboost_grad_function_cpp_generator.py → pyboost/pyboost_grad_function_cpp_generator.py} +11 -10
- mindspore/ops_generate/{pyboost_inner_prim_generator.py → pyboost/pyboost_inner_prim_generator.py} +8 -7
- mindspore/ops_generate/{pyboost_native_grad_functions_generator.py → pyboost/pyboost_native_grad_functions_generator.py} +14 -10
- mindspore/ops_generate/{pyboost_op_cpp_code_generator.py → pyboost/pyboost_op_cpp_code_generator.py} +140 -53
- mindspore/ops_generate/{pyboost_overload_functions_cpp_generator.py → pyboost/pyboost_overload_functions_cpp_generator.py} +28 -15
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +88 -4
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +6 -2
- mindspore/parallel/_auto_parallel_context.py +133 -6
- mindspore/parallel/_cell_wrapper.py +130 -15
- mindspore/parallel/_parallel_serialization.py +95 -4
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +198 -25
- mindspore/parallel/algo_parameter_config.py +3 -3
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +656 -37
- mindspore/parallel/cluster/process_entity/_api.py +151 -19
- mindspore/parallel/cluster/run.py +1 -1
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +259 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +24 -13
- mindspore/parallel/shard.py +137 -61
- mindspore/parallel/transform_safetensors.py +287 -95
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +9 -5
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +6 -2
- mindspore/profiler/analysis/parser/ms_framework_parser.py +4 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -4
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +22 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +241 -86
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +41 -2
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +33 -35
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +7 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +8 -3
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +141 -30
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +5 -6
- mindspore/profiler/common/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/common/constant.py +12 -0
- mindspore/profiler/common/msprof_cmd_tool.py +42 -23
- mindspore/profiler/common/path_manager.py +24 -0
- mindspore/profiler/common/profiler_context.py +26 -2
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_parameters.py +59 -18
- mindspore/profiler/common/profiler_path_manager.py +66 -7
- mindspore/profiler/dynamic_profiler.py +112 -79
- mindspore/profiler/envprofiler.py +26 -1
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +57 -14
- mindspore/profiler/platform/npu_profiler.py +33 -7
- mindspore/profiler/profiler.py +541 -45
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +4 -0
- mindspore/profiler/schedule.py +57 -22
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +25 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +2 -2
- mindspore/runtime/executor.py +40 -11
- mindspore/runtime/memory.py +25 -8
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +35 -7
- mindspore/train/amp.py +1 -1
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +24 -40
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_flops_collector.py +2 -3
- mindspore/train/callback/_history.py +7 -4
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +8 -13
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +179 -103
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +4 -5
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +8 -6
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +19 -12
- mindspore/train/model.py +176 -103
- mindspore/train/serialization.py +246 -988
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +4 -2
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/METADATA +2 -1
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/RECORD +483 -438
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -136
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_constants.py +0 -190
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/ops_primitive_h_generator.py +0 -81
- /mindspore/ops_generate/{base_generator.py → common/base_generator.py} +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.5.0.dist-info → mindspore-2.6.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2023-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -23,20 +23,21 @@ import numpy as np
|
|
|
23
23
|
from mindspore.ops import signature as sig
|
|
24
24
|
from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
|
-
from mindspore.ops.
|
|
26
|
+
from mindspore.ops._utils import arg_handler as handler
|
|
27
|
+
from mindspore.ops._utils.arg_dtype_cast import DtypeToEnum
|
|
27
28
|
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
28
29
|
from mindspore.common._stub_tensor import _convert_stub
|
|
29
30
|
from mindspore._c_expression import typing
|
|
30
|
-
from mindspore._c_expression import
|
|
31
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
31
32
|
from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones, pyboost_type_as
|
|
32
33
|
from mindspore.common import dtype as mstype
|
|
33
34
|
from mindspore.common._utils import is_shape_unknown
|
|
34
35
|
from mindspore import _checkparam as validator
|
|
35
36
|
from mindspore.ops.operations.manually_defined._inner import ScalarCast
|
|
36
|
-
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
|
|
37
37
|
from mindspore.common.initializer import Zero
|
|
38
38
|
from mindspore.common.parameter import Parameter
|
|
39
|
-
from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
|
|
39
|
+
from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore, FusedInferAttentionScore
|
|
40
|
+
from mindspore.common.jit_context import jit_context
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
dtype_to_type_id = DtypeToEnum()
|
|
@@ -527,6 +528,64 @@ class ScalarBool(Primitive):
|
|
|
527
528
|
return bool(x)
|
|
528
529
|
|
|
529
530
|
|
|
531
|
+
class ScalarMax(Primitive):
|
|
532
|
+
r"""
|
|
533
|
+
Return the maximum of two input scalars.
|
|
534
|
+
|
|
535
|
+
.. note::
|
|
536
|
+
The inputs can be constant/variable value. Usage is the same as 'max' in Python.
|
|
537
|
+
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
|
538
|
+
|
|
539
|
+
Inputs:
|
|
540
|
+
- **x** (Scalar) - A constant or variable scalar.
|
|
541
|
+
- **y** (Scalar) - A constant or variable scalar.
|
|
542
|
+
|
|
543
|
+
Outputs:
|
|
544
|
+
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
|
545
|
+
|
|
546
|
+
Raises:
|
|
547
|
+
TypeError: If `x` and `y` are not scalar.
|
|
548
|
+
|
|
549
|
+
Supported Platforms:
|
|
550
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
551
|
+
"""
|
|
552
|
+
@prim_attr_register
|
|
553
|
+
def __init__(self):
|
|
554
|
+
"""Initialize ScalarMax"""
|
|
555
|
+
|
|
556
|
+
def __call__(self, x, y):
|
|
557
|
+
return max(x, y)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class ScalarMin(Primitive):
|
|
561
|
+
r"""
|
|
562
|
+
Return the minimum of two input scalars.
|
|
563
|
+
|
|
564
|
+
.. note::
|
|
565
|
+
The inputs can be constant/variable value. Usage is the same as 'min' in Python.
|
|
566
|
+
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
|
567
|
+
|
|
568
|
+
Inputs:
|
|
569
|
+
- **x** (Scalar) - A constant or variable scalar.
|
|
570
|
+
- **y** (Scalar) - A constant or variable scalar.
|
|
571
|
+
|
|
572
|
+
Outputs:
|
|
573
|
+
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
|
574
|
+
|
|
575
|
+
Raises:
|
|
576
|
+
TypeError: If `x` and `y` are not scalar.
|
|
577
|
+
|
|
578
|
+
Supported Platforms:
|
|
579
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
580
|
+
"""
|
|
581
|
+
@prim_attr_register
|
|
582
|
+
def __init__(self):
|
|
583
|
+
"""Initialize ScalarMin"""
|
|
584
|
+
|
|
585
|
+
def __call__(self, x, y):
|
|
586
|
+
return min(x, y)
|
|
587
|
+
|
|
588
|
+
|
|
530
589
|
scalar_div = ScalarDiv()
|
|
531
590
|
scalar_mod = ScalarMod()
|
|
532
591
|
scalar_add = ScalarAdd()
|
|
@@ -543,6 +602,8 @@ scalar_log = ScalarLog()
|
|
|
543
602
|
scalar_pow = ScalarPow()
|
|
544
603
|
scalar_uadd = ScalarUadd()
|
|
545
604
|
scalar_usub = ScalarUsub()
|
|
605
|
+
scalar_max = ScalarMax()
|
|
606
|
+
scalar_min = ScalarMin()
|
|
546
607
|
|
|
547
608
|
|
|
548
609
|
class BatchNorm(Primitive):
|
|
@@ -570,13 +631,16 @@ class BatchNorm(Primitive):
|
|
|
570
631
|
- For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
|
|
571
632
|
|
|
572
633
|
Args:
|
|
573
|
-
is_training (bool): If `is_training` is ``True`` ,
|
|
634
|
+
is_training (bool, optional): If `is_training` is ``True`` ,
|
|
635
|
+
`mean` and `variance` are computed during training.
|
|
574
636
|
If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
|
|
575
|
-
epsilon (float): A small value added for numerical stability.
|
|
576
|
-
|
|
637
|
+
epsilon (float, optional): A small value added for numerical stability.
|
|
638
|
+
Default: ``1e-5``, value must be (0, 1] .
|
|
639
|
+
momentum (float, optional): The hyper parameter to compute moving average for running_mean and running_var
|
|
577
640
|
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
|
578
641
|
Momentum value must be [0, 1]. Default: ``0.1`` .
|
|
579
|
-
data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
|
|
642
|
+
data_format (str, optional): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
|
|
643
|
+
and the ``'NHWC'`` format
|
|
580
644
|
is only supported in GPU target. Default: ``"NCHW"`` .
|
|
581
645
|
|
|
582
646
|
Inputs:
|
|
@@ -788,29 +852,21 @@ class Rank(Primitive):
|
|
|
788
852
|
|
|
789
853
|
def rank(input_x):
|
|
790
854
|
"""
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
|
|
794
|
-
is the number of indices required to uniquely select each element of the tensor.
|
|
855
|
+
Return the rank of a tensor.
|
|
795
856
|
|
|
796
857
|
Args:
|
|
797
|
-
input_x (Tensor): The
|
|
858
|
+
input_x (Tensor): The input tensor.
|
|
798
859
|
|
|
799
860
|
Returns:
|
|
800
|
-
Tensor
|
|
801
|
-
|
|
802
|
-
Raises:
|
|
803
|
-
TypeError: If `input_x` is not a Tensor.
|
|
861
|
+
Tensor
|
|
804
862
|
|
|
805
863
|
Supported Platforms:
|
|
806
864
|
``Ascend`` ``GPU`` ``CPU``
|
|
807
865
|
|
|
808
866
|
Examples:
|
|
809
867
|
>>> import mindspore
|
|
810
|
-
>>>
|
|
811
|
-
>>>
|
|
812
|
-
>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
|
|
813
|
-
>>> output = ops.rank(input_tensor)
|
|
868
|
+
>>> input_tensor = mindspore.tensor([[2, 2], [2, 2]], mindspore.float32)
|
|
869
|
+
>>> output = mindspore.ops.rank(input_tensor)
|
|
814
870
|
>>> print(output)
|
|
815
871
|
2
|
|
816
872
|
>>> print(type(output))
|
|
@@ -932,10 +988,6 @@ class Tile(Primitive):
|
|
|
932
988
|
|
|
933
989
|
Refer to :func:`mindspore.ops.tile` for more details.
|
|
934
990
|
|
|
935
|
-
Note:
|
|
936
|
-
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
937
|
-
where more than 4 dimensions are repeated simultaneously.
|
|
938
|
-
|
|
939
991
|
Inputs:
|
|
940
992
|
- **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
941
993
|
:math:`(x_1, x_2, ..., x_S)` .
|
|
@@ -943,6 +995,10 @@ class Tile(Primitive):
|
|
|
943
995
|
the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
|
|
944
996
|
Only constant value is allowed.
|
|
945
997
|
|
|
998
|
+
.. note::
|
|
999
|
+
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1000
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
1001
|
+
|
|
946
1002
|
Outputs:
|
|
947
1003
|
Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
|
|
948
1004
|
the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
|
|
@@ -999,7 +1055,16 @@ class Tile(Primitive):
|
|
|
999
1055
|
"""Initialize."""
|
|
1000
1056
|
|
|
1001
1057
|
def __call__(self, input, dims):
|
|
1002
|
-
|
|
1058
|
+
# Add for jit context.
|
|
1059
|
+
if jit_context() and jit_context().compiled:
|
|
1060
|
+
return None
|
|
1061
|
+
res = _convert_stub(pyboost_tile(self, [input, dims]))
|
|
1062
|
+
# Add for jit context.
|
|
1063
|
+
if jit_context():
|
|
1064
|
+
if validator.is_stub_tensor(res):
|
|
1065
|
+
res = res.stub_sync()
|
|
1066
|
+
return jit_context().run_op(self, res, input, dims)
|
|
1067
|
+
return res
|
|
1003
1068
|
|
|
1004
1069
|
# pylint: disable=missing-docstring
|
|
1005
1070
|
def check_elim(self, *args):
|
|
@@ -1020,26 +1085,14 @@ class Tile(Primitive):
|
|
|
1020
1085
|
|
|
1021
1086
|
def tile(input, dims):
|
|
1022
1087
|
r"""
|
|
1023
|
-
Creates a new tensor by repeating
|
|
1024
|
-
|
|
1088
|
+
Creates a new tensor by repeating the elements in the input tensor `dims` times.
|
|
1089
|
+
|
|
1090
|
+
The i'th dimension of output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
|
|
1025
1091
|
are repeated `dims[i]` times along the i'th dimension.
|
|
1026
1092
|
|
|
1027
1093
|
Note:
|
|
1028
|
-
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
Args:
|
|
1032
|
-
input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
1033
|
-
:math:`(x_1, x_2, ..., x_S)` .
|
|
1034
|
-
|
|
1035
|
-
dims (tuple[int]): The parameter that specifies the number of replications,
|
|
1036
|
-
the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
|
|
1037
|
-
Only constant value is allowed.
|
|
1038
|
-
|
|
1039
|
-
Returns:
|
|
1040
|
-
Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
|
|
1041
|
-
the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
|
|
1042
|
-
|
|
1094
|
+
- On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1095
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
1043
1096
|
- If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
|
|
1044
1097
|
the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
|
|
1045
1098
|
- If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
|
|
@@ -1050,40 +1103,39 @@ def tile(input, dims):
|
|
|
1050
1103
|
`dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
|
|
1051
1104
|
can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
|
|
1052
1105
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1106
|
+
Args:
|
|
1107
|
+
input (Tensor): The input tensor.
|
|
1108
|
+
dims (tuple[int]): The specified number of repetitions in each dimension.
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
Tensor
|
|
1056
1112
|
|
|
1057
1113
|
Supported Platforms:
|
|
1058
1114
|
``Ascend`` ``GPU`` ``CPU``
|
|
1059
1115
|
|
|
1060
1116
|
Examples:
|
|
1061
1117
|
>>> import mindspore
|
|
1062
|
-
>>>
|
|
1063
|
-
>>>
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
[3
|
|
1077
|
-
|
|
1078
|
-
[3
|
|
1079
|
-
[1
|
|
1080
|
-
[3
|
|
1081
|
-
|
|
1082
|
-
[3
|
|
1083
|
-
[1. 2. 1. 2.]
|
|
1084
|
-
[3. 4. 3. 4.]
|
|
1085
|
-
[1. 2. 1. 2.]
|
|
1086
|
-
[3. 4. 3. 4.]]]
|
|
1118
|
+
>>> input = mindspore.tensor([[1, 2], [3, 4]])
|
|
1119
|
+
>>> mindspore.ops.tile(input, (2, 3))
|
|
1120
|
+
Tensor(shape=[4, 6], dtype=Int64, value=
|
|
1121
|
+
[[1, 2, 1, 2, 1, 2],
|
|
1122
|
+
[3, 4, 3, 4, 3, 4],
|
|
1123
|
+
[1, 2, 1, 2, 1, 2],
|
|
1124
|
+
[3, 4, 3, 4, 3, 4]])
|
|
1125
|
+
>>> mindspore.ops.tile(input, (2, 3, 2))
|
|
1126
|
+
Tensor(shape=[2, 6, 4], dtype=Int64, value=
|
|
1127
|
+
[[[1, 2, 1, 2],
|
|
1128
|
+
[3, 4, 3, 4],
|
|
1129
|
+
[1, 2, 1, 2],
|
|
1130
|
+
[3, 4, 3, 4],
|
|
1131
|
+
[1, 2, 1, 2],
|
|
1132
|
+
[3, 4, 3, 4]],
|
|
1133
|
+
[[1, 2, 1, 2],
|
|
1134
|
+
[3, 4, 3, 4],
|
|
1135
|
+
[1, 2, 1, 2],
|
|
1136
|
+
[3, 4, 3, 4],
|
|
1137
|
+
[1, 2, 1, 2],
|
|
1138
|
+
[3, 4, 3, 4]]])
|
|
1087
1139
|
"""
|
|
1088
1140
|
tile_op = _get_cache_prim(Tile)()
|
|
1089
1141
|
return tile_op(input, dims)
|
|
@@ -1176,10 +1228,19 @@ class Cast(Primitive):
|
|
|
1176
1228
|
return (False, None)
|
|
1177
1229
|
|
|
1178
1230
|
def __call__(self, input_x, dtype):
|
|
1231
|
+
# Add for jit context.
|
|
1232
|
+
if jit_context() and jit_context().compiled:
|
|
1233
|
+
return None
|
|
1179
1234
|
should_elim, output = self.check_elim(input_x, dtype)
|
|
1180
1235
|
if should_elim:
|
|
1181
1236
|
return output
|
|
1182
|
-
|
|
1237
|
+
res = _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
|
|
1238
|
+
# Add for jit context.
|
|
1239
|
+
if jit_context():
|
|
1240
|
+
if validator.is_stub_tensor(res):
|
|
1241
|
+
res = res.stub_sync()
|
|
1242
|
+
return jit_context().run_op(self, res, input_x, dtype)
|
|
1243
|
+
return res
|
|
1183
1244
|
|
|
1184
1245
|
|
|
1185
1246
|
class TypeAs(Primitive):
|
|
@@ -1552,15 +1613,52 @@ def infer_value_for_Tile(input, dims):
|
|
|
1552
1613
|
return Tensor(np.tile(input.asnumpy(), dims))
|
|
1553
1614
|
|
|
1554
1615
|
|
|
1616
|
+
def infer_value_for_EqualExt(x, y):
|
|
1617
|
+
"""Infer value for EqualExt op."""
|
|
1618
|
+
if x is None or y is None:
|
|
1619
|
+
return None
|
|
1620
|
+
result = np.equal(x.asnumpy(), y.asnumpy())
|
|
1621
|
+
value = False
|
|
1622
|
+
if result.all():
|
|
1623
|
+
value = True
|
|
1624
|
+
return Tensor(value)
|
|
1625
|
+
|
|
1626
|
+
|
|
1555
1627
|
def infer_value_for_Concat(tensors, axis):
|
|
1556
1628
|
"""Infer value for Concat op."""
|
|
1557
1629
|
if not tensors or None in tensors or axis is None:
|
|
1558
1630
|
return None
|
|
1559
1631
|
|
|
1560
|
-
tensor_to_concat = [x.asnumpy()
|
|
1632
|
+
tensor_to_concat = [x.asnumpy() for x in tensors]
|
|
1561
1633
|
return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
|
|
1562
1634
|
|
|
1563
1635
|
|
|
1636
|
+
def infer_value_for_GatherD(input, dim, index):
|
|
1637
|
+
"""Infer value for GatherD op."""
|
|
1638
|
+
if input is None or dim is None or index is None:
|
|
1639
|
+
return None
|
|
1640
|
+
|
|
1641
|
+
input_np = input.asnumpy()
|
|
1642
|
+
index_np = index.asnumpy()
|
|
1643
|
+
|
|
1644
|
+
index_shape = index_np.shape
|
|
1645
|
+
multi_index = [np.indices(index_shape)[i] for i in range(len(index_shape))]
|
|
1646
|
+
multi_index[dim] = index_np
|
|
1647
|
+
|
|
1648
|
+
output = input_np[tuple(multi_index)]
|
|
1649
|
+
return Tensor(output, dtype=input.dtype)
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
def infer_value_for_Softmax(input, axis):
|
|
1653
|
+
"""Infer value for Softmax op."""
|
|
1654
|
+
if input is None or axis is None:
|
|
1655
|
+
return None
|
|
1656
|
+
|
|
1657
|
+
e_input = np.exp(input.asnumpy())
|
|
1658
|
+
output = e_input / np.sum(e_input, axis=axis, keepdims=True)
|
|
1659
|
+
return Tensor(output, dtype=input.dtype)
|
|
1660
|
+
|
|
1661
|
+
|
|
1564
1662
|
def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
|
|
1565
1663
|
"""Infer value for ReduceSum op."""
|
|
1566
1664
|
value = None
|
|
@@ -1608,6 +1706,20 @@ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
|
|
|
1608
1706
|
return value
|
|
1609
1707
|
|
|
1610
1708
|
|
|
1709
|
+
def infer_value_for_Arange(start, end, step, dtype=None):
|
|
1710
|
+
"""Infer value for Arange op."""
|
|
1711
|
+
if start is None or end is None or step is None:
|
|
1712
|
+
return None
|
|
1713
|
+
np_dtype = np.int64
|
|
1714
|
+
if dtype is None:
|
|
1715
|
+
has_float = any(isinstance(i, float) for i in [start, end, step])
|
|
1716
|
+
if has_float:
|
|
1717
|
+
np_dtype = np.float32
|
|
1718
|
+
else:
|
|
1719
|
+
np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
|
|
1720
|
+
return Tensor(np.arange(start, end, step, dtype=np_dtype))
|
|
1721
|
+
|
|
1722
|
+
|
|
1611
1723
|
def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
|
|
1612
1724
|
"""Infer value for Common ReduceExtand op."""
|
|
1613
1725
|
value = None
|
|
@@ -1679,6 +1791,95 @@ def infer_value_for_Cast(x, dst_type_enum=None):
|
|
|
1679
1791
|
return value
|
|
1680
1792
|
|
|
1681
1793
|
|
|
1794
|
+
def infer_value_for_LinalgVectorNorm(input_x, ord, dim, keepdim, dtype):
|
|
1795
|
+
"""Infer value for linalg_vector_norm op.
|
|
1796
|
+
Current version numpy is not support numpy.linalg.vector_norm.
|
|
1797
|
+
So using numpy.linalg.norm.
|
|
1798
|
+
"""
|
|
1799
|
+
if input_x is None or ord is None:
|
|
1800
|
+
return None
|
|
1801
|
+
if ord != 0:
|
|
1802
|
+
out = np.power(np.sum(np.power(np.abs(input_x.asnumpy()), ord), axis=dim, keepdims=keepdim), 1/ord)
|
|
1803
|
+
else:
|
|
1804
|
+
out = np.sum(input_x.asnumpy() != 0, axis=dim, keepdims=keepdim)
|
|
1805
|
+
if dtype is None:
|
|
1806
|
+
return Tensor(out)
|
|
1807
|
+
dtype_for_ms = typing.type_id_to_type(dtype)
|
|
1808
|
+
return Tensor(out, dtype=dtype_for_ms)
|
|
1809
|
+
|
|
1810
|
+
|
|
1811
|
+
def infer_value_for_LpNormV2(input_x, p=2, dim=None, keepdim=False, eps=1e-12):
|
|
1812
|
+
"""Infer value for linalg_vector_norm op.
|
|
1813
|
+
Current version numpy is not support numpy.linalg.vector_norm.
|
|
1814
|
+
So using numpy.linalg.norm.
|
|
1815
|
+
"""
|
|
1816
|
+
if input_x is None:
|
|
1817
|
+
return None
|
|
1818
|
+
return Tensor(np.linalg.norm(input_x.asnumpy(), axis=dim, keepdims=keepdim,
|
|
1819
|
+
ord=p))
|
|
1820
|
+
|
|
1821
|
+
|
|
1822
|
+
def infer_value_for_Svd(input_x, full_matrices, compute_uv):
|
|
1823
|
+
"""Infer value for Svd op."""
|
|
1824
|
+
if input_x is None:
|
|
1825
|
+
return None
|
|
1826
|
+
if bool(compute_uv):
|
|
1827
|
+
s, u, v = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=True)
|
|
1828
|
+
return Tensor(s), Tensor(u), Tensor(v)
|
|
1829
|
+
s = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=False)
|
|
1830
|
+
return Tensor(s), np.zeros(1), np.zeros(1)
|
|
1831
|
+
|
|
1832
|
+
|
|
1833
|
+
def infer_value_for_Div(input_x, other_x):
|
|
1834
|
+
"""Infer value for Div op."""
|
|
1835
|
+
if input_x is None or other_x is None:
|
|
1836
|
+
return None
|
|
1837
|
+
return Tensor(np.true_divide(input_x.asnumpy(), other_x.asnumpy()))
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def infer_value_for_Divs(input_x, other_x):
|
|
1841
|
+
"""Infer value for Divs op."""
|
|
1842
|
+
if input_x is None or other_x is None:
|
|
1843
|
+
return None
|
|
1844
|
+
tmp = np.true_divide(input_x.asnumpy(), other_x)
|
|
1845
|
+
if not input_x.shape:
|
|
1846
|
+
# tensor scalar has a special rule for data type promote
|
|
1847
|
+
if input_x.dtype in (mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.int8, mstype.int16,
|
|
1848
|
+
mstype.int32, mstype.int64):
|
|
1849
|
+
res = Tensor(tmp, dtype=mstype.float32)
|
|
1850
|
+
else:
|
|
1851
|
+
res = Tensor(tmp, dtype=input_x.dtype)
|
|
1852
|
+
else:
|
|
1853
|
+
res = Tensor(tmp)
|
|
1854
|
+
return res
|
|
1855
|
+
|
|
1856
|
+
|
|
1857
|
+
def infer_value_for_DivMod(input_x, other_x, rounding_mode):
|
|
1858
|
+
"""Infer value for DivMod op."""
|
|
1859
|
+
if input_x is None or other_x is None:
|
|
1860
|
+
return None
|
|
1861
|
+
if rounding_mode == 1:
|
|
1862
|
+
# trunc
|
|
1863
|
+
return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x.asnumpy())))
|
|
1864
|
+
if rounding_mode == 2:
|
|
1865
|
+
# floor
|
|
1866
|
+
return Tensor(np.floor_divide(input_x.asnumpy(), other_x.asnumpy()))
|
|
1867
|
+
return None
|
|
1868
|
+
|
|
1869
|
+
|
|
1870
|
+
def infer_value_for_DivMods(input_x, other_x, rounding_mode):
|
|
1871
|
+
"""Infer value for DivMods op."""
|
|
1872
|
+
if input_x is None or other_x is None:
|
|
1873
|
+
return None
|
|
1874
|
+
if rounding_mode == 1:
|
|
1875
|
+
# trunc
|
|
1876
|
+
return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x)))
|
|
1877
|
+
if rounding_mode == 2:
|
|
1878
|
+
# floor
|
|
1879
|
+
return Tensor(np.floor_divide(input_x.asnumpy(), other_x))
|
|
1880
|
+
return None
|
|
1881
|
+
|
|
1882
|
+
|
|
1682
1883
|
def infer_value_for_ReduceMax(input_x, axis, keep_dims):
|
|
1683
1884
|
"""Infer value for ReduceMax op."""
|
|
1684
1885
|
return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
|
|
@@ -1867,8 +2068,18 @@ class Ones(Primitive):
|
|
|
1867
2068
|
pass
|
|
1868
2069
|
|
|
1869
2070
|
def __call__(self, size, type=None):
|
|
1870
|
-
|
|
2071
|
+
# Add for jit context.
|
|
2072
|
+
if jit_context() and jit_context().compiled:
|
|
2073
|
+
return None
|
|
2074
|
+
res = _convert_stub(pyboost_ones(self, [size, type if type is None \
|
|
1871
2075
|
else handler.dtype_to_type_id('Ones', 'type', type)]))
|
|
2076
|
+
# Add for jit context.
|
|
2077
|
+
if jit_context():
|
|
2078
|
+
if validator.is_stub_tensor(res):
|
|
2079
|
+
res = res.stub_sync()
|
|
2080
|
+
return jit_context().run_op(self, res, size, type if type is None \
|
|
2081
|
+
else handler.dtype_to_type_id('Ones', 'type', type))
|
|
2082
|
+
return res
|
|
1872
2083
|
|
|
1873
2084
|
|
|
1874
2085
|
class Zeros(Primitive):
|
|
@@ -1917,8 +2128,18 @@ class Zeros(Primitive):
|
|
|
1917
2128
|
pass
|
|
1918
2129
|
|
|
1919
2130
|
def __call__(self, size, type=None):
|
|
1920
|
-
|
|
2131
|
+
# Add for jit context.
|
|
2132
|
+
if jit_context() and jit_context().compiled:
|
|
2133
|
+
return None
|
|
2134
|
+
res = _convert_stub(pyboost_zeros(self, [size, type if type is None else \
|
|
1921
2135
|
handler.dtype_to_type_id('Zeros', 'type', type)]))
|
|
2136
|
+
# Add for jit context.
|
|
2137
|
+
if jit_context():
|
|
2138
|
+
if validator.is_stub_tensor(res):
|
|
2139
|
+
res = res.stub_sync()
|
|
2140
|
+
return jit_context().run_op(self, res, size, type if type is None else \
|
|
2141
|
+
handler.dtype_to_type_id('Zeros', 'type', type))
|
|
2142
|
+
return res
|
|
1922
2143
|
|
|
1923
2144
|
|
|
1924
2145
|
def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
|
|
@@ -1996,7 +2217,7 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1996
2217
|
keep_prob (double, optional): The keep probability of dropout. Value range is (0.0, 1.0]. When `keep_prob`
|
|
1997
2218
|
is 1.0, `drop_mask` should be None.
|
|
1998
2219
|
Default: ``1.0``.
|
|
1999
|
-
|
|
2220
|
+
scalar_value (double, optional): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5).
|
|
2000
2221
|
Default: ``1.0``.
|
|
2001
2222
|
pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
|
|
2002
2223
|
When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect.
|
|
@@ -2067,7 +2288,7 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
2067
2288
|
TypeError: `query`, `key` and `value` don't have the same dtype.
|
|
2068
2289
|
TypeError: Dtype of `attn_mask` is not bool or uint8.
|
|
2069
2290
|
TypeError: Dtype of `real_shift` has a different dtype as `query`.
|
|
2070
|
-
TypeError: `
|
|
2291
|
+
TypeError: `scalar_value` or `keep_prob` is not a double number.
|
|
2071
2292
|
TypeError: `input_layout` is not a string.
|
|
2072
2293
|
TypeError: `num_key_value_heads` is not an int.
|
|
2073
2294
|
TypeError: `sparse_mode` is not an int.
|
|
@@ -2104,6 +2325,452 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
2104
2325
|
actual_seq_kvlen)[3]
|
|
2105
2326
|
|
|
2106
2327
|
|
|
2328
|
+
def fused_infer_attention_score(query, key, value, *, pse_shift=None, atten_mask=None, actual_seq_lengths=None,
|
|
2329
|
+
actual_seq_lengths_kv=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None,
|
|
2330
|
+
quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None,
|
|
2331
|
+
key_antiquant_scale=None, key_antiquant_offset=None, value_antiquant_scale=None,
|
|
2332
|
+
value_antiquant_offset=None, block_table=None, query_padding_size=None,
|
|
2333
|
+
kv_padding_size=None, key_shared_prefix=None, value_shared_prefix=None,
|
|
2334
|
+
actual_shared_prefix_len=None, num_heads=1, scale=1.0, pre_tokens=2147483647,
|
|
2335
|
+
next_tokens=2147483647, input_layout='BSH', num_key_value_heads=0, sparse_mode=0,
|
|
2336
|
+
inner_precise=1, block_size=0, antiquant_mode=0, key_antiquant_mode=0,
|
|
2337
|
+
value_antiquant_mode=0, softmax_lse_flag=False):
|
|
2338
|
+
r"""
|
|
2339
|
+
This is a FlashAttention function designed for both incremental and full inference scenarios. It supports full
|
|
2340
|
+
inference scenarios (PromptFlashAttention) as well as incremental inference scenarios (IncreFlashAttention).
|
|
2341
|
+
When the S dimension of the query tensor (Q_S) equals 1, it enters the IncreFlashAttention branch; otherwise,
|
|
2342
|
+
it enters the PromptFlashAttention branch.
|
|
2343
|
+
|
|
2344
|
+
.. math::
|
|
2345
|
+
|
|
2346
|
+
Attention(Q,K,V) = Softmax(\frac{QK^{T}}{\sqrt{d}})V
|
|
2347
|
+
|
|
2348
|
+
.. warning::
|
|
2349
|
+
- This is an experimental API that is subject to change or deletion.
|
|
2350
|
+
- For Ascend, only the Atlas A2 training series products and Atlas 800I A2 inference products are currently
|
|
2351
|
+
supported.
|
|
2352
|
+
|
|
2353
|
+
Note:
|
|
2354
|
+
- The data layout formats of query, key and value can be interpreted from multiple dimensions, as shown below:
|
|
2355
|
+
|
|
2356
|
+
- B, Batch size. Represents the batch size of the input samples.
|
|
2357
|
+
- S, Sequence length. Represents the sequence length of the input samples. S1 represents the sequence length
|
|
2358
|
+
of the query, and S2 represents the sequence length of the key/value.
|
|
2359
|
+
- H, Head size. Represents the size of the hidden layer.
|
|
2360
|
+
- N, Head nums. Represents the number of attention heads.
|
|
2361
|
+
- D, Head dims. Represents the smallest unit size of the hidden layer, satisfying :math:`D = H / N`.
|
|
2362
|
+
|
|
2363
|
+
Args:
|
|
2364
|
+
query (Tensor): The query input of the attention structure, with data type of float16, bfloat16 or int8.
|
|
2365
|
+
Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or :math:`(B, S, N, D)`.
|
|
2366
|
+
key (Union[Tensor, tuple[Tensor], list[Tensor]]): The key input of the attention structure, with data type
|
|
2367
|
+
of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
|
|
2368
|
+
:math:`(B, S, N, D)`.
|
|
2369
|
+
value (Union[Tensor, tuple[Tensor], list[Tensor]]): The value input of the attention structure, with data
|
|
2370
|
+
type of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
|
|
2371
|
+
:math:`(B, S, N, D)`.
|
|
2372
|
+
|
|
2373
|
+
Keyword Args:
|
|
2374
|
+
pse_shift (Tensor, optional): The padding mask tensor with data type of float16 or bfloat16.
|
|
2375
|
+
Default: ``None``.
|
|
2376
|
+
|
|
2377
|
+
- When Q_S is not 1, if pse_shift is of type float16, the query must be of type float16 or int8.
|
|
2378
|
+
If pse_shift is of type bfloat16, the query must also be of type bfloat16. The input shape
|
|
2379
|
+
must be either :math:`(B, N, Q\_S, KV\_S)` or :math:`(1, N, Q\_S, KV\_S)`, where Q_S corresponds to the
|
|
2380
|
+
S dimension of the query shape, and KV_S corresponds to the S dimension of the key and value shapes.
|
|
2381
|
+
For scenarios where the KV_S of pse_shift is not 32-aligned, it is recommended to pad it
|
|
2382
|
+
to 32 bytes to improve performance. The padding values for the extra portions are not restricted.
|
|
2383
|
+
- When Q_S is 1, if pse_shift is of type float16, the query must also be of type float16.
|
|
2384
|
+
If pse_shift is of type bfloat16, the query must be of type bfloat16. The input shape must be
|
|
2385
|
+
:math:`(B, N, 1, KV\_S)` or :math:`(1, N, 1, KV\_S)`, where KV_S corresponds to the S dimension of the
|
|
2386
|
+
key/value shapes. For scenarios where the KV\_S of pse_shift is not 32-aligned, it is recommended
|
|
2387
|
+
to pad it to 32 bytes to improve performance. The padding values for the extra portions are not
|
|
2388
|
+
restricted.
|
|
2389
|
+
|
|
2390
|
+
atten_mask (Tensor, optional): The attention mask tensor for the result of query*key with data type of int8,
|
|
2391
|
+
uint8 or bool. For each element, 0 indicates retention and 1 indicates discard.
|
|
2392
|
+
Default: ``None``.
|
|
2393
|
+
|
|
2394
|
+
- When Q_S is not 1, the recommended input shapes are Q_S,KV_S; B,Q_S,KV_S; 1,Q_S,KV_S; B,1,Q_S,KV_S
|
|
2395
|
+
or 1,1,Q_S,KV_S.
|
|
2396
|
+
- When Q_S is 1, the recommended input shapes are B,KV_S; B,1,KV_S or B,1,1,KV_S.
|
|
2397
|
+
|
|
2398
|
+
actual_seq_lengths (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
|
|
2399
|
+
query with data type of int64. If this parameter is not specified, it can be set to None, indicating that
|
|
2400
|
+
it matches the S dimension of the query shape. Constraint: The effective sequence length for each batch in
|
|
2401
|
+
this parameter should not exceed the corresponding batch's sequence length in the query. When Q_S is 1, this
|
|
2402
|
+
parameter is ignored.
|
|
2403
|
+
Default: ``None``.
|
|
2404
|
+
actual_seq_lengths_kv (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
|
|
2405
|
+
key and value with data type of int64. If this parameter is not specified, it can be set to None,
|
|
2406
|
+
indicating that it matches the S dimension of the key and value shape. Constraint: The effective sequence
|
|
2407
|
+
length for each batch in this parameter should not exceed the corresponding batch's sequence length in the
|
|
2408
|
+
key and value.
|
|
2409
|
+
Default: ``None``.
|
|
2410
|
+
dequant_scale1 (Tensor, optional): Quantization factors for inverse quantization after BMM1 with data type of
|
|
2411
|
+
uint64. Supports per-tensor mode. If not used, set it to None.
|
|
2412
|
+
Default: ``None``.
|
|
2413
|
+
quant_scale1 (Tensor, optional): Quantization factors for quantization before BMM2 with data type of float32.
|
|
2414
|
+
Supports per-tensor mode. If not used, set it to None.
|
|
2415
|
+
Default: ``None``.
|
|
2416
|
+
dequant_scale2 (Tensor, optional): Quantization factors for inverse quantization after BMM2 with data type of
|
|
2417
|
+
uint64. Supports per-tensor mode. If not used, set it to None.
|
|
2418
|
+
Default: ``None``.
|
|
2419
|
+
quant_scale2 (Tensor, optional): Quantization factors for output quantization with data type of float32,
|
|
2420
|
+
bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
|
|
2421
|
+
Default: ``None``.
|
|
2422
|
+
quant_offset2 (Tensor, optional): Quantization offset for output quantization with data type of float32,
|
|
2423
|
+
bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
|
|
2424
|
+
Default: ``None``.
|
|
2425
|
+
|
|
2426
|
+
For scenarios where the input is int8 and the output is int8: the parameters dequant_scale1, quant_scale1,
|
|
2427
|
+
dequant_scale2, and quant_scale2 must all be provided. The parameter quant_offset2 is optional and defaults
|
|
2428
|
+
to 0 if not specified.
|
|
2429
|
+
|
|
2430
|
+
- When the output is int8 and quant_scale2 and quant_offset2 are per-channel, left padding, Ring Attention,
|
|
2431
|
+
or D-axis misalignment (not 32-aligned) scenarios are not supported.
|
|
2432
|
+
- When the output is int8, scenarios with sparse_mode as band and pre_tokens/next_tokens being negative are
|
|
2433
|
+
not supported.
|
|
2434
|
+
- When the output is int8, if quant_offset2 is not None and empty tensor, and the sparse_mode, pre_tokens,
|
|
2435
|
+
and next_tokens meet the following conditions, certain rows of the matrix may not participate in
|
|
2436
|
+
calculations, leading to errors. This scenario will be intercepted (solution: if this scenario should
|
|
2437
|
+
not be intercepted, quantization should be performed outside the FIA interface, not enabled inside the
|
|
2438
|
+
FIA interface):
|
|
2439
|
+
|
|
2440
|
+
- sparse_mode = 0, if atten_mask is a not None and each batch's
|
|
2441
|
+
actual_seq_lengths - actual_seq_lengths_kv - pre_tokens > 0 or next_tokens < 0, it will meet the
|
|
2442
|
+
interception condition.
|
|
2443
|
+
- sparse_mode = 1 or 2, no interception condition will occur.
|
|
2444
|
+
- sparse_mode = 3, if each batch's actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the
|
|
2445
|
+
interception condition.
|
|
2446
|
+
- sparse_mode = 4, if pre_tokens < 0 or each batch's
|
|
2447
|
+
next_tokens + actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the interception
|
|
2448
|
+
condition.
|
|
2449
|
+
|
|
2450
|
+
For scenarios where the input is int8 and the output is float16: the parameters dequant_scale1,
|
|
2451
|
+
quant_scale1, and dequant_scale2 must all be provided.
|
|
2452
|
+
|
|
2453
|
+
For scenarios where the input is entirely float16 or bfloat16 and the output is int8: the parameter
|
|
2454
|
+
quant_scale2 must be provided. The parameter quant_offset2 is optional and defaults to 0 if not specified.
|
|
2455
|
+
|
|
2456
|
+
The parameters quant_scale2 and quant_offset2 support both per-tensor and per-channel modes and two data
|
|
2457
|
+
types: float32 and bfloat16. If quant_offset2 is provided, its type and shape must match those of
|
|
2458
|
+
quant_scale2. When the input is bfloat16, both float32 and bfloat16 are supported; otherwise, only float32
|
|
2459
|
+
is supported. For per-channel mode: When the output layout is BSH, the product of all dimensions in
|
|
2460
|
+
quant_scale2 must equal H. For other layouts, the product must equal N * D. When the output layout is BSH,
|
|
2461
|
+
it is recommended to set the shape of quant_scale2 as :math:`(1, 1, H)` or :math:`(H)`. When the output
|
|
2462
|
+
layout is BNSD, it is recommended to set the shape as :math:`(1, N, 1, D)` or :math:`(N, D)`. When the
|
|
2463
|
+
output layout is BSND, it is recommended to set the shape as :math:`(1, 1, N, D)` or :math:`(N, D)`.
|
|
2464
|
+
|
|
2465
|
+
antiquant_scale (Tensor, optional): Inverse quantization factors with data type of float16, float32 or bfloat16.
|
|
2466
|
+
Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
|
|
2467
|
+
Default: ``None``.
|
|
2468
|
+
antiquant_offset (Tensor, optional): Inverse quantization offset with data type of float16, float32 or bfloat16.
|
|
2469
|
+
Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
|
|
2470
|
+
Default: ``None``.
|
|
2471
|
+
|
|
2472
|
+
Constraints for antiquant_scale and antiquant_offset parameters:
|
|
2473
|
+
|
|
2474
|
+
- Supports three modes: per-channel, per-tensor, and per-token:
|
|
2475
|
+
|
|
2476
|
+
- Per-channel mode: The shape of both parameters in the BNSD layout is :math:`(2, N, 1, D)`, the shape
|
|
2477
|
+
in the BSND layout is :math:`(2, N, D)`, and the shape in the BSH layout is :math:`(2, H)`, where 2
|
|
2478
|
+
corresponds to the key and value, and N represents num_key_value_heads. The parameter data type is
|
|
2479
|
+
the same as the query data type, and antiquant_mode should be set to 0.
|
|
2480
|
+
- Per-tensor mode: The shape of both parameters is :math:`(2)`, the data type is the same as the query
|
|
2481
|
+
data type, and antiquant_mode should be set to 0.
|
|
2482
|
+
- Per-token mode: The shape of both parameters is :math:`(2, B, S)`, the data type is fixed to float32,
|
|
2483
|
+
and antiquant_mode should be set to 1.
|
|
2484
|
+
|
|
2485
|
+
- Supports both symmetric and asymmetric quantization:
|
|
2486
|
+
|
|
2487
|
+
- Asymmetric quantization mode: Both antiquant_scale and antiquant_offset must be provided.
|
|
2488
|
+
- Symmetric quantization mode: antiquant_offset can be empty (``None``). If antiquant_offset is empty,
|
|
2489
|
+
symmetric quantization is performed. If antiquant_offset is provided, asymmetric quantization is
|
|
2490
|
+
performed.
|
|
2491
|
+
|
|
2492
|
+
key_antiquant_scale (Tensor, optional): Inverse quantization factors for the key, with data type of float16,
|
|
2493
|
+
float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2494
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2495
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2496
|
+
key_antiquant_offset (Tensor, optional): Inverse quantization offset for the key, with data type of float16,
|
|
2497
|
+
float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2498
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2499
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2500
|
+
value_antiquant_scale (Tensor, optional): Inverse quantization factors for the value, with data type of
|
|
2501
|
+
float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2502
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2503
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2504
|
+
value_antiquant_offset (Tensor, optional): Inverse quantization offset for the value, with data type of
|
|
2505
|
+
float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2506
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2507
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2508
|
+
block_table (Tensor, optional): Block mapping table in KV cache for PageAttention, with data type of int32.
|
|
2509
|
+
If not used, set it to None.
|
|
2510
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2511
|
+
query_padding_size (Tensor, optional): The query padding size with data type of int64. Indicates whether the
|
|
2512
|
+
data in each batch of the query is right-aligned, and how many elements are right-aligned.
|
|
2513
|
+
Default: ``None``. Invalid when Q_S is 1.
|
|
2514
|
+
kv_padding_size (Tensor, optional): The key and value padding size with data type of int64. Indicates whether
|
|
2515
|
+
the data in each batch of the key and value is right-aligned, and how many elements are right-aligned.
|
|
2516
|
+
Default: ``None``. Invalid when Q_S is 1.
|
|
2517
|
+
key_shared_prefix (Tensor, optional): Shared prefix of the key. This is a reserved parameter and is not yet
|
|
2518
|
+
enabled. Default: ``None``.
|
|
2519
|
+
value_shared_prefix (Tensor, optional): Shared prefix of the value. This is a reserved parameter and is not yet
|
|
2520
|
+
enabled. Default: ``None``.
|
|
2521
|
+
actual_shared_prefix_len (Union[tuple[int], list[int], Tensor], optional): Describe the actual length of shared
|
|
2522
|
+
prefix. This is a reserved parameter and is not yet enabled.
|
|
2523
|
+
Default: ``None``.
|
|
2524
|
+
num_heads (int, optional): The number of heads in the query, equal to N when input_layout is BNSD.
|
|
2525
|
+
Default: ``1``.
|
|
2526
|
+
scale (double, optional): The scale value indicating the scale coefficient, which serves as the scalar value for
|
|
2527
|
+
the Muls in the calculation. Generally, the value is :math:`1.0 / \sqrt{d}`. Default: ``1.0``.
|
|
2528
|
+
pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
|
|
2529
|
+
Default: ``2147483647``. Invalid when Q_S is 1.
|
|
2530
|
+
next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
|
|
2531
|
+
Default: ``2147483647``. Invalid when Q_S is 1.
|
|
2532
|
+
input_layout (str, optional): Specifies the layout of input query, key and value. BSH, BNSD, BSND or
|
|
2533
|
+
BNSD_BSND is supported. When the layout is BNSD_BSND, it means the input is in the BNSD format and
|
|
2534
|
+
the output is in the BSND format, this is only supported when Q_S > 1.
|
|
2535
|
+
Default: ``BSH``.
|
|
2536
|
+
num_key_value_heads (int, optional): Head numbers of key/value which are used in GQA (Grouped-Query Attention)
|
|
2537
|
+
scenario. Default: ``0``. A value of 0 means it is equal to the number of key/value heads. The num_heads
|
|
2538
|
+
must be divisible by num_key_value_heads, and the ratio of num_heads to num_key_value_heads must not be
|
|
2539
|
+
greater than 64. When the layout is BNSD, the num_key_value_heads must also equals to the N dimension of
|
|
2540
|
+
the key/value shapes, otherwise, an execution error will occur.
|
|
2541
|
+
sparse_mode (int, optional): Indicates sparse mode. Default ``0``. Invalid when Q_S is 1.
|
|
2542
|
+
|
|
2543
|
+
- 0: Indicates the defaultMask mode. If atten_mask is not passed, the mask operation is not performed,
|
|
2544
|
+
and pre_tokens and next_tokens(internally assigned as INT_MAX) are ignored. If passed in, the complete
|
|
2545
|
+
atten_mask matrix (S1 * S2) also must be passed in, indicating that the part between pre_tokens and
|
|
2546
|
+
next_tokens needs to be calculated.
|
|
2547
|
+
- 1: Represents allMask. The complete atten_mask matrix (S1 * S2) is required.
|
|
2548
|
+
- 2: Represents the mask in leftUpCausal mode. The optimized atten_mask matrix (2048*2048) is required.
|
|
2549
|
+
- 3: Represents the mask in rightDownCausal mode, corresponding to the lower triangular scenario divided by
|
|
2550
|
+
the right vertex. The optimized atten_mask matrix (2048*2048) is required.
|
|
2551
|
+
- 4: Represents the mask in band mode, that is, the part between counting pre_tokens and next_tokens. The
|
|
2552
|
+
optimized atten_mask matrix (2048*2048) is required.
|
|
2553
|
+
- 5: Represents the prefix scenario, not implemented yet.
|
|
2554
|
+
- 6: Represents the global scenario, not implemented yet.
|
|
2555
|
+
- 7: Represents the dilated scenario, not implemented yet.
|
|
2556
|
+
- 8: Represents the block_local scenario, not implemented yet.
|
|
2557
|
+
|
|
2558
|
+
inner_precise (int, optional): There are four modes: 0, 1, 2, and 3, represented by 2 bits: bit 0 (bit0)
|
|
2559
|
+
represents the choice for high precision or high performance, and bit 1 (bit1) indicates whether row-wise
|
|
2560
|
+
invalidity correction is applied.
|
|
2561
|
+
|
|
2562
|
+
- 0: Enable high-precise mode, without row-wise invalidity correction.
|
|
2563
|
+
- 1: High-performance mode, without row-wise invalidity correction.
|
|
2564
|
+
- 2: Enable high-precise mode, with row-wise invalidity correction.
|
|
2565
|
+
- 3: High-performance mode, with row-wise invalidity correction.
|
|
2566
|
+
|
|
2567
|
+
When Q_S > 1, if sparse_mode is 0 or 1 and a user-defined mask is provided, it is recommended to enable
|
|
2568
|
+
row-wise invalidity correction. Only support 0 and 1 when Q_S is 1. Default: ``1``.
|
|
2569
|
+
|
|
2570
|
+
High-precise and high-performance are only effective for float16 inputs; Row invalidity correction
|
|
2571
|
+
is effective for float16, bfloat16, and int8 inputs.
|
|
2572
|
+
Currently, 0 and 1 are reserved configuration values. If there is a situation where an entire row in the
|
|
2573
|
+
"mask portion involved in computation" is all 1s, precision may degrade. In such cases, you can try
|
|
2574
|
+
setting this parameter to 2 or 3 to enable row invalidity correction for improved precision. However,
|
|
2575
|
+
this configuration will result in decreased performance.
|
|
2576
|
+
If the function can detect the presence of invalid row scenarios, e.g. in cases where sparse_mode is 3
|
|
2577
|
+
and S_q > S_kv, it will automatically enable row invalidity computation.
|
|
2578
|
+
|
|
2579
|
+
block_size (int, optional): Maximum number of tokens per block in the KV cache block for PageAttention.
|
|
2580
|
+
Default: ``0``. Invalid when Q_S > 1.
|
|
2581
|
+
antiquant_mode (int, optional): Fake-quantization mode, 0: per-channel (per-channel includes per-tensor),
|
|
2582
|
+
1: per-token. The per-channel and per-tensor modes can be distinguished by the dimension of the input
|
|
2583
|
+
shape. When the dimension is 1, it runs in per-tensor mode; otherwise, it runs in per-channel mode.
|
|
2584
|
+
Default: ``0``. Invalid when Q_S > 1.
|
|
2585
|
+
key_antiquant_mode (int, optional): Fake-quantization mode for the key. 0: per-channel (per-channel includes
|
|
2586
|
+
per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
|
|
2587
|
+
value_antiquant_mode (int, optional): Fake-quantization mode for the value. 0: per-channel (per-channel includes
|
|
2588
|
+
per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
|
|
2589
|
+
softmax_lse_flag (bool, optional): Whether to output softmax_lse. Default: ``False``.
|
|
2590
|
+
|
|
2591
|
+
Returns:
|
|
2592
|
+
attention_out (Tensor), the attention score with data type of float16, bfloat16 or int8. When the input_layout
|
|
2593
|
+
is BNSD_BSND, the shape is :math:`(B, S, N, D)`. In all other cases, the shape is consistent with the
|
|
2594
|
+
input query shape.
|
|
2595
|
+
|
|
2596
|
+
softmax_lse (Tensor), the softmax_lse with data type of float32, obtained by taking the lse (log, sum and exp)
|
|
2597
|
+
of the result of query*key. Specifically, the Ring Attention algorithm first takes the max of the result of
|
|
2598
|
+
query*key, obtaining softmax_max. The result of query*key is then subtracted by softmax_max, followed by
|
|
2599
|
+
taking exp, and then the sum is computed to obtain softmax_sum. Finally, the log of softmax_sum is taken,
|
|
2600
|
+
and softmax_max is added to obtain softmax_lse. The softmax_lse is only calculated when softmax_lse_flag
|
|
2601
|
+
is True, and the shape would be :math:`(B, N, Q\_S, 1)`. If softmax_lse_flag is False, then a tensor with
|
|
2602
|
+
shape :math:`(1)` filled with zeros would be returned. In graph mode with JitConfig set to O2, please ensure
|
|
2603
|
+
that the softmax_lse_flag is enabled before using softmax_lse; otherwise, an exception will occur.
|
|
2604
|
+
|
|
2605
|
+
Constraints:
|
|
2606
|
+
- Full Inference Scenario (Q_S > 1):
|
|
2607
|
+
|
|
2608
|
+
- Query, key, and value inputs functional usage restrictions:
|
|
2609
|
+
|
|
2610
|
+
- The B axis supports values less than or equal to 65535. If the input type includes int8, or
|
|
2611
|
+
if the input type is float16 or bfloat16 and the D axis is not 16-aligned, the B axis is only
|
|
2612
|
+
supported up to 128.
|
|
2613
|
+
- The N axis supports values less than or equal to 256, and the D axis supports values less than
|
|
2614
|
+
or equal to 512.
|
|
2615
|
+
- The S axis supports values less than or equal to 20,971,520 (20M). In some long sequence
|
|
2616
|
+
scenarios, if the computation load is too large, it may cause a timeout in the PFA operator
|
|
2617
|
+
(AICore error type with errorStr: "timeout or trap error"). In this case, it is recommended to
|
|
2618
|
+
perform an S split. Note: The computational load is affected by B, S, N, D, etc.; the larger the
|
|
2619
|
+
values, the greater the computational load. Typical long sequence timeout scenarios (where the
|
|
2620
|
+
product of B, S, N, and D is large) include, but are not limited to:
|
|
2621
|
+
|
|
2622
|
+
1. B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
|
|
2623
|
+
2. B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520;
|
|
2624
|
+
3. B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
|
|
2625
|
+
4. B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
|
|
2626
|
+
|
|
2627
|
+
- When the query, key, value, or attention_out type includes int8, the D axis must be 32-aligned.
|
|
2628
|
+
If all types are float16 or bfloat16, the D axis must be 16-aligned.
|
|
2629
|
+
|
|
2630
|
+
- The sparse_mode parameter currently only supports values 0, 1, 2, 3, and 4. Using any other values
|
|
2631
|
+
will result in an error.
|
|
2632
|
+
|
|
2633
|
+
- When sparse_mode = 0, if the atten_mask is None, or if the atten_mask is provided in the left
|
|
2634
|
+
padding scenario, the input parameters pre_tokens and next_tokens are ignored.
|
|
2635
|
+
- When sparse_mode = 2, 3, or 4, the shape of the atten_mask must be S,S or 1,S,S or 1,1,S,S, where
|
|
2636
|
+
S must be fixed at 2048, and the user must ensure the atten_mask is a lower triangular matrix. If
|
|
2637
|
+
no atten_mask is provided or if the shape is incorrect, an error will occur.
|
|
2638
|
+
- In sparse_mode = 1, 2, 3 scenarios, the pre_tokens and next_tokens inputs are ignored and assigned
|
|
2639
|
+
according to the relevant rules.
|
|
2640
|
+
|
|
2641
|
+
- The KV cache de-quantization only supports queries of type float16, where int8 keys and values are
|
|
2642
|
+
de-quantized to float16. The data range of the input key/value and the antiquant_scale must have a
|
|
2643
|
+
product within the range of (-1, 1). High-performance mode can guarantee precision; otherwise,
|
|
2644
|
+
high-precision mode should be enabled to ensure accuracy.
|
|
2645
|
+
|
|
2646
|
+
- Query left padding scenario:
|
|
2647
|
+
|
|
2648
|
+
- In the query left padding scenario, the formula for calculating the starting point of the query
|
|
2649
|
+
transport is: Q_S - query_padding_size - actual_seq_lengths. The formula for the
|
|
2650
|
+
ending point of the query transport is: Q_S - query_padding_size. The query transport
|
|
2651
|
+
starting point must not be less than 0, and the ending point must not exceed Q_S; otherwise,
|
|
2652
|
+
the results will be incorrect.
|
|
2653
|
+
- If the kv_padding_size in the query left padding scenario is less than 0, it will be set to 0.
|
|
2654
|
+
- The query left padding scenario must be enabled together with the actual_seq_lengths parameter,
|
|
2655
|
+
otherwise, the default is the query right padding scenario.
|
|
2656
|
+
- The query left padding scenario does not support PageAttention and cannot be enabled together with
|
|
2657
|
+
the block_table parameter.
|
|
2658
|
+
|
|
2659
|
+
- KV left padding scenario:
|
|
2660
|
+
|
|
2661
|
+
- In the KV left padding scenario, the formula for calculating the starting point of the key and
|
|
2662
|
+
value transport is: KV_S - kv_padding_size - actual_seq_lengths_kv. The formula
|
|
2663
|
+
for the ending point of the key and value transport is: KV_S - kv_padding_size. The
|
|
2664
|
+
key and value transport starting point must not be less than 0, and the ending point must not
|
|
2665
|
+
exceed KV_S; otherwise, the results will be incorrect.
|
|
2666
|
+
- If the kv_padding_size in the KV left padding scenario is less than 0, it will be set to 0.
|
|
2667
|
+
- The KV left padding scenario must be enabled together with the actual_seq_lengths_kv parameter,
|
|
2668
|
+
otherwise, the default is the KV right padding scenario.
|
|
2669
|
+
- The KV left padding scenario does not support PageAttention and cannot be enabled together with
|
|
2670
|
+
the block_table parameter.
|
|
2671
|
+
|
|
2672
|
+
- pse_shift functional usage restrictions:
|
|
2673
|
+
|
|
2674
|
+
- This function is supported when the query data type is float16, bfloat16, or int8.
|
|
2675
|
+
- If the query data type is float16 and pse_shift is enabled, it will force high-precision mode,
|
|
2676
|
+
inheriting the limitations of high-precision mode.
|
|
2677
|
+
- Q_S must be greater than or equal to the length of the query S, and KV_S must be greater than
|
|
2678
|
+
or equal to the length of the key S.
|
|
2679
|
+
|
|
2680
|
+
- KV fake quantization parameter separation is not currently supported.
|
|
2681
|
+
|
|
2682
|
+
- Incremental Inference Scenario (Q_S is 1):
|
|
2683
|
+
|
|
2684
|
+
- Query, key, and value inputs functional usage restrictions:
|
|
2685
|
+
|
|
2686
|
+
- The B axis supports values less than or equal to 65,536.
|
|
2687
|
+
- The N axis supports values less than or equal to 256.
|
|
2688
|
+
- The D axis supports values less than or equal to 512.
|
|
2689
|
+
- Scenarios where the input types of query, key, and value are all int8 are not supported.
|
|
2690
|
+
|
|
2691
|
+
- Page attention scenario:
|
|
2692
|
+
|
|
2693
|
+
- The necessary condition to enable page attention is that the block_table exists and is valid.
|
|
2694
|
+
The key and value are arranged in contiguous memory according to the indices in the block_table.
|
|
2695
|
+
The key and value dtypes supported are float16, bfloat16, and int8. In this scenario, the
|
|
2696
|
+
input_layout parameter for key and value is invalid.
|
|
2697
|
+
- block_size is a user-defined parameter, and its value will affect the performance of page
|
|
2698
|
+
attention. When enabling page attention, a non-zero value for block_size must be provided, and
|
|
2699
|
+
the maximum value for block_size is 512.
|
|
2700
|
+
- If the input types of key and value are float16 or bfloat16, they must be 16-aligned. If the
|
|
2701
|
+
input types are int8, they must be 32-aligned, with 128 being recommended. In general, page
|
|
2702
|
+
attention can increase throughput but may lead to a performance decrease.
|
|
2703
|
+
- In the page attention enabled scenario, when the KV cache layout is (blocknum, block_size, H) and
|
|
2704
|
+
num_key_value_heads * D exceeds 64K, an error will be reported due to hardware
|
|
2705
|
+
instruction constraints. This can be resolved by enabling GQA (reducing num_key_value_heads) or
|
|
2706
|
+
adjusting the KV cache layout to (blocknum, num_key_value_heads, block_size, D).
|
|
2707
|
+
- The product of all dimensions of the shape of the key and value tensors in the page attention
|
|
2708
|
+
scenario must not exceed the representable range of int32.
|
|
2709
|
+
|
|
2710
|
+
- In the page attention enabled scenario, the input S must be greater than or equal to
|
|
2711
|
+
max_block_num_per_seq * block_size.
|
|
2712
|
+
|
|
2713
|
+
- Enabling attention mask (e.g., mask shape = (B, 1, 1, S))
|
|
2714
|
+
- Enabling pse_shift (e.g., pse_shift shape = (B, N, 1, S))
|
|
2715
|
+
- Enabling fake quantization in per-token mode (e.g., antiquant_scale and antiquant_offset shapes =
|
|
2716
|
+
(2, B, S)) are also supported.
|
|
2717
|
+
|
|
2718
|
+
- KV left padding scenario:
|
|
2719
|
+
|
|
2720
|
+
- In the KV left padding scenario, the formula for calculating the starting point of the KV cache
|
|
2721
|
+
transport is: KV_S - kv_padding_size - actual_seq_lengths. The formula for the endpoint of the
|
|
2722
|
+
KV cache transport is: KV_S - kv_padding_size. If the starting point or endpoint of the KV cache
|
|
2723
|
+
is less than 0, the returned data result will be all zeros.
|
|
2724
|
+
- If kv_padding_size is less than 0 in the KV left padding scenario, it will be set to 0.
|
|
2725
|
+
- The KV left padding scenario must be enabled together with the actual_seq_lengths parameter,
|
|
2726
|
+
otherwise, it defaults to the KV right padding scenario.
|
|
2727
|
+
- The KV left padding scenario must be enabled together with the atten_mask parameter, and the
|
|
2728
|
+
atten_mask must be correctly applied to hide invalid data. Otherwise, accuracy issues may arise.
|
|
2729
|
+
|
|
2730
|
+
- pse_shift functional usage restrictions:
|
|
2731
|
+
|
|
2732
|
+
- The data type of pse_shift must match the data type of the query.
|
|
2733
|
+
- Only the D axis alignment is supported, meaning the D axis must be divisible by 16.
|
|
2734
|
+
|
|
2735
|
+
- KV fake quantization parameter separation:
|
|
2736
|
+
|
|
2737
|
+
- key_antiquant_mode and value_antiquant_mode must be consistent.
|
|
2738
|
+
- key_antiquant_scale and value_antiquant_scale must either both be empty or both non-empty.
|
|
2739
|
+
- key_antiquant_offset and value_antiquant_offset must either both be empty or both non-empty.
|
|
2740
|
+
- When both key_antiquant_scale and value_antiquant_scale are non-empty, their shapes must be
|
|
2741
|
+
consistent.
|
|
2742
|
+
- When both key_antiquant_offset and value_antiquant_offset are non-empty, their shapes must be
|
|
2743
|
+
consistent.
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
Supported Platforms:
|
|
2747
|
+
``Ascend``
|
|
2748
|
+
|
|
2749
|
+
Examples:
|
|
2750
|
+
>>> from mindspore import ops
|
|
2751
|
+
>>> from mindspore import Tensor
|
|
2752
|
+
>>> import numpy as np
|
|
2753
|
+
>>> B, N, S, D = 1, 8, 1024, 128
|
|
2754
|
+
>>> query = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2755
|
+
>>> key = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2756
|
+
>>> value = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2757
|
+
>>> out = ops.fused_infer_attention_score(query, key, value, num_heads=N, input_layout='BNSD')
|
|
2758
|
+
>>> print(out[0].shape)
|
|
2759
|
+
(1, 8, 1024, 128)
|
|
2760
|
+
"""
|
|
2761
|
+
fias_op = _get_cache_prim(FusedInferAttentionScore)(num_heads, scale, pre_tokens, next_tokens, input_layout,
|
|
2762
|
+
num_key_value_heads, sparse_mode, inner_precise, block_size,
|
|
2763
|
+
antiquant_mode, softmax_lse_flag, key_antiquant_mode,
|
|
2764
|
+
value_antiquant_mode)
|
|
2765
|
+
key_list = key if isinstance(key, (tuple, list)) else [key]
|
|
2766
|
+
value_list = value if isinstance(value, (tuple, list)) else [value]
|
|
2767
|
+
return fias_op(query, key_list, value_list, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
|
|
2768
|
+
dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale,
|
|
2769
|
+
antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale,
|
|
2770
|
+
key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix,
|
|
2771
|
+
value_shared_prefix, actual_shared_prefix_len)
|
|
2772
|
+
|
|
2773
|
+
|
|
2107
2774
|
class WhileLoop(Primitive):
|
|
2108
2775
|
"""
|
|
2109
2776
|
Provide a useful op for reducing compilation times of while loop.
|
|
@@ -2276,7 +2943,7 @@ class Scan(Primitive):
|
|
|
2276
2943
|
|
|
2277
2944
|
class ForiLoop(Primitive):
|
|
2278
2945
|
"""
|
|
2279
|
-
|
|
2946
|
+
Performs a loop operation within the specified range.
|
|
2280
2947
|
The execution logic of the ForiLoop operator can be roughly represented by the following code:
|
|
2281
2948
|
|
|
2282
2949
|
.. code-block:: python
|