mindspore 2.4.10__cp311-cp311-win_amd64.whl → 2.6.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2023-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -23,20 +23,21 @@ import numpy as np
|
|
|
23
23
|
from mindspore.ops import signature as sig
|
|
24
24
|
from mindspore.ops.primitive import Primitive, prim_attr_register, prim_arg_register, PrimitiveWithInfer
|
|
25
25
|
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
26
|
-
from mindspore.ops.
|
|
26
|
+
from mindspore.ops._utils import arg_handler as handler
|
|
27
|
+
from mindspore.ops._utils.arg_dtype_cast import DtypeToEnum
|
|
27
28
|
from mindspore.common import Tensor, CSRTensor, COOTensor
|
|
28
29
|
from mindspore.common._stub_tensor import _convert_stub
|
|
29
30
|
from mindspore._c_expression import typing
|
|
30
|
-
from mindspore._c_expression import
|
|
31
|
-
from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones
|
|
31
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
32
|
+
from mindspore._c_expression import pyboost_cast, pyboost_tile, pyboost_zeros, pyboost_ones, pyboost_type_as
|
|
32
33
|
from mindspore.common import dtype as mstype
|
|
33
34
|
from mindspore.common._utils import is_shape_unknown
|
|
34
35
|
from mindspore import _checkparam as validator
|
|
35
36
|
from mindspore.ops.operations.manually_defined._inner import ScalarCast
|
|
36
|
-
from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
|
|
37
37
|
from mindspore.common.initializer import Zero
|
|
38
38
|
from mindspore.common.parameter import Parameter
|
|
39
|
-
from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore
|
|
39
|
+
from mindspore.ops.auto_generate.gen_ops_prim import FlashAttentionScore, FusedInferAttentionScore
|
|
40
|
+
from mindspore.common.jit_context import jit_context
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
dtype_to_type_id = DtypeToEnum()
|
|
@@ -527,6 +528,64 @@ class ScalarBool(Primitive):
|
|
|
527
528
|
return bool(x)
|
|
528
529
|
|
|
529
530
|
|
|
531
|
+
class ScalarMax(Primitive):
|
|
532
|
+
r"""
|
|
533
|
+
Return the maximum of two input scalars.
|
|
534
|
+
|
|
535
|
+
.. note::
|
|
536
|
+
The inputs can be constant/variable value. Usage is the same as 'max' in Python.
|
|
537
|
+
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
|
538
|
+
|
|
539
|
+
Inputs:
|
|
540
|
+
- **x** (Scalar) - A constant or variable scalar.
|
|
541
|
+
- **y** (Scalar) - A constant or variable scalar.
|
|
542
|
+
|
|
543
|
+
Outputs:
|
|
544
|
+
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
|
545
|
+
|
|
546
|
+
Raises:
|
|
547
|
+
TypeError: If `x` and `y` are not scalar.
|
|
548
|
+
|
|
549
|
+
Supported Platforms:
|
|
550
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
551
|
+
"""
|
|
552
|
+
@prim_attr_register
|
|
553
|
+
def __init__(self):
|
|
554
|
+
"""Initialize ScalarMax"""
|
|
555
|
+
|
|
556
|
+
def __call__(self, x, y):
|
|
557
|
+
return max(x, y)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class ScalarMin(Primitive):
|
|
561
|
+
r"""
|
|
562
|
+
Return the minimum of two input scalars.
|
|
563
|
+
|
|
564
|
+
.. note::
|
|
565
|
+
The inputs can be constant/variable value. Usage is the same as 'min' in Python.
|
|
566
|
+
This primitive only have 'CPU' implementation, for other platform, it runs using heterogeneous.
|
|
567
|
+
|
|
568
|
+
Inputs:
|
|
569
|
+
- **x** (Scalar) - A constant or variable scalar.
|
|
570
|
+
- **y** (Scalar) - A constant or variable scalar.
|
|
571
|
+
|
|
572
|
+
Outputs:
|
|
573
|
+
Scalar, and the data type is the one with higher precision or higher digits among the two inputs.
|
|
574
|
+
|
|
575
|
+
Raises:
|
|
576
|
+
TypeError: If `x` and `y` are not scalar.
|
|
577
|
+
|
|
578
|
+
Supported Platforms:
|
|
579
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
580
|
+
"""
|
|
581
|
+
@prim_attr_register
|
|
582
|
+
def __init__(self):
|
|
583
|
+
"""Initialize ScalarMin"""
|
|
584
|
+
|
|
585
|
+
def __call__(self, x, y):
|
|
586
|
+
return min(x, y)
|
|
587
|
+
|
|
588
|
+
|
|
530
589
|
scalar_div = ScalarDiv()
|
|
531
590
|
scalar_mod = ScalarMod()
|
|
532
591
|
scalar_add = ScalarAdd()
|
|
@@ -543,6 +602,8 @@ scalar_log = ScalarLog()
|
|
|
543
602
|
scalar_pow = ScalarPow()
|
|
544
603
|
scalar_uadd = ScalarUadd()
|
|
545
604
|
scalar_usub = ScalarUsub()
|
|
605
|
+
scalar_max = ScalarMax()
|
|
606
|
+
scalar_min = ScalarMin()
|
|
546
607
|
|
|
547
608
|
|
|
548
609
|
class BatchNorm(Primitive):
|
|
@@ -570,31 +631,28 @@ class BatchNorm(Primitive):
|
|
|
570
631
|
- For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction.
|
|
571
632
|
|
|
572
633
|
Args:
|
|
573
|
-
is_training (bool): If `is_training` is ``True`` ,
|
|
634
|
+
is_training (bool, optional): If `is_training` is ``True`` ,
|
|
635
|
+
`mean` and `variance` are computed during training.
|
|
574
636
|
If `is_training` is ``False`` , they're loaded from checkpoint during inference. Default: ``False`` .
|
|
575
|
-
epsilon (float): A small value added for numerical stability.
|
|
576
|
-
|
|
637
|
+
epsilon (float, optional): A small value added for numerical stability.
|
|
638
|
+
Default: ``1e-5``, value must be (0, 1] .
|
|
639
|
+
momentum (float, optional): The hyper parameter to compute moving average for running_mean and running_var
|
|
577
640
|
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
|
578
641
|
Momentum value must be [0, 1]. Default: ``0.1`` .
|
|
579
|
-
data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
|
|
642
|
+
data_format (str, optional): The optional value for data format, is ``'NHWC'`` or ``'NCHW'``,
|
|
643
|
+
and the ``'NHWC'`` format
|
|
580
644
|
is only supported in GPU target. Default: ``"NCHW"`` .
|
|
581
645
|
|
|
582
646
|
Inputs:
|
|
583
|
-
If `is_training` is ``False`` , inputs are Tensors.
|
|
584
|
-
|
|
585
|
-
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
|
586
|
-
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
|
587
|
-
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
588
|
-
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
589
|
-
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
|
590
|
-
|
|
591
|
-
If `is_training` is ``True`` , `scale`, `bias`, `mean` and `variance` are Parameters.
|
|
592
|
-
|
|
593
647
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
|
594
|
-
- **scale** (Parameter) - Parameter of shape :math:`(C,)`,
|
|
595
|
-
|
|
596
|
-
- **
|
|
597
|
-
|
|
648
|
+
- **scale** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
649
|
+
with float16 or float32 data type.
|
|
650
|
+
- **bias** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
651
|
+
has the same data type with `scale`.
|
|
652
|
+
- **mean** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
653
|
+
has the same data type with `scale`.
|
|
654
|
+
- **variance** (Union[Parameter, Tensor]) - Tensor or Parameter of shape :math:`(C,)`,
|
|
655
|
+
has the same data type with `scale`.
|
|
598
656
|
|
|
599
657
|
Outputs:
|
|
600
658
|
Tuple of 5 Tensors, the normalized inputs and the updated parameters.
|
|
@@ -794,29 +852,21 @@ class Rank(Primitive):
|
|
|
794
852
|
|
|
795
853
|
def rank(input_x):
|
|
796
854
|
"""
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
Returns a 0-D int32 Tensor representing the rank of input; the rank of a tensor
|
|
800
|
-
is the number of indices required to uniquely select each element of the tensor.
|
|
855
|
+
Return the rank of a tensor.
|
|
801
856
|
|
|
802
857
|
Args:
|
|
803
|
-
input_x (Tensor): The
|
|
858
|
+
input_x (Tensor): The input tensor.
|
|
804
859
|
|
|
805
860
|
Returns:
|
|
806
|
-
Tensor
|
|
807
|
-
|
|
808
|
-
Raises:
|
|
809
|
-
TypeError: If `input_x` is not a Tensor.
|
|
861
|
+
Tensor
|
|
810
862
|
|
|
811
863
|
Supported Platforms:
|
|
812
864
|
``Ascend`` ``GPU`` ``CPU``
|
|
813
865
|
|
|
814
866
|
Examples:
|
|
815
867
|
>>> import mindspore
|
|
816
|
-
>>>
|
|
817
|
-
>>>
|
|
818
|
-
>>> input_tensor = Tensor(np.array([[2, 2], [2, 2]]), mindspore.float32)
|
|
819
|
-
>>> output = ops.rank(input_tensor)
|
|
868
|
+
>>> input_tensor = mindspore.tensor([[2, 2], [2, 2]], mindspore.float32)
|
|
869
|
+
>>> output = mindspore.ops.rank(input_tensor)
|
|
820
870
|
>>> print(output)
|
|
821
871
|
2
|
|
822
872
|
>>> print(type(output))
|
|
@@ -938,10 +988,6 @@ class Tile(Primitive):
|
|
|
938
988
|
|
|
939
989
|
Refer to :func:`mindspore.ops.tile` for more details.
|
|
940
990
|
|
|
941
|
-
Note:
|
|
942
|
-
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
943
|
-
where more than 4 dimensions are repeated simultaneously.
|
|
944
|
-
|
|
945
991
|
Inputs:
|
|
946
992
|
- **input** (Tensor) - The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
947
993
|
:math:`(x_1, x_2, ..., x_S)` .
|
|
@@ -949,6 +995,10 @@ class Tile(Primitive):
|
|
|
949
995
|
the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
|
|
950
996
|
Only constant value is allowed.
|
|
951
997
|
|
|
998
|
+
.. note::
|
|
999
|
+
On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1000
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
1001
|
+
|
|
952
1002
|
Outputs:
|
|
953
1003
|
Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
|
|
954
1004
|
the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
|
|
@@ -1005,7 +1055,16 @@ class Tile(Primitive):
|
|
|
1005
1055
|
"""Initialize."""
|
|
1006
1056
|
|
|
1007
1057
|
def __call__(self, input, dims):
|
|
1008
|
-
|
|
1058
|
+
# Add for jit context.
|
|
1059
|
+
if jit_context() and jit_context().compiled:
|
|
1060
|
+
return None
|
|
1061
|
+
res = _convert_stub(pyboost_tile(self, [input, dims]))
|
|
1062
|
+
# Add for jit context.
|
|
1063
|
+
if jit_context():
|
|
1064
|
+
if validator.is_stub_tensor(res):
|
|
1065
|
+
res = res.stub_sync()
|
|
1066
|
+
return jit_context().run_op(self, res, input, dims)
|
|
1067
|
+
return res
|
|
1009
1068
|
|
|
1010
1069
|
# pylint: disable=missing-docstring
|
|
1011
1070
|
def check_elim(self, *args):
|
|
@@ -1026,26 +1085,14 @@ class Tile(Primitive):
|
|
|
1026
1085
|
|
|
1027
1086
|
def tile(input, dims):
|
|
1028
1087
|
r"""
|
|
1029
|
-
Creates a new tensor by
|
|
1030
|
-
output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
|
|
1031
|
-
are replicated `dims[i]` times along the i'th dimension.
|
|
1088
|
+
Creates a new tensor by repeating the elements in the input tensor `dims` times.
|
|
1032
1089
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
where more than 4 dimensions are repeated simultaneously.
|
|
1036
|
-
|
|
1037
|
-
Args:
|
|
1038
|
-
input (Tensor): The tensor whose elements need to be repeated. Set the shape of input tensor as
|
|
1039
|
-
:math:`(x_1, x_2, ..., x_S)` .
|
|
1040
|
-
|
|
1041
|
-
dims (tuple[int]): The parameter that specifies the number of replications,
|
|
1042
|
-
the parameter type is tuple, and the data type is int, i.e., :math:`(y_1, y_2, ..., y_S)`.
|
|
1043
|
-
Only constant value is allowed.
|
|
1044
|
-
|
|
1045
|
-
Returns:
|
|
1046
|
-
Tensor, has the same data type as the `input`. Suppose the length of `dims` is `d`,
|
|
1047
|
-
the dimension of `input` is `input.dim`, and the shape of `input` is :math:`(x_1, x_2, ..., x_S)`.
|
|
1090
|
+
The i'th dimension of output tensor has `input.shape[i] * dims[i]` elements, and the values of `input`
|
|
1091
|
+
are repeated `dims[i]` times along the i'th dimension.
|
|
1048
1092
|
|
|
1093
|
+
Note:
|
|
1094
|
+
- On Ascend, the number of `dims` should not exceed 8, and currently does not support scenarios
|
|
1095
|
+
where more than 4 dimensions are repeated simultaneously.
|
|
1049
1096
|
- If `input.dim = d`, then the shape of their corresponding positions can be multiplied, and
|
|
1050
1097
|
the shape of Outputs is :math:`(x_1*y_1, x_2*y_2, ..., x_S*y_S)`.
|
|
1051
1098
|
- If `input.dim < d`, prepend 1 to the shape of `input` until their lengths are consistent.
|
|
@@ -1056,40 +1103,39 @@ def tile(input, dims):
|
|
|
1056
1103
|
`dims` as :math:`(1, ..., y_1, y_2, ..., y_S)`, then the shape of their corresponding positions
|
|
1057
1104
|
can be multiplied, and the shape of Outputs is :math:`(x_1*1, ..., x_R*y_R, x_S*y_S)`.
|
|
1058
1105
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1106
|
+
Args:
|
|
1107
|
+
input (Tensor): The input tensor.
|
|
1108
|
+
dims (tuple[int]): The specified number of repetitions in each dimension.
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
Tensor
|
|
1062
1112
|
|
|
1063
1113
|
Supported Platforms:
|
|
1064
1114
|
``Ascend`` ``GPU`` ``CPU``
|
|
1065
1115
|
|
|
1066
1116
|
Examples:
|
|
1067
1117
|
>>> import mindspore
|
|
1068
|
-
>>>
|
|
1069
|
-
>>>
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
[3
|
|
1083
|
-
|
|
1084
|
-
[3
|
|
1085
|
-
[1
|
|
1086
|
-
[3
|
|
1087
|
-
|
|
1088
|
-
[3
|
|
1089
|
-
[1. 2. 1. 2.]
|
|
1090
|
-
[3. 4. 3. 4.]
|
|
1091
|
-
[1. 2. 1. 2.]
|
|
1092
|
-
[3. 4. 3. 4.]]]
|
|
1118
|
+
>>> input = mindspore.tensor([[1, 2], [3, 4]])
|
|
1119
|
+
>>> mindspore.ops.tile(input, (2, 3))
|
|
1120
|
+
Tensor(shape=[4, 6], dtype=Int64, value=
|
|
1121
|
+
[[1, 2, 1, 2, 1, 2],
|
|
1122
|
+
[3, 4, 3, 4, 3, 4],
|
|
1123
|
+
[1, 2, 1, 2, 1, 2],
|
|
1124
|
+
[3, 4, 3, 4, 3, 4]])
|
|
1125
|
+
>>> mindspore.ops.tile(input, (2, 3, 2))
|
|
1126
|
+
Tensor(shape=[2, 6, 4], dtype=Int64, value=
|
|
1127
|
+
[[[1, 2, 1, 2],
|
|
1128
|
+
[3, 4, 3, 4],
|
|
1129
|
+
[1, 2, 1, 2],
|
|
1130
|
+
[3, 4, 3, 4],
|
|
1131
|
+
[1, 2, 1, 2],
|
|
1132
|
+
[3, 4, 3, 4]],
|
|
1133
|
+
[[1, 2, 1, 2],
|
|
1134
|
+
[3, 4, 3, 4],
|
|
1135
|
+
[1, 2, 1, 2],
|
|
1136
|
+
[3, 4, 3, 4],
|
|
1137
|
+
[1, 2, 1, 2],
|
|
1138
|
+
[3, 4, 3, 4]]])
|
|
1093
1139
|
"""
|
|
1094
1140
|
tile_op = _get_cache_prim(Tile)()
|
|
1095
1141
|
return tile_op(input, dims)
|
|
@@ -1176,17 +1222,78 @@ class Cast(Primitive):
|
|
|
1176
1222
|
if data.dtype == dtype:
|
|
1177
1223
|
return (True, x)
|
|
1178
1224
|
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1179
|
-
x.set_cast_dtype()
|
|
1180
1225
|
return (True, x)
|
|
1181
1226
|
if isinstance(x, numbers.Number):
|
|
1182
1227
|
return (True, Tensor(x, dtype=dtype))
|
|
1183
1228
|
return (False, None)
|
|
1184
1229
|
|
|
1185
1230
|
def __call__(self, input_x, dtype):
|
|
1231
|
+
# Add for jit context.
|
|
1232
|
+
if jit_context() and jit_context().compiled:
|
|
1233
|
+
return None
|
|
1186
1234
|
should_elim, output = self.check_elim(input_x, dtype)
|
|
1187
1235
|
if should_elim:
|
|
1188
1236
|
return output
|
|
1189
|
-
|
|
1237
|
+
res = _convert_stub(pyboost_cast(self, [input_x, dtype_to_type_id('Cast', 'dtype', dtype)]))
|
|
1238
|
+
# Add for jit context.
|
|
1239
|
+
if jit_context():
|
|
1240
|
+
if validator.is_stub_tensor(res):
|
|
1241
|
+
res = res.stub_sync()
|
|
1242
|
+
return jit_context().run_op(self, res, input_x, dtype)
|
|
1243
|
+
return res
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
class TypeAs(Primitive):
|
|
1247
|
+
"""
|
|
1248
|
+
Returns first input tensor cast to the type of the with the second input tensor.
|
|
1249
|
+
|
|
1250
|
+
.. warning::
|
|
1251
|
+
This is an experimental API that is subject to change or deletion.
|
|
1252
|
+
|
|
1253
|
+
Note:
|
|
1254
|
+
When converting complex numbers to boolean type, the imaginary part of the complex number is not
|
|
1255
|
+
taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
|
|
1256
|
+
|
|
1257
|
+
Inputs:
|
|
1258
|
+
- **input** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
|
|
1259
|
+
The tensor whose data type is to be converted.
|
|
1260
|
+
- **other ** (Tensor) - The shape of tensor is :math:`(x_0, x_1, ..., x_R)`.
|
|
1261
|
+
The tensor whose data type is specified.
|
|
1262
|
+
|
|
1263
|
+
Outputs:
|
|
1264
|
+
Tensor, the shape of tensor is the same as `input`, :math:`(x_0, x_1, ..., x_R)`.
|
|
1265
|
+
|
|
1266
|
+
Raises:
|
|
1267
|
+
TypeError: If `input` is not a Tensor.
|
|
1268
|
+
TypeError: If `other` is not a Tensor.
|
|
1269
|
+
|
|
1270
|
+
Supported Platforms:
|
|
1271
|
+
``Ascend``
|
|
1272
|
+
|
|
1273
|
+
Examples:
|
|
1274
|
+
>>> import mindspore
|
|
1275
|
+
>>> import numpy as np
|
|
1276
|
+
>>> from mindspore import Tensor, ops
|
|
1277
|
+
>>> input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
|
1278
|
+
>>> input = Tensor(input_np)
|
|
1279
|
+
>>> other_np = np.random.randn(2, 3, 4).astype(np.int32)
|
|
1280
|
+
>>> other = Tensor(other_np)
|
|
1281
|
+
>>> type_as = ops.TypeAs()
|
|
1282
|
+
>>> output = type_as(input, other)
|
|
1283
|
+
>>> print(output.dtype)
|
|
1284
|
+
Int32
|
|
1285
|
+
>>> print(output.shape)
|
|
1286
|
+
(2, 3, 4, 5)
|
|
1287
|
+
"""
|
|
1288
|
+
|
|
1289
|
+
@prim_attr_register
|
|
1290
|
+
def __init__(self):
|
|
1291
|
+
pass
|
|
1292
|
+
|
|
1293
|
+
def __call__(self, input, other):
|
|
1294
|
+
if input.dtype == other.dtype:
|
|
1295
|
+
return input
|
|
1296
|
+
return _convert_stub(pyboost_type_as(self, [input, other]))
|
|
1190
1297
|
|
|
1191
1298
|
|
|
1192
1299
|
def to_sequence(val):
|
|
@@ -1506,15 +1613,52 @@ def infer_value_for_Tile(input, dims):
|
|
|
1506
1613
|
return Tensor(np.tile(input.asnumpy(), dims))
|
|
1507
1614
|
|
|
1508
1615
|
|
|
1616
|
+
def infer_value_for_EqualExt(x, y):
|
|
1617
|
+
"""Infer value for EqualExt op."""
|
|
1618
|
+
if x is None or y is None:
|
|
1619
|
+
return None
|
|
1620
|
+
result = np.equal(x.asnumpy(), y.asnumpy())
|
|
1621
|
+
value = False
|
|
1622
|
+
if result.all():
|
|
1623
|
+
value = True
|
|
1624
|
+
return Tensor(value)
|
|
1625
|
+
|
|
1626
|
+
|
|
1509
1627
|
def infer_value_for_Concat(tensors, axis):
|
|
1510
1628
|
"""Infer value for Concat op."""
|
|
1511
1629
|
if not tensors or None in tensors or axis is None:
|
|
1512
1630
|
return None
|
|
1513
1631
|
|
|
1514
|
-
tensor_to_concat = [x.asnumpy()
|
|
1632
|
+
tensor_to_concat = [x.asnumpy() for x in tensors]
|
|
1515
1633
|
return Tensor(np.concatenate(tensor_to_concat, axis), dtype=tensors[0].dtype)
|
|
1516
1634
|
|
|
1517
1635
|
|
|
1636
|
+
def infer_value_for_GatherD(input, dim, index):
|
|
1637
|
+
"""Infer value for GatherD op."""
|
|
1638
|
+
if input is None or dim is None or index is None:
|
|
1639
|
+
return None
|
|
1640
|
+
|
|
1641
|
+
input_np = input.asnumpy()
|
|
1642
|
+
index_np = index.asnumpy()
|
|
1643
|
+
|
|
1644
|
+
index_shape = index_np.shape
|
|
1645
|
+
multi_index = [np.indices(index_shape)[i] for i in range(len(index_shape))]
|
|
1646
|
+
multi_index[dim] = index_np
|
|
1647
|
+
|
|
1648
|
+
output = input_np[tuple(multi_index)]
|
|
1649
|
+
return Tensor(output, dtype=input.dtype)
|
|
1650
|
+
|
|
1651
|
+
|
|
1652
|
+
def infer_value_for_Softmax(input, axis):
|
|
1653
|
+
"""Infer value for Softmax op."""
|
|
1654
|
+
if input is None or axis is None:
|
|
1655
|
+
return None
|
|
1656
|
+
|
|
1657
|
+
e_input = np.exp(input.asnumpy())
|
|
1658
|
+
output = e_input / np.sum(e_input, axis=axis, keepdims=True)
|
|
1659
|
+
return Tensor(output, dtype=input.dtype)
|
|
1660
|
+
|
|
1661
|
+
|
|
1518
1662
|
def infer_value_for_ReduceSum(input_x, axis, keep_dims, skip_mode):
|
|
1519
1663
|
"""Infer value for ReduceSum op."""
|
|
1520
1664
|
value = None
|
|
@@ -1562,6 +1706,20 @@ def _infer_value_for_Reduce(input_x, axis, keep_dims, prim_name):
|
|
|
1562
1706
|
return value
|
|
1563
1707
|
|
|
1564
1708
|
|
|
1709
|
+
def infer_value_for_Arange(start, end, step, dtype=None):
|
|
1710
|
+
"""Infer value for Arange op."""
|
|
1711
|
+
if start is None or end is None or step is None:
|
|
1712
|
+
return None
|
|
1713
|
+
np_dtype = np.int64
|
|
1714
|
+
if dtype is None:
|
|
1715
|
+
has_float = any(isinstance(i, float) for i in [start, end, step])
|
|
1716
|
+
if has_float:
|
|
1717
|
+
np_dtype = np.float32
|
|
1718
|
+
else:
|
|
1719
|
+
np_dtype = mstype.dtype_to_nptype(typing.type_id_to_type(dtype))
|
|
1720
|
+
return Tensor(np.arange(start, end, step, dtype=np_dtype))
|
|
1721
|
+
|
|
1722
|
+
|
|
1565
1723
|
def _infer_value_for_ReduceExtand(input_x, axis, keep_dims, dtype, prim_name):
|
|
1566
1724
|
"""Infer value for Common ReduceExtand op."""
|
|
1567
1725
|
value = None
|
|
@@ -1633,6 +1791,95 @@ def infer_value_for_Cast(x, dst_type_enum=None):
|
|
|
1633
1791
|
return value
|
|
1634
1792
|
|
|
1635
1793
|
|
|
1794
|
+
def infer_value_for_LinalgVectorNorm(input_x, ord, dim, keepdim, dtype):
|
|
1795
|
+
"""Infer value for linalg_vector_norm op.
|
|
1796
|
+
Current version numpy is not support numpy.linalg.vector_norm.
|
|
1797
|
+
So using numpy.linalg.norm.
|
|
1798
|
+
"""
|
|
1799
|
+
if input_x is None or ord is None:
|
|
1800
|
+
return None
|
|
1801
|
+
if ord != 0:
|
|
1802
|
+
out = np.power(np.sum(np.power(np.abs(input_x.asnumpy()), ord), axis=dim, keepdims=keepdim), 1/ord)
|
|
1803
|
+
else:
|
|
1804
|
+
out = np.sum(input_x.asnumpy() != 0, axis=dim, keepdims=keepdim)
|
|
1805
|
+
if dtype is None:
|
|
1806
|
+
return Tensor(out)
|
|
1807
|
+
dtype_for_ms = typing.type_id_to_type(dtype)
|
|
1808
|
+
return Tensor(out, dtype=dtype_for_ms)
|
|
1809
|
+
|
|
1810
|
+
|
|
1811
|
+
def infer_value_for_LpNormV2(input_x, p=2, dim=None, keepdim=False, eps=1e-12):
|
|
1812
|
+
"""Infer value for linalg_vector_norm op.
|
|
1813
|
+
Current version numpy is not support numpy.linalg.vector_norm.
|
|
1814
|
+
So using numpy.linalg.norm.
|
|
1815
|
+
"""
|
|
1816
|
+
if input_x is None:
|
|
1817
|
+
return None
|
|
1818
|
+
return Tensor(np.linalg.norm(input_x.asnumpy(), axis=dim, keepdims=keepdim,
|
|
1819
|
+
ord=p))
|
|
1820
|
+
|
|
1821
|
+
|
|
1822
|
+
def infer_value_for_Svd(input_x, full_matrices, compute_uv):
|
|
1823
|
+
"""Infer value for Svd op."""
|
|
1824
|
+
if input_x is None:
|
|
1825
|
+
return None
|
|
1826
|
+
if bool(compute_uv):
|
|
1827
|
+
s, u, v = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=True)
|
|
1828
|
+
return Tensor(s), Tensor(u), Tensor(v)
|
|
1829
|
+
s = np.linalg.svd(input_x.asnumpy(), full_matrices=full_matrices, compute_uv=False)
|
|
1830
|
+
return Tensor(s), np.zeros(1), np.zeros(1)
|
|
1831
|
+
|
|
1832
|
+
|
|
1833
|
+
def infer_value_for_Div(input_x, other_x):
|
|
1834
|
+
"""Infer value for Div op."""
|
|
1835
|
+
if input_x is None or other_x is None:
|
|
1836
|
+
return None
|
|
1837
|
+
return Tensor(np.true_divide(input_x.asnumpy(), other_x.asnumpy()))
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def infer_value_for_Divs(input_x, other_x):
|
|
1841
|
+
"""Infer value for Divs op."""
|
|
1842
|
+
if input_x is None or other_x is None:
|
|
1843
|
+
return None
|
|
1844
|
+
tmp = np.true_divide(input_x.asnumpy(), other_x)
|
|
1845
|
+
if not input_x.shape:
|
|
1846
|
+
# tensor scalar has a special rule for data type promote
|
|
1847
|
+
if input_x.dtype in (mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64, mstype.int8, mstype.int16,
|
|
1848
|
+
mstype.int32, mstype.int64):
|
|
1849
|
+
res = Tensor(tmp, dtype=mstype.float32)
|
|
1850
|
+
else:
|
|
1851
|
+
res = Tensor(tmp, dtype=input_x.dtype)
|
|
1852
|
+
else:
|
|
1853
|
+
res = Tensor(tmp)
|
|
1854
|
+
return res
|
|
1855
|
+
|
|
1856
|
+
|
|
1857
|
+
def infer_value_for_DivMod(input_x, other_x, rounding_mode):
|
|
1858
|
+
"""Infer value for DivMod op."""
|
|
1859
|
+
if input_x is None or other_x is None:
|
|
1860
|
+
return None
|
|
1861
|
+
if rounding_mode == 1:
|
|
1862
|
+
# trunc
|
|
1863
|
+
return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x.asnumpy())))
|
|
1864
|
+
if rounding_mode == 2:
|
|
1865
|
+
# floor
|
|
1866
|
+
return Tensor(np.floor_divide(input_x.asnumpy(), other_x.asnumpy()))
|
|
1867
|
+
return None
|
|
1868
|
+
|
|
1869
|
+
|
|
1870
|
+
def infer_value_for_DivMods(input_x, other_x, rounding_mode):
|
|
1871
|
+
"""Infer value for DivMods op."""
|
|
1872
|
+
if input_x is None or other_x is None:
|
|
1873
|
+
return None
|
|
1874
|
+
if rounding_mode == 1:
|
|
1875
|
+
# trunc
|
|
1876
|
+
return Tensor(np.trunc(np.true_divide(input_x.asnumpy(), other_x)))
|
|
1877
|
+
if rounding_mode == 2:
|
|
1878
|
+
# floor
|
|
1879
|
+
return Tensor(np.floor_divide(input_x.asnumpy(), other_x))
|
|
1880
|
+
return None
|
|
1881
|
+
|
|
1882
|
+
|
|
1636
1883
|
def infer_value_for_ReduceMax(input_x, axis, keep_dims):
|
|
1637
1884
|
"""Infer value for ReduceMax op."""
|
|
1638
1885
|
return _infer_value_for_Reduce(input_x, axis, keep_dims, 'ReduceMax')
|
|
@@ -1791,7 +2038,7 @@ class Ones(Primitive):
|
|
|
1791
2038
|
Tensor, whose dtype and size are defined by input.
|
|
1792
2039
|
|
|
1793
2040
|
Raises:
|
|
1794
|
-
TypeError: If `shape` is neither an int nor
|
|
2041
|
+
TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
|
|
1795
2042
|
|
|
1796
2043
|
Supported Platforms:
|
|
1797
2044
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1821,13 +2068,23 @@ class Ones(Primitive):
|
|
|
1821
2068
|
pass
|
|
1822
2069
|
|
|
1823
2070
|
def __call__(self, size, type=None):
|
|
1824
|
-
|
|
2071
|
+
# Add for jit context.
|
|
2072
|
+
if jit_context() and jit_context().compiled:
|
|
2073
|
+
return None
|
|
2074
|
+
res = _convert_stub(pyboost_ones(self, [size, type if type is None \
|
|
1825
2075
|
else handler.dtype_to_type_id('Ones', 'type', type)]))
|
|
2076
|
+
# Add for jit context.
|
|
2077
|
+
if jit_context():
|
|
2078
|
+
if validator.is_stub_tensor(res):
|
|
2079
|
+
res = res.stub_sync()
|
|
2080
|
+
return jit_context().run_op(self, res, size, type if type is None \
|
|
2081
|
+
else handler.dtype_to_type_id('Ones', 'type', type))
|
|
2082
|
+
return res
|
|
1826
2083
|
|
|
1827
2084
|
|
|
1828
2085
|
class Zeros(Primitive):
|
|
1829
2086
|
r"""
|
|
1830
|
-
Zeros will be deprecated in the future. Please use class
|
|
2087
|
+
Zeros will be deprecated in the future. Please use class :func:`mindspore.ops.zeros` instead.
|
|
1831
2088
|
|
|
1832
2089
|
Creates a tensor filled with value zeros.
|
|
1833
2090
|
|
|
@@ -1845,7 +2102,7 @@ class Zeros(Primitive):
|
|
|
1845
2102
|
Tensor, whose dtype and size are defined by input.
|
|
1846
2103
|
|
|
1847
2104
|
Raises:
|
|
1848
|
-
TypeError: If `shape` is neither an int nor
|
|
2105
|
+
TypeError: If `shape` is neither an int nor a tuple/list/Tensor of int.
|
|
1849
2106
|
|
|
1850
2107
|
Supported Platforms:
|
|
1851
2108
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1871,8 +2128,18 @@ class Zeros(Primitive):
|
|
|
1871
2128
|
pass
|
|
1872
2129
|
|
|
1873
2130
|
def __call__(self, size, type=None):
|
|
1874
|
-
|
|
2131
|
+
# Add for jit context.
|
|
2132
|
+
if jit_context() and jit_context().compiled:
|
|
2133
|
+
return None
|
|
2134
|
+
res = _convert_stub(pyboost_zeros(self, [size, type if type is None else \
|
|
1875
2135
|
handler.dtype_to_type_id('Zeros', 'type', type)]))
|
|
2136
|
+
# Add for jit context.
|
|
2137
|
+
if jit_context():
|
|
2138
|
+
if validator.is_stub_tensor(res):
|
|
2139
|
+
res = res.stub_sync()
|
|
2140
|
+
return jit_context().run_op(self, res, size, type if type is None else \
|
|
2141
|
+
handler.dtype_to_type_id('Zeros', 'type', type))
|
|
2142
|
+
return res
|
|
1876
2143
|
|
|
1877
2144
|
|
|
1878
2145
|
def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mask=None, padding_mask=None,
|
|
@@ -1880,116 +2147,132 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1880
2147
|
scalar_value=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0,
|
|
1881
2148
|
input_layout='BSH', sparse_mode=0):
|
|
1882
2149
|
r"""
|
|
1883
|
-
|
|
2150
|
+
Implement self-attention calculations in training scenarios.
|
|
2151
|
+
|
|
2152
|
+
- B: Batch size. Value range 1 to 2k.
|
|
2153
|
+
- S1: Sequence length of `query`. Value range 1 to 512k.
|
|
2154
|
+
- S2: Sequence length of `key` and `value`. Value range 1 to 512k.
|
|
2155
|
+
- N1: Num heads of `query`. Value range 1 to 256.
|
|
2156
|
+
- N2: Num heads of `key` and `value`, and N2 must be a factor of N1.
|
|
2157
|
+
- D: Head size. The value ranges is a multiple of 16, with the max value of 512.
|
|
2158
|
+
- H1: Hidden size of `query`, which equals to N1 * D.
|
|
2159
|
+
- H2: Hidden size of `key` and `value`, which equals to N2 * D.
|
|
2160
|
+
|
|
2161
|
+
The self attention calculation formula is defined as:
|
|
1884
2162
|
|
|
1885
2163
|
.. math::
|
|
1886
2164
|
\begin{array}{ll} \\
|
|
1887
|
-
|
|
1888
|
-
\
|
|
2165
|
+
\text { attention_out }=\operatorname{Dropout}\left(\operatorname{Softmax}\left(\text
|
|
2166
|
+
{ Mask(scale } *\left(\text { query } * \mathrm{key}^{\top}\right)+\text { pse }\right)\text
|
|
2167
|
+
{, atten_mask), keep_prob) } *\right. \text { value }
|
|
1889
2168
|
\end{array}
|
|
1890
2169
|
|
|
1891
|
-
B -- Batch size. Value range 1 to 2k.
|
|
1892
|
-
S1 -- Sequence length of query. Value range 1 to 512k.
|
|
1893
|
-
S2 -- Sequence length of key and value. Value range 1 to 512k.
|
|
1894
|
-
N1 -- Num heads of query. Value range 1 to 256.
|
|
1895
|
-
N2 -- Num heads of key and value, and N2 must be a factor of N1.
|
|
1896
|
-
D -- Head size. The value ranges is a multiple of 16, with the max value of 512.
|
|
1897
|
-
H1 -- Hidden size of query, which equals to N1 * D.
|
|
1898
|
-
H2 -- Hidden size of key and value, which equals to N2 * D.
|
|
1899
|
-
|
|
1900
2170
|
.. warning::
|
|
1901
|
-
This is an experimental API that is subject to change or deletion.
|
|
2171
|
+
- This is an experimental API that is subject to change or deletion.
|
|
2172
|
+
- Only support on Atlas A2 training series.
|
|
1902
2173
|
|
|
1903
2174
|
Args:
|
|
1904
|
-
query (Tensor
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
head_num (int): The head num of query
|
|
1911
|
-
real_shift (
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
|
|
1917
|
-
|
|
1918
|
-
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
`
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
2175
|
+
query (Tensor): The query tensor. Input tensor of shape :math:`(B, S1, H1)`,
|
|
2176
|
+
:math:`(B, N1, S1, D)`, :math:`(S1, B, H1)`, :math:`(B, S1, N1, D)` or :math:`(T1, N1, D)`.
|
|
2177
|
+
The supported dtype is float16 and bfloat16.
|
|
2178
|
+
key (Tensor): The key tensor with the same dtype as `query`. Supported shape: :math:`(B, S2, H2)`,
|
|
2179
|
+
:math:`(B, N2, S2, D)`, :math:`(S2, B, H2)`, :math:`(B, S2, N2, D)` or :math:`(T2, N2, D)`.
|
|
2180
|
+
value (Tensor): The value tensor with the same dtype and shape as `key`.
|
|
2181
|
+
head_num (int): The head num of `query`, equal to N1.
|
|
2182
|
+
real_shift (Tensor, optional): The position embedding code which is also known as pse, it has the same
|
|
2183
|
+
dtype as `query`.
|
|
2184
|
+
Default: ``None``.
|
|
2185
|
+
If S is greater than 1024 and the mask of the lower triangle is used, only the inverse 1024 lines of
|
|
2186
|
+
the lower triangle is used for memory optimization. Input tensor of shape :math:`(B, N1, S1, S2)`,
|
|
2187
|
+
:math:`(1, N1, S1, S2)`, :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
|
|
2188
|
+
|
|
2189
|
+
- ALiBi scenario: `real_shift` must meet the ALiBi rule, and sparse_mode is 2 or 3 for the lower triangle.
|
|
2190
|
+
In this scenario, `real_shift` is :math:`(B, N1, 1024, S2)`, :math:`(1, N1, 1024, S2)`.
|
|
2191
|
+
- Non-ALiBi scenario: `real_shift` is :math:`(B, N1, S1, S2)`, :math:`(1, N1, S1, S2)`.
|
|
2192
|
+
- input_layout is TND: shape should be :math:`(B, N1, 1024, S2)` and :math:`(1, N1, 1024, S2)`.
|
|
2193
|
+
|
|
2194
|
+
drop_mask (Tensor, optional): The dropout mask tensor of uint8. Input tensor of shape
|
|
2195
|
+
:math:`(B, N1, S1, S2 // 8) or None`. `S2` is a multiple of 8 when not None. Default: ``None``.
|
|
2196
|
+
padding_mask (Tensor, optional): Reserved parameter. Not implemented yet. Default: ``None``.
|
|
2197
|
+
attn_mask (Tensor, optional): The attention mask tensor of bool or uint8. For each element, 0/False
|
|
2198
|
+
indicates retention and 1/True indicates discard. Input tensor of shape :math:`(B, N1, S1, S2)`,
|
|
2199
|
+
:math:`(B, 1, S1, S2)`, :math:`(S1, S2)` or :math:`(2048, 2048)`.
|
|
2200
|
+
Default: ``None``.
|
|
2201
|
+
|
|
2202
|
+
- In compression scenario, `sparse_mode` is 2, 3, or 4, `attn_mask` must be :math:`(2048, 2048)`.
|
|
2203
|
+
- When `sparse_mode` is 5, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`.
|
|
2204
|
+
- When `sparse_mode` is 0 and 1, `attn_mask` should be :math:`(B, N1, S1, S2)`, :math:`(B, 1, S1, S2)`,
|
|
2205
|
+
:math:`(S1, S2)`.
|
|
2206
|
+
|
|
2207
|
+
prefix (Union[Tensor, tuple[int], list[int]], optional): N value of each Batch in the prefix sparse calculation
|
|
2208
|
+
scenario. Input tensor of shape :math:`(B,)`. B max value 32. Not none only when `sparse_mode` is 5.
|
|
2209
|
+
Default: ``None``.
|
|
1933
2210
|
If S1 > S2, N ranges from 0 to S2. If S1 <= S2, N ranges from S2 - S1 to S2.
|
|
1934
|
-
actual_seq_qlen (Union[
|
|
1935
|
-
with increasing values and the last value equal to T1.
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
2211
|
+
actual_seq_qlen (Union[Tensor, tuple[int], list[int]], optional): Size of query corresponding to each batch,
|
|
2212
|
+
array with increasing values and the last value equal to T1.
|
|
2213
|
+
Default: ``None``.
|
|
2214
|
+
actual_seq_kvlen (Union[Tensor, tuple[int], list[int]], optional): Size of key and value corresponding
|
|
2215
|
+
to each batch, array with increasing values and the last value equal to T2.
|
|
2216
|
+
Default: ``None``.
|
|
2217
|
+
keep_prob (double, optional): The keep probability of dropout. Value range is (0.0, 1.0]. When `keep_prob`
|
|
2218
|
+
is 1.0, `drop_mask` should be None.
|
|
2219
|
+
Default: ``1.0``.
|
|
2220
|
+
scalar_value (double, optional): The scale factor of score. Generally, the value is 1.0 / (D ** 0.5).
|
|
2221
|
+
Default: ``1.0``.
|
|
2222
|
+
pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
|
|
2223
|
+
When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect.
|
|
2224
|
+
Default: ``2147483647``.
|
|
2225
|
+
next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
|
|
2226
|
+
When `sparse_mode` is set to 1, 2, 3, or 5, this parameter does not take effect. Default: ``2147483647``.
|
|
2227
|
+
The value of `pre_tokens` corresponds to S1, and the value of `next_tokens` corresponds to S2.
|
|
2228
|
+
They define the valid area on the `attn_mask` matrix. It must ensure that the band is not empty.
|
|
1947
2229
|
The following values are not allowed:
|
|
1948
2230
|
|
|
1949
2231
|
- pre_tokens < 0 and next_tokens < 0.
|
|
1950
2232
|
- (pre_tokens < 0 and next_tokens >= 0) and (next_tokens < abs(pre_tokens) or abs(pre_tokens) >= S2).
|
|
1951
2233
|
- (pre_tokens >= 0 and next_tokens < 0) and (abs(next_tokens) > pre_tokens or abs(next_tokens) >= S1).
|
|
1952
2234
|
|
|
1953
|
-
inner_precise (int): The parameter is reserved and not implemented yet. Default
|
|
1954
|
-
input_layout (str): Specifies the layout of input `query`, key and value
|
|
1955
|
-
"SBH", "BSND" or "TND". "TND" is an experimental format. Default: "BSH"
|
|
2235
|
+
inner_precise (int, optional): The parameter is reserved and not implemented yet. Default:``0``.
|
|
2236
|
+
input_layout (str, optional): Specifies the layout of input `query`, `key` and `value`. The value can be
|
|
2237
|
+
"BSH", "BNSD", "SBH", "BSND" or "TND". "TND" is an experimental format. Default: ``"BSH"``.
|
|
1956
2238
|
When input_layout is "TND", the following restrictions must be met.
|
|
1957
|
-
|
|
2239
|
+
Assume there are two lists that represent the length of the input sequence: list_seq_q and list_seq_k. Each
|
|
1958
2240
|
value in the list indicates the length of the sequence in the batch. For example, list_seq_q = [4, 2, 6],
|
|
1959
2241
|
list_seq_k = [10, 3, 9]. The element of list indicate S. T1 is sum(list_seq_q) = 12, T2 is
|
|
1960
2242
|
sum(list_seq_k) = 22.
|
|
1961
2243
|
max_seqlen_q = max(list_seq_q), max_seqlen_k = max(list_seq_k).
|
|
1962
2244
|
qk_pointer = sum(list_seq_q * list_seq_k), which is the sum of the element multiplication.
|
|
1963
2245
|
|
|
1964
|
-
- The lengths of two lists
|
|
1965
|
-
|
|
2246
|
+
- The lengths of two lists must be the same, and size of list is batch. batch is less than or equal to
|
|
2247
|
+
1024.
|
|
2248
|
+
- When `input_layout` is "TND", `actual_seq_qlen` and `actual_seq_kvlen` must be not none.
|
|
1966
2249
|
Otherwise, they are none.
|
|
1967
|
-
- The actual_seq_qlen and actual_seq_kvlen are the cumulative sum of sequence of key/value, so they must
|
|
2250
|
+
- The `actual_seq_qlen` and `actual_seq_kvlen` are the cumulative sum of sequence of key/value, so they must
|
|
1968
2251
|
be non-decreasing.
|
|
1969
|
-
- If real_shift is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
|
|
1970
|
-
list_seq_k is greater than 1024.
|
|
1971
|
-
S2 is equal to max_seqlen_k.
|
|
1972
|
-
-
|
|
1973
|
-
should be
|
|
1974
|
-
- The shape of drop_mask is (
|
|
1975
|
-
-
|
|
1976
|
-
-
|
|
1977
|
-
- When sparse_mode is 3, S1 of each batch should be less than or equal to S2.
|
|
2252
|
+
- If `real_shift` is not none, list_seq_q and list_seq_k must be same. The maximum value of list_seq_q and
|
|
2253
|
+
list_seq_k is greater than 1024. `real_shift` should be :math:`(B, N1, 1024, S2)` and
|
|
2254
|
+
:math:`(1, N1, 1024, S2)`, and S2 is equal to max_seqlen_k.
|
|
2255
|
+
- `attn_mask` must be a lower trianglar matrix, so `sparse_mode` should be 2 or 3. The shape of `attn_mask`
|
|
2256
|
+
should be :math:`(2048, 2048)`.
|
|
2257
|
+
- The shape of `drop_mask` is :math:`(qk\_pointer * N1 // 8,)`.
|
|
2258
|
+
- `prefix` is none.
|
|
2259
|
+
- `next_tokens` is 0, and `pre_tokens` is not less than max_seqlen_q.
|
|
2260
|
+
- When `sparse_mode` is 3, S1 of each batch should be less than or equal to S2.
|
|
1978
2261
|
- 0 should not exist in list_seq_k.
|
|
1979
2262
|
|
|
1980
|
-
sparse_mode (int): Indicates sparse mode. Default 0
|
|
2263
|
+
sparse_mode (int, optional): Indicates sparse mode. Default: ``0``.
|
|
1981
2264
|
|
|
1982
|
-
- 0: Indicates the defaultMask mode. If attn_mask is not passed, the mask operation is not performed,
|
|
1983
|
-
and
|
|
1984
|
-
attn_mask matrix (S1 * S2) needs to be passed in, indicating that the part between
|
|
1985
|
-
|
|
1986
|
-
- 1: Represents allMask, that is, passing in the complete attn_mask matrix.
|
|
2265
|
+
- 0: Indicates the defaultMask mode. If `attn_mask` is not passed, the mask operation is not performed,
|
|
2266
|
+
`next_tokens` and `pre_tokens` (internally assigned as INT_MAX) are ignored. If passed in, the full
|
|
2267
|
+
`attn_mask` matrix (S1 * S2) needs to be passed in, indicating that the part between `next_tokens` and
|
|
2268
|
+
`pre_tokens` needs to be calculated.
|
|
2269
|
+
- 1: Represents allMask, that is, passing in the complete `attn_mask` matrix.
|
|
1987
2270
|
- 2: Representing the leftUpCausal mode corresponds to the lower triangle scenario divided by the left
|
|
1988
|
-
vertex, and the optimized attn_mask matrix (2048*2048) is required.
|
|
2271
|
+
vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
1989
2272
|
- 3: Representing the rightDownCausal model corresponds to the lower triangle scene divided by the lower
|
|
1990
|
-
right vertex, and the optimized attn_mask matrix (2048*2048) is required.
|
|
1991
|
-
- 4: Represents the band scenario, that is, the part between counting
|
|
1992
|
-
optimized attn_mask matrix (2048*2048) is required.
|
|
2273
|
+
right vertex, and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
2274
|
+
- 4: Represents the band scenario, that is, the part between counting `next_tokens` and `pre_tokens`,
|
|
2275
|
+
and the optimized `attn_mask` matrix (2048*2048) is required.
|
|
1993
2276
|
- 5: Represents the prefix scenario, that is, on the basis of rightDownCasual, a matrix with length S1 and
|
|
1994
2277
|
width N is added to the left side. The value of N is obtained by the new input prefix, and the N value
|
|
1995
2278
|
of each Batch axis is different, not implemented yet.
|
|
@@ -1998,8 +2281,27 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
1998
2281
|
- 8: Represents the block_local scenario, not implemented yet.
|
|
1999
2282
|
|
|
2000
2283
|
Returns:
|
|
2001
|
-
attention_out (Tensor
|
|
2002
|
-
|
|
2284
|
+
attention_out (Tensor) - The output of attention, it has the same shape and dtype as `query`.
|
|
2285
|
+
|
|
2286
|
+
Raises:
|
|
2287
|
+
TypeError: Dtype of `query` is not float16 or bfloat16.
|
|
2288
|
+
TypeError: `query`, `key` and `value` don't have the same dtype.
|
|
2289
|
+
TypeError: Dtype of `attn_mask` is not bool or uint8.
|
|
2290
|
+
TypeError: Dtype of `real_shift` has a different dtype as `query`.
|
|
2291
|
+
TypeError: `scalar_value` or `keep_prob` is not a double number.
|
|
2292
|
+
TypeError: `input_layout` is not a string.
|
|
2293
|
+
TypeError: `num_key_value_heads` is not an int.
|
|
2294
|
+
TypeError: `sparse_mode` is not an int.
|
|
2295
|
+
TypeError: `real_shift` is not Tensor type.
|
|
2296
|
+
TypeError: `drop_mask` is not Tensor type.
|
|
2297
|
+
TypeError: `padding_mask` is not Tensor type.
|
|
2298
|
+
TypeError: `attn_mask` is not Tensor type.
|
|
2299
|
+
ValueError: `input_layout` is a string but not valid.
|
|
2300
|
+
RuntimeError: `head_num` is not divisible by `N2`.
|
|
2301
|
+
RuntimeError: `head_num` is not greater than 0.
|
|
2302
|
+
RuntimeError: `attn_mask` shape is not valid.
|
|
2303
|
+
RuntimeError: The specified value of `sparse_mode` is invalid.
|
|
2304
|
+
RuntimeError: D-axis of `query`, `key` and `value` is not the same.
|
|
2003
2305
|
|
|
2004
2306
|
Supported Platforms:
|
|
2005
2307
|
``Ascend``
|
|
@@ -2023,6 +2325,452 @@ def flash_attention_score(query, key, value, head_num, real_shift=None, drop_mas
|
|
|
2023
2325
|
actual_seq_kvlen)[3]
|
|
2024
2326
|
|
|
2025
2327
|
|
|
2328
|
+
def fused_infer_attention_score(query, key, value, *, pse_shift=None, atten_mask=None, actual_seq_lengths=None,
|
|
2329
|
+
actual_seq_lengths_kv=None, dequant_scale1=None, quant_scale1=None, dequant_scale2=None,
|
|
2330
|
+
quant_scale2=None, quant_offset2=None, antiquant_scale=None, antiquant_offset=None,
|
|
2331
|
+
key_antiquant_scale=None, key_antiquant_offset=None, value_antiquant_scale=None,
|
|
2332
|
+
value_antiquant_offset=None, block_table=None, query_padding_size=None,
|
|
2333
|
+
kv_padding_size=None, key_shared_prefix=None, value_shared_prefix=None,
|
|
2334
|
+
actual_shared_prefix_len=None, num_heads=1, scale=1.0, pre_tokens=2147483647,
|
|
2335
|
+
next_tokens=2147483647, input_layout='BSH', num_key_value_heads=0, sparse_mode=0,
|
|
2336
|
+
inner_precise=1, block_size=0, antiquant_mode=0, key_antiquant_mode=0,
|
|
2337
|
+
value_antiquant_mode=0, softmax_lse_flag=False):
|
|
2338
|
+
r"""
|
|
2339
|
+
This is a FlashAttention function designed for both incremental and full inference scenarios. It supports full
|
|
2340
|
+
inference scenarios (PromptFlashAttention) as well as incremental inference scenarios (IncreFlashAttention).
|
|
2341
|
+
When the S dimension of the query tensor (Q_S) equals 1, it enters the IncreFlashAttention branch; otherwise,
|
|
2342
|
+
it enters the PromptFlashAttention branch.
|
|
2343
|
+
|
|
2344
|
+
.. math::
|
|
2345
|
+
|
|
2346
|
+
Attention(Q,K,V) = Softmax(\frac{QK^{T}}{\sqrt{d}})V
|
|
2347
|
+
|
|
2348
|
+
.. warning::
|
|
2349
|
+
- This is an experimental API that is subject to change or deletion.
|
|
2350
|
+
- For Ascend, only the Atlas A2 training series products and Atlas 800I A2 inference products are currently
|
|
2351
|
+
supported.
|
|
2352
|
+
|
|
2353
|
+
Note:
|
|
2354
|
+
- The data layout formats of query, key and value can be interpreted from multiple dimensions, as shown below:
|
|
2355
|
+
|
|
2356
|
+
- B, Batch size. Represents the batch size of the input samples.
|
|
2357
|
+
- S, Sequence length. Represents the sequence length of the input samples. S1 represents the sequence length
|
|
2358
|
+
of the query, and S2 represents the sequence length of the key/value.
|
|
2359
|
+
- H, Head size. Represents the size of the hidden layer.
|
|
2360
|
+
- N, Head nums. Represents the number of attention heads.
|
|
2361
|
+
- D, Head dims. Represents the smallest unit size of the hidden layer, satisfying :math:`D = H / N`.
|
|
2362
|
+
|
|
2363
|
+
Args:
|
|
2364
|
+
query (Tensor): The query input of the attention structure, with data type of float16, bfloat16 or int8.
|
|
2365
|
+
Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or :math:`(B, S, N, D)`.
|
|
2366
|
+
key (Union[Tensor, tuple[Tensor], list[Tensor]]): The key input of the attention structure, with data type
|
|
2367
|
+
of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
|
|
2368
|
+
:math:`(B, S, N, D)`.
|
|
2369
|
+
value (Union[Tensor, tuple[Tensor], list[Tensor]]): The value input of the attention structure, with data
|
|
2370
|
+
type of float16, bfloat16 or int8. Input tensor of shape :math:`(B, S, H)`, :math:`(B, N, S, D)`, or
|
|
2371
|
+
:math:`(B, S, N, D)`.
|
|
2372
|
+
|
|
2373
|
+
Keyword Args:
|
|
2374
|
+
pse_shift (Tensor, optional): The padding mask tensor with data type of float16 or bfloat16.
|
|
2375
|
+
Default: ``None``.
|
|
2376
|
+
|
|
2377
|
+
- When Q_S is not 1, if pse_shift is of type float16, the query must be of type float16 or int8.
|
|
2378
|
+
If pse_shift is of type bfloat16, the query must also be of type bfloat16. The input shape
|
|
2379
|
+
must be either :math:`(B, N, Q\_S, KV\_S)` or :math:`(1, N, Q\_S, KV\_S)`, where Q_S corresponds to the
|
|
2380
|
+
S dimension of the query shape, and KV_S corresponds to the S dimension of the key and value shapes.
|
|
2381
|
+
For scenarios where the KV_S of pse_shift is not 32-aligned, it is recommended to pad it
|
|
2382
|
+
to 32 bytes to improve performance. The padding values for the extra portions are not restricted.
|
|
2383
|
+
- When Q_S is 1, if pse_shift is of type float16, the query must also be of type float16.
|
|
2384
|
+
If pse_shift is of type bfloat16, the query must be of type bfloat16. The input shape must be
|
|
2385
|
+
:math:`(B, N, 1, KV\_S)` or :math:`(1, N, 1, KV\_S)`, where KV_S corresponds to the S dimension of the
|
|
2386
|
+
key/value shapes. For scenarios where the KV\_S of pse_shift is not 32-aligned, it is recommended
|
|
2387
|
+
to pad it to 32 bytes to improve performance. The padding values for the extra portions are not
|
|
2388
|
+
restricted.
|
|
2389
|
+
|
|
2390
|
+
atten_mask (Tensor, optional): The attention mask tensor for the result of query*key with data type of int8,
|
|
2391
|
+
uint8 or bool. For each element, 0 indicates retention and 1 indicates discard.
|
|
2392
|
+
Default: ``None``.
|
|
2393
|
+
|
|
2394
|
+
- When Q_S is not 1, the recommended input shapes are Q_S,KV_S; B,Q_S,KV_S; 1,Q_S,KV_S; B,1,Q_S,KV_S
|
|
2395
|
+
or 1,1,Q_S,KV_S.
|
|
2396
|
+
- When Q_S is 1, the recommended input shapes are B,KV_S; B,1,KV_S or B,1,1,KV_S.
|
|
2397
|
+
|
|
2398
|
+
actual_seq_lengths (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
|
|
2399
|
+
query with data type of int64. If this parameter is not specified, it can be set to None, indicating that
|
|
2400
|
+
it matches the S dimension of the query shape. Constraint: The effective sequence length for each batch in
|
|
2401
|
+
this parameter should not exceed the corresponding batch's sequence length in the query. When Q_S is 1, this
|
|
2402
|
+
parameter is ignored.
|
|
2403
|
+
Default: ``None``.
|
|
2404
|
+
actual_seq_lengths_kv (Union[tuple[int], list[int], Tensor], optional): Describe actual sequence length of the
|
|
2405
|
+
key and value with data type of int64. If this parameter is not specified, it can be set to None,
|
|
2406
|
+
indicating that it matches the S dimension of the key and value shape. Constraint: The effective sequence
|
|
2407
|
+
length for each batch in this parameter should not exceed the corresponding batch's sequence length in the
|
|
2408
|
+
key and value.
|
|
2409
|
+
Default: ``None``.
|
|
2410
|
+
dequant_scale1 (Tensor, optional): Quantization factors for inverse quantization after BMM1 with data type of
|
|
2411
|
+
uint64. Supports per-tensor mode. If not used, set it to None.
|
|
2412
|
+
Default: ``None``.
|
|
2413
|
+
quant_scale1 (Tensor, optional): Quantization factors for quantization before BMM2 with data type of float32.
|
|
2414
|
+
Supports per-tensor mode. If not used, set it to None.
|
|
2415
|
+
Default: ``None``.
|
|
2416
|
+
dequant_scale2 (Tensor, optional): Quantization factors for inverse quantization after BMM2 with data type of
|
|
2417
|
+
uint64. Supports per-tensor mode. If not used, set it to None.
|
|
2418
|
+
Default: ``None``.
|
|
2419
|
+
quant_scale2 (Tensor, optional): Quantization factors for output quantization with data type of float32,
|
|
2420
|
+
bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
|
|
2421
|
+
Default: ``None``.
|
|
2422
|
+
quant_offset2 (Tensor, optional): Quantization offset for output quantization with data type of float32,
|
|
2423
|
+
bfloat16. Supports per-tensor and per-channel modes. If not used, set it to None.
|
|
2424
|
+
Default: ``None``.
|
|
2425
|
+
|
|
2426
|
+
For scenarios where the input is int8 and the output is int8: the parameters dequant_scale1, quant_scale1,
|
|
2427
|
+
dequant_scale2, and quant_scale2 must all be provided. The parameter quant_offset2 is optional and defaults
|
|
2428
|
+
to 0 if not specified.
|
|
2429
|
+
|
|
2430
|
+
- When the output is int8 and quant_scale2 and quant_offset2 are per-channel, left padding, Ring Attention,
|
|
2431
|
+
or D-axis misalignment (not 32-aligned) scenarios are not supported.
|
|
2432
|
+
- When the output is int8, scenarios with sparse_mode as band and pre_tokens/next_tokens being negative are
|
|
2433
|
+
not supported.
|
|
2434
|
+
- When the output is int8, if quant_offset2 is not None and empty tensor, and the sparse_mode, pre_tokens,
|
|
2435
|
+
and next_tokens meet the following conditions, certain rows of the matrix may not participate in
|
|
2436
|
+
calculations, leading to errors. This scenario will be intercepted (solution: if this scenario should
|
|
2437
|
+
not be intercepted, quantization should be performed outside the FIA interface, not enabled inside the
|
|
2438
|
+
FIA interface):
|
|
2439
|
+
|
|
2440
|
+
- sparse_mode = 0, if atten_mask is a not None and each batch's
|
|
2441
|
+
actual_seq_lengths - actual_seq_lengths_kv - pre_tokens > 0 or next_tokens < 0, it will meet the
|
|
2442
|
+
interception condition.
|
|
2443
|
+
- sparse_mode = 1 or 2, no interception condition will occur.
|
|
2444
|
+
- sparse_mode = 3, if each batch's actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the
|
|
2445
|
+
interception condition.
|
|
2446
|
+
- sparse_mode = 4, if pre_tokens < 0 or each batch's
|
|
2447
|
+
next_tokens + actual_seq_lengths - actual_seq_lengths_kv < 0, it will meet the interception
|
|
2448
|
+
condition.
|
|
2449
|
+
|
|
2450
|
+
For scenarios where the input is int8 and the output is float16: the parameters dequant_scale1,
|
|
2451
|
+
quant_scale1, and dequant_scale2 must all be provided.
|
|
2452
|
+
|
|
2453
|
+
For scenarios where the input is entirely float16 or bfloat16 and the output is int8: the parameter
|
|
2454
|
+
quant_scale2 must be provided. The parameter quant_offset2 is optional and defaults to 0 if not specified.
|
|
2455
|
+
|
|
2456
|
+
The parameters quant_scale2 and quant_offset2 support both per-tensor and per-channel modes and two data
|
|
2457
|
+
types: float32 and bfloat16. If quant_offset2 is provided, its type and shape must match those of
|
|
2458
|
+
quant_scale2. When the input is bfloat16, both float32 and bfloat16 are supported; otherwise, only float32
|
|
2459
|
+
is supported. For per-channel mode: When the output layout is BSH, the product of all dimensions in
|
|
2460
|
+
quant_scale2 must equal H. For other layouts, the product must equal N * D. When the output layout is BSH,
|
|
2461
|
+
it is recommended to set the shape of quant_scale2 as :math:`(1, 1, H)` or :math:`(H)`. When the output
|
|
2462
|
+
layout is BNSD, it is recommended to set the shape as :math:`(1, N, 1, D)` or :math:`(N, D)`. When the
|
|
2463
|
+
output layout is BSND, it is recommended to set the shape as :math:`(1, 1, N, D)` or :math:`(N, D)`.
|
|
2464
|
+
|
|
2465
|
+
antiquant_scale (Tensor, optional): Inverse quantization factors with data type of float16, float32 or bfloat16.
|
|
2466
|
+
Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
|
|
2467
|
+
Default: ``None``.
|
|
2468
|
+
antiquant_offset (Tensor, optional): Inverse quantization offset with data type of float16, float32 or bfloat16.
|
|
2469
|
+
Only support float16 when Q_S > 1. Supports per-tensor, per-channel and per-token modes.
|
|
2470
|
+
Default: ``None``.
|
|
2471
|
+
|
|
2472
|
+
Constraints for antiquant_scale and antiquant_offset parameters:
|
|
2473
|
+
|
|
2474
|
+
- Supports three modes: per-channel, per-tensor, and per-token:
|
|
2475
|
+
|
|
2476
|
+
- Per-channel mode: The shape of both parameters in the BNSD layout is :math:`(2, N, 1, D)`, the shape
|
|
2477
|
+
in the BSND layout is :math:`(2, N, D)`, and the shape in the BSH layout is :math:`(2, H)`, where 2
|
|
2478
|
+
corresponds to the key and value, and N represents num_key_value_heads. The parameter data type is
|
|
2479
|
+
the same as the query data type, and antiquant_mode should be set to 0.
|
|
2480
|
+
- Per-tensor mode: The shape of both parameters is :math:`(2)`, the data type is the same as the query
|
|
2481
|
+
data type, and antiquant_mode should be set to 0.
|
|
2482
|
+
- Per-token mode: The shape of both parameters is :math:`(2, B, S)`, the data type is fixed to float32,
|
|
2483
|
+
and antiquant_mode should be set to 1.
|
|
2484
|
+
|
|
2485
|
+
- Supports both symmetric and asymmetric quantization:
|
|
2486
|
+
|
|
2487
|
+
- Asymmetric quantization mode: Both antiquant_scale and antiquant_offset must be provided.
|
|
2488
|
+
- Symmetric quantization mode: antiquant_offset can be empty (``None``). If antiquant_offset is empty,
|
|
2489
|
+
symmetric quantization is performed. If antiquant_offset is provided, asymmetric quantization is
|
|
2490
|
+
performed.
|
|
2491
|
+
|
|
2492
|
+
key_antiquant_scale (Tensor, optional): Inverse quantization factors for the key, with data type of float16,
|
|
2493
|
+
float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2494
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2495
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2496
|
+
key_antiquant_offset (Tensor, optional): Inverse quantization offset for the key, with data type of float16,
|
|
2497
|
+
float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2498
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2499
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2500
|
+
value_antiquant_scale (Tensor, optional): Inverse quantization factors for the value, with data type of
|
|
2501
|
+
float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2502
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2503
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2504
|
+
value_antiquant_offset (Tensor, optional): Inverse quantization offset for the value, with data type of
|
|
2505
|
+
float16, float32 or bfloat16, when the KV fake quantization parameters are separated.
|
|
2506
|
+
Supports per-tensor, per-channel and per-token modes.
|
|
2507
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2508
|
+
block_table (Tensor, optional): Block mapping table in KV cache for PageAttention, with data type of int32.
|
|
2509
|
+
If not used, set it to None.
|
|
2510
|
+
Default: ``None``. Invalid when Q_S > 1.
|
|
2511
|
+
query_padding_size (Tensor, optional): The query padding size with data type of int64. Indicates whether the
|
|
2512
|
+
data in each batch of the query is right-aligned, and how many elements are right-aligned.
|
|
2513
|
+
Default: ``None``. Invalid when Q_S is 1.
|
|
2514
|
+
kv_padding_size (Tensor, optional): The key and value padding size with data type of int64. Indicates whether
|
|
2515
|
+
the data in each batch of the key and value is right-aligned, and how many elements are right-aligned.
|
|
2516
|
+
Default: ``None``. Invalid when Q_S is 1.
|
|
2517
|
+
key_shared_prefix (Tensor, optional): Shared prefix of the key. This is a reserved parameter and is not yet
|
|
2518
|
+
enabled. Default: ``None``.
|
|
2519
|
+
value_shared_prefix (Tensor, optional): Shared prefix of the value. This is a reserved parameter and is not yet
|
|
2520
|
+
enabled. Default: ``None``.
|
|
2521
|
+
actual_shared_prefix_len (Union[tuple[int], list[int], Tensor], optional): Describe the actual length of shared
|
|
2522
|
+
prefix. This is a reserved parameter and is not yet enabled.
|
|
2523
|
+
Default: ``None``.
|
|
2524
|
+
num_heads (int, optional): The number of heads in the query, equal to N when input_layout is BNSD.
|
|
2525
|
+
Default: ``1``.
|
|
2526
|
+
scale (double, optional): The scale value indicating the scale coefficient, which serves as the scalar value for
|
|
2527
|
+
the Muls in the calculation. Generally, the value is :math:`1.0 / \sqrt{d}`. Default: ``1.0``.
|
|
2528
|
+
pre_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted forward.
|
|
2529
|
+
Default: ``2147483647``. Invalid when Q_S is 1.
|
|
2530
|
+
next_tokens (int, optional): Parameter for sparse computation, represents how many tokens are counted backward.
|
|
2531
|
+
Default: ``2147483647``. Invalid when Q_S is 1.
|
|
2532
|
+
input_layout (str, optional): Specifies the layout of input query, key and value. BSH, BNSD, BSND or
|
|
2533
|
+
BNSD_BSND is supported. When the layout is BNSD_BSND, it means the input is in the BNSD format and
|
|
2534
|
+
the output is in the BSND format, this is only supported when Q_S > 1.
|
|
2535
|
+
Default: ``BSH``.
|
|
2536
|
+
num_key_value_heads (int, optional): Head numbers of key/value which are used in GQA (Grouped-Query Attention)
|
|
2537
|
+
scenario. Default: ``0``. A value of 0 means it is equal to the number of key/value heads. The num_heads
|
|
2538
|
+
must be divisible by num_key_value_heads, and the ratio of num_heads to num_key_value_heads must not be
|
|
2539
|
+
greater than 64. When the layout is BNSD, the num_key_value_heads must also equals to the N dimension of
|
|
2540
|
+
the key/value shapes, otherwise, an execution error will occur.
|
|
2541
|
+
sparse_mode (int, optional): Indicates sparse mode. Default ``0``. Invalid when Q_S is 1.
|
|
2542
|
+
|
|
2543
|
+
- 0: Indicates the defaultMask mode. If atten_mask is not passed, the mask operation is not performed,
|
|
2544
|
+
and pre_tokens and next_tokens(internally assigned as INT_MAX) are ignored. If passed in, the complete
|
|
2545
|
+
atten_mask matrix (S1 * S2) also must be passed in, indicating that the part between pre_tokens and
|
|
2546
|
+
next_tokens needs to be calculated.
|
|
2547
|
+
- 1: Represents allMask. The complete atten_mask matrix (S1 * S2) is required.
|
|
2548
|
+
- 2: Represents the mask in leftUpCausal mode. The optimized atten_mask matrix (2048*2048) is required.
|
|
2549
|
+
- 3: Represents the mask in rightDownCausal mode, corresponding to the lower triangular scenario divided by
|
|
2550
|
+
the right vertex. The optimized atten_mask matrix (2048*2048) is required.
|
|
2551
|
+
- 4: Represents the mask in band mode, that is, the part between counting pre_tokens and next_tokens. The
|
|
2552
|
+
optimized atten_mask matrix (2048*2048) is required.
|
|
2553
|
+
- 5: Represents the prefix scenario, not implemented yet.
|
|
2554
|
+
- 6: Represents the global scenario, not implemented yet.
|
|
2555
|
+
- 7: Represents the dilated scenario, not implemented yet.
|
|
2556
|
+
- 8: Represents the block_local scenario, not implemented yet.
|
|
2557
|
+
|
|
2558
|
+
inner_precise (int, optional): There are four modes: 0, 1, 2, and 3, represented by 2 bits: bit 0 (bit0)
|
|
2559
|
+
represents the choice for high precision or high performance, and bit 1 (bit1) indicates whether row-wise
|
|
2560
|
+
invalidity correction is applied.
|
|
2561
|
+
|
|
2562
|
+
- 0: Enable high-precise mode, without row-wise invalidity correction.
|
|
2563
|
+
- 1: High-performance mode, without row-wise invalidity correction.
|
|
2564
|
+
- 2: Enable high-precise mode, with row-wise invalidity correction.
|
|
2565
|
+
- 3: High-performance mode, with row-wise invalidity correction.
|
|
2566
|
+
|
|
2567
|
+
When Q_S > 1, if sparse_mode is 0 or 1 and a user-defined mask is provided, it is recommended to enable
|
|
2568
|
+
row-wise invalidity correction. Only support 0 and 1 when Q_S is 1. Default: ``1``.
|
|
2569
|
+
|
|
2570
|
+
High-precise and high-performance are only effective for float16 inputs; Row invalidity correction
|
|
2571
|
+
is effective for float16, bfloat16, and int8 inputs.
|
|
2572
|
+
Currently, 0 and 1 are reserved configuration values. If there is a situation where an entire row in the
|
|
2573
|
+
"mask portion involved in computation" is all 1s, precision may degrade. In such cases, you can try
|
|
2574
|
+
setting this parameter to 2 or 3 to enable row invalidity correction for improved precision. However,
|
|
2575
|
+
this configuration will result in decreased performance.
|
|
2576
|
+
If the function can detect the presence of invalid row scenarios, e.g. in cases where sparse_mode is 3
|
|
2577
|
+
and S_q > S_kv, it will automatically enable row invalidity computation.
|
|
2578
|
+
|
|
2579
|
+
block_size (int, optional): Maximum number of tokens per block in the KV cache block for PageAttention.
|
|
2580
|
+
Default: ``0``. Invalid when Q_S > 1.
|
|
2581
|
+
antiquant_mode (int, optional): Fake-quantization mode, 0: per-channel (per-channel includes per-tensor),
|
|
2582
|
+
1: per-token. The per-channel and per-tensor modes can be distinguished by the dimension of the input
|
|
2583
|
+
shape. When the dimension is 1, it runs in per-tensor mode; otherwise, it runs in per-channel mode.
|
|
2584
|
+
Default: ``0``. Invalid when Q_S > 1.
|
|
2585
|
+
key_antiquant_mode (int, optional): Fake-quantization mode for the key. 0: per-channel (per-channel includes
|
|
2586
|
+
per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
|
|
2587
|
+
value_antiquant_mode (int, optional): Fake-quantization mode for the value. 0: per-channel (per-channel includes
|
|
2588
|
+
per-tensor), 1: per-token. Default: ``0``. Invalid when Q_S > 1.
|
|
2589
|
+
softmax_lse_flag (bool, optional): Whether to output softmax_lse. Default: ``False``.
|
|
2590
|
+
|
|
2591
|
+
Returns:
|
|
2592
|
+
attention_out (Tensor), the attention score with data type of float16, bfloat16 or int8. When the input_layout
|
|
2593
|
+
is BNSD_BSND, the shape is :math:`(B, S, N, D)`. In all other cases, the shape is consistent with the
|
|
2594
|
+
input query shape.
|
|
2595
|
+
|
|
2596
|
+
softmax_lse (Tensor), the softmax_lse with data type of float32, obtained by taking the lse (log, sum and exp)
|
|
2597
|
+
of the result of query*key. Specifically, the Ring Attention algorithm first takes the max of the result of
|
|
2598
|
+
query*key, obtaining softmax_max. The result of query*key is then subtracted by softmax_max, followed by
|
|
2599
|
+
taking exp, and then the sum is computed to obtain softmax_sum. Finally, the log of softmax_sum is taken,
|
|
2600
|
+
and softmax_max is added to obtain softmax_lse. The softmax_lse is only calculated when softmax_lse_flag
|
|
2601
|
+
is True, and the shape would be :math:`(B, N, Q\_S, 1)`. If softmax_lse_flag is False, then a tensor with
|
|
2602
|
+
shape :math:`(1)` filled with zeros would be returned. In graph mode with JitConfig set to O2, please ensure
|
|
2603
|
+
that the softmax_lse_flag is enabled before using softmax_lse; otherwise, an exception will occur.
|
|
2604
|
+
|
|
2605
|
+
Constraints:
|
|
2606
|
+
- Full Inference Scenario (Q_S > 1):
|
|
2607
|
+
|
|
2608
|
+
- Query, key, and value inputs functional usage restrictions:
|
|
2609
|
+
|
|
2610
|
+
- The B axis supports values less than or equal to 65535. If the input type includes int8, or
|
|
2611
|
+
if the input type is float16 or bfloat16 and the D axis is not 16-aligned, the B axis is only
|
|
2612
|
+
supported up to 128.
|
|
2613
|
+
- The N axis supports values less than or equal to 256, and the D axis supports values less than
|
|
2614
|
+
or equal to 512.
|
|
2615
|
+
- The S axis supports values less than or equal to 20,971,520 (20M). In some long sequence
|
|
2616
|
+
scenarios, if the computation load is too large, it may cause a timeout in the PFA operator
|
|
2617
|
+
(AICore error type with errorStr: "timeout or trap error"). In this case, it is recommended to
|
|
2618
|
+
perform an S split. Note: The computational load is affected by B, S, N, D, etc.; the larger the
|
|
2619
|
+
values, the greater the computational load. Typical long sequence timeout scenarios (where the
|
|
2620
|
+
product of B, S, N, and D is large) include, but are not limited to:
|
|
2621
|
+
|
|
2622
|
+
1. B=1, Q_N=20, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
|
|
2623
|
+
2. B=1, Q_N=2, Q_S=20971520, D=256, KV_N=2, KV_S=20971520;
|
|
2624
|
+
3. B=20, Q_N=1, Q_S=2097152, D=256, KV_N=1, KV_S=2097152;
|
|
2625
|
+
4. B=1, Q_N=10, Q_S=2097152, D=512, KV_N=1, KV_S=2097152.
|
|
2626
|
+
|
|
2627
|
+
- When the query, key, value, or attention_out type includes int8, the D axis must be 32-aligned.
|
|
2628
|
+
If all types are float16 or bfloat16, the D axis must be 16-aligned.
|
|
2629
|
+
|
|
2630
|
+
- The sparse_mode parameter currently only supports values 0, 1, 2, 3, and 4. Using any other values
|
|
2631
|
+
will result in an error.
|
|
2632
|
+
|
|
2633
|
+
- When sparse_mode = 0, if the atten_mask is None, or if the atten_mask is provided in the left
|
|
2634
|
+
padding scenario, the input parameters pre_tokens and next_tokens are ignored.
|
|
2635
|
+
- When sparse_mode = 2, 3, or 4, the shape of the atten_mask must be S,S or 1,S,S or 1,1,S,S, where
|
|
2636
|
+
S must be fixed at 2048, and the user must ensure the atten_mask is a lower triangular matrix. If
|
|
2637
|
+
no atten_mask is provided or if the shape is incorrect, an error will occur.
|
|
2638
|
+
- In sparse_mode = 1, 2, 3 scenarios, the pre_tokens and next_tokens inputs are ignored and assigned
|
|
2639
|
+
according to the relevant rules.
|
|
2640
|
+
|
|
2641
|
+
- The KV cache de-quantization only supports queries of type float16, where int8 keys and values are
|
|
2642
|
+
de-quantized to float16. The data range of the input key/value and the antiquant_scale must have a
|
|
2643
|
+
product within the range of (-1, 1). High-performance mode can guarantee precision; otherwise,
|
|
2644
|
+
high-precision mode should be enabled to ensure accuracy.
|
|
2645
|
+
|
|
2646
|
+
- Query left padding scenario:
|
|
2647
|
+
|
|
2648
|
+
- In the query left padding scenario, the formula for calculating the starting point of the query
|
|
2649
|
+
transport is: Q_S - query_padding_size - actual_seq_lengths. The formula for the
|
|
2650
|
+
ending point of the query transport is: Q_S - query_padding_size. The query transport
|
|
2651
|
+
starting point must not be less than 0, and the ending point must not exceed Q_S; otherwise,
|
|
2652
|
+
the results will be incorrect.
|
|
2653
|
+
- If the kv_padding_size in the query left padding scenario is less than 0, it will be set to 0.
|
|
2654
|
+
- The query left padding scenario must be enabled together with the actual_seq_lengths parameter,
|
|
2655
|
+
otherwise, the default is the query right padding scenario.
|
|
2656
|
+
- The query left padding scenario does not support PageAttention and cannot be enabled together with
|
|
2657
|
+
the block_table parameter.
|
|
2658
|
+
|
|
2659
|
+
- KV left padding scenario:
|
|
2660
|
+
|
|
2661
|
+
- In the KV left padding scenario, the formula for calculating the starting point of the key and
|
|
2662
|
+
value transport is: KV_S - kv_padding_size - actual_seq_lengths_kv. The formula
|
|
2663
|
+
for the ending point of the key and value transport is: KV_S - kv_padding_size. The
|
|
2664
|
+
key and value transport starting point must not be less than 0, and the ending point must not
|
|
2665
|
+
exceed KV_S; otherwise, the results will be incorrect.
|
|
2666
|
+
- If the kv_padding_size in the KV left padding scenario is less than 0, it will be set to 0.
|
|
2667
|
+
- The KV left padding scenario must be enabled together with the actual_seq_lengths_kv parameter,
|
|
2668
|
+
otherwise, the default is the KV right padding scenario.
|
|
2669
|
+
- The KV left padding scenario does not support PageAttention and cannot be enabled together with
|
|
2670
|
+
the block_table parameter.
|
|
2671
|
+
|
|
2672
|
+
- pse_shift functional usage restrictions:
|
|
2673
|
+
|
|
2674
|
+
- This function is supported when the query data type is float16, bfloat16, or int8.
|
|
2675
|
+
- If the query data type is float16 and pse_shift is enabled, it will force high-precision mode,
|
|
2676
|
+
inheriting the limitations of high-precision mode.
|
|
2677
|
+
- Q_S must be greater than or equal to the length of the query S, and KV_S must be greater than
|
|
2678
|
+
or equal to the length of the key S.
|
|
2679
|
+
|
|
2680
|
+
- KV fake quantization parameter separation is not currently supported.
|
|
2681
|
+
|
|
2682
|
+
- Incremental Inference Scenario (Q_S is 1):
|
|
2683
|
+
|
|
2684
|
+
- Query, key, and value inputs functional usage restrictions:
|
|
2685
|
+
|
|
2686
|
+
- The B axis supports values less than or equal to 65,536.
|
|
2687
|
+
- The N axis supports values less than or equal to 256.
|
|
2688
|
+
- The D axis supports values less than or equal to 512.
|
|
2689
|
+
- Scenarios where the input types of query, key, and value are all int8 are not supported.
|
|
2690
|
+
|
|
2691
|
+
- Page attention scenario:
|
|
2692
|
+
|
|
2693
|
+
- The necessary condition to enable page attention is that the block_table exists and is valid.
|
|
2694
|
+
The key and value are arranged in contiguous memory according to the indices in the block_table.
|
|
2695
|
+
The key and value dtypes supported are float16, bfloat16, and int8. In this scenario, the
|
|
2696
|
+
input_layout parameter for key and value is invalid.
|
|
2697
|
+
- block_size is a user-defined parameter, and its value will affect the performance of page
|
|
2698
|
+
attention. When enabling page attention, a non-zero value for block_size must be provided, and
|
|
2699
|
+
the maximum value for block_size is 512.
|
|
2700
|
+
- If the input types of key and value are float16 or bfloat16, they must be 16-aligned. If the
|
|
2701
|
+
input types are int8, they must be 32-aligned, with 128 being recommended. In general, page
|
|
2702
|
+
attention can increase throughput but may lead to a performance decrease.
|
|
2703
|
+
- In the page attention enabled scenario, when the KV cache layout is (blocknum, block_size, H) and
|
|
2704
|
+
num_key_value_heads * D exceeds 64K, an error will be reported due to hardware
|
|
2705
|
+
instruction constraints. This can be resolved by enabling GQA (reducing num_key_value_heads) or
|
|
2706
|
+
adjusting the KV cache layout to (blocknum, num_key_value_heads, block_size, D).
|
|
2707
|
+
- The product of all dimensions of the shape of the key and value tensors in the page attention
|
|
2708
|
+
scenario must not exceed the representable range of int32.
|
|
2709
|
+
|
|
2710
|
+
- In the page attention enabled scenario, the input S must be greater than or equal to
|
|
2711
|
+
max_block_num_per_seq * block_size.
|
|
2712
|
+
|
|
2713
|
+
- Enabling attention mask (e.g., mask shape = (B, 1, 1, S))
|
|
2714
|
+
- Enabling pse_shift (e.g., pse_shift shape = (B, N, 1, S))
|
|
2715
|
+
- Enabling fake quantization in per-token mode (e.g., antiquant_scale and antiquant_offset shapes =
|
|
2716
|
+
(2, B, S)) are also supported.
|
|
2717
|
+
|
|
2718
|
+
- KV left padding scenario:
|
|
2719
|
+
|
|
2720
|
+
- In the KV left padding scenario, the formula for calculating the starting point of the KV cache
|
|
2721
|
+
transport is: KV_S - kv_padding_size - actual_seq_lengths. The formula for the endpoint of the
|
|
2722
|
+
KV cache transport is: KV_S - kv_padding_size. If the starting point or endpoint of the KV cache
|
|
2723
|
+
is less than 0, the returned data result will be all zeros.
|
|
2724
|
+
- If kv_padding_size is less than 0 in the KV left padding scenario, it will be set to 0.
|
|
2725
|
+
- The KV left padding scenario must be enabled together with the actual_seq_lengths parameter,
|
|
2726
|
+
otherwise, it defaults to the KV right padding scenario.
|
|
2727
|
+
- The KV left padding scenario must be enabled together with the atten_mask parameter, and the
|
|
2728
|
+
atten_mask must be correctly applied to hide invalid data. Otherwise, accuracy issues may arise.
|
|
2729
|
+
|
|
2730
|
+
- pse_shift functional usage restrictions:
|
|
2731
|
+
|
|
2732
|
+
- The data type of pse_shift must match the data type of the query.
|
|
2733
|
+
- Only the D axis alignment is supported, meaning the D axis must be divisible by 16.
|
|
2734
|
+
|
|
2735
|
+
- KV fake quantization parameter separation:
|
|
2736
|
+
|
|
2737
|
+
- key_antiquant_mode and value_antiquant_mode must be consistent.
|
|
2738
|
+
- key_antiquant_scale and value_antiquant_scale must either both be empty or both non-empty.
|
|
2739
|
+
- key_antiquant_offset and value_antiquant_offset must either both be empty or both non-empty.
|
|
2740
|
+
- When both key_antiquant_scale and value_antiquant_scale are non-empty, their shapes must be
|
|
2741
|
+
consistent.
|
|
2742
|
+
- When both key_antiquant_offset and value_antiquant_offset are non-empty, their shapes must be
|
|
2743
|
+
consistent.
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
Supported Platforms:
|
|
2747
|
+
``Ascend``
|
|
2748
|
+
|
|
2749
|
+
Examples:
|
|
2750
|
+
>>> from mindspore import ops
|
|
2751
|
+
>>> from mindspore import Tensor
|
|
2752
|
+
>>> import numpy as np
|
|
2753
|
+
>>> B, N, S, D = 1, 8, 1024, 128
|
|
2754
|
+
>>> query = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2755
|
+
>>> key = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2756
|
+
>>> value = Tensor(np.random.rand(B, N, S, D).astype(np.float16))
|
|
2757
|
+
>>> out = ops.fused_infer_attention_score(query, key, value, num_heads=N, input_layout='BNSD')
|
|
2758
|
+
>>> print(out[0].shape)
|
|
2759
|
+
(1, 8, 1024, 128)
|
|
2760
|
+
"""
|
|
2761
|
+
fias_op = _get_cache_prim(FusedInferAttentionScore)(num_heads, scale, pre_tokens, next_tokens, input_layout,
|
|
2762
|
+
num_key_value_heads, sparse_mode, inner_precise, block_size,
|
|
2763
|
+
antiquant_mode, softmax_lse_flag, key_antiquant_mode,
|
|
2764
|
+
value_antiquant_mode)
|
|
2765
|
+
key_list = key if isinstance(key, (tuple, list)) else [key]
|
|
2766
|
+
value_list = value if isinstance(value, (tuple, list)) else [value]
|
|
2767
|
+
return fias_op(query, key_list, value_list, pse_shift, atten_mask, actual_seq_lengths, actual_seq_lengths_kv,
|
|
2768
|
+
dequant_scale1, quant_scale1, dequant_scale2, quant_scale2, quant_offset2, antiquant_scale,
|
|
2769
|
+
antiquant_offset, block_table, query_padding_size, kv_padding_size, key_antiquant_scale,
|
|
2770
|
+
key_antiquant_offset, value_antiquant_scale, value_antiquant_offset, key_shared_prefix,
|
|
2771
|
+
value_shared_prefix, actual_shared_prefix_len)
|
|
2772
|
+
|
|
2773
|
+
|
|
2026
2774
|
class WhileLoop(Primitive):
|
|
2027
2775
|
"""
|
|
2028
2776
|
Provide a useful op for reducing compilation times of while loop.
|
|
@@ -2195,7 +2943,7 @@ class Scan(Primitive):
|
|
|
2195
2943
|
|
|
2196
2944
|
class ForiLoop(Primitive):
|
|
2197
2945
|
"""
|
|
2198
|
-
|
|
2946
|
+
Performs a loop operation within the specified range.
|
|
2199
2947
|
The execution logic of the ForiLoop operator can be roughly represented by the following code:
|
|
2200
2948
|
|
|
2201
2949
|
.. code-block:: python
|