mindspore 2.4.10__cp39-cp39-win_amd64.whl → 2.5.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +385 -261
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020-
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -31,15 +31,18 @@ from mindspore.ops.primitive import PrimitiveWithInfer
|
|
|
31
31
|
from mindspore.ops.primitive import PrimitiveWithCheck
|
|
32
32
|
from mindspore.ops.primitive import prim_attr_register
|
|
33
33
|
from mindspore.run_check._check_version import AscendEnvChecker
|
|
34
|
-
from
|
|
34
|
+
from mindspore._c_expression import pyboost_all_finite
|
|
35
|
+
from mindspore.common._stub_tensor import _convert_stub
|
|
36
|
+
from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, GLU, ReLU, ReLU6, Dense, Tanh,
|
|
35
37
|
Elu, Sigmoid, Softmax, SoftplusExt, HSwish, HSigmoid, AvgPool, BiasAdd,
|
|
36
38
|
NLLLoss, OneHot, GeLU, FastGeLU, PReLU, RmsNorm, IncreFlashAttention, MSELossExt,
|
|
37
39
|
GridSampler3D, GridSampler2D, LayerNorm, LayerNormExt, HShrink, AdamWeightDecay, Dropout,
|
|
38
40
|
ApplyRotaryPosEmb, PagedAttention, PagedAttentionMask, ReshapeAndCache,
|
|
39
|
-
FlashAttentionScore, Embedding, UpsampleNearest1D, UpsampleNearest2D,
|
|
41
|
+
FlashAttentionScore, PromptFlashAttention, Embedding, UpsampleNearest1D, UpsampleNearest2D,
|
|
40
42
|
UpsampleNearest3D, UpsampleTrilinear3D,
|
|
41
43
|
UpsampleBilinear2D, UpsampleLinear1D,
|
|
42
|
-
BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink
|
|
44
|
+
BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink,
|
|
45
|
+
SmoothL1Loss)
|
|
43
46
|
from .manually_defined import BatchNorm
|
|
44
47
|
|
|
45
48
|
|
|
@@ -612,12 +615,12 @@ class InstanceNorm(PrimitiveWithInfer):
|
|
|
612
615
|
Inputs:
|
|
613
616
|
- **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`,
|
|
614
617
|
data type: float16 or float32.
|
|
615
|
-
- **gamma** (Parameter) - Scale, Tensor of shape :math:`(C,)`,
|
|
618
|
+
- **gamma** (Union[Parameter, Tensor])) - Scale, Tensor of shape :math:`(C,)`,
|
|
616
619
|
data type: float32.
|
|
617
|
-
- **beta** (Parameter) - Bias, Tensor of shape :math:`(C,)`,
|
|
620
|
+
- **beta** (Union[Parameter, Tensor])) - Bias, Tensor of shape :math:`(C,)`,
|
|
618
621
|
data type: float32.
|
|
619
|
-
- **mean** (Parameter) - Mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
620
|
-
- **variance** (Parameter) - Variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
622
|
+
- **mean** (Union[Parameter, Tensor])) - Mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
623
|
+
- **variance** (Union[Parameter, Tensor])) - Variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
|
621
624
|
|
|
622
625
|
Outputs:
|
|
623
626
|
Tuple of 3 Tensors, the normalized input, the updated parameters.
|
|
@@ -2287,9 +2290,9 @@ class ApplyMomentum(Primitive):
|
|
|
2287
2290
|
gradient_scale (float): The scale of the gradient. Default: ``1.0`` .
|
|
2288
2291
|
|
|
2289
2292
|
Inputs:
|
|
2290
|
-
- **variable** (Parameter) - Weights to be updated. Data type must be float64, int64, float,
|
|
2291
|
-
int16, int32, int8, uint16, uint32, uint64, uint8, complex64, complex128.
|
|
2292
|
-
- **accumulation** (Parameter) - Accumulated gradient value by moment weight,
|
|
2293
|
+
- **variable** (Union[Parameter, Tensor]) - Weights to be updated. Data type must be float64, int64, float,
|
|
2294
|
+
float16, int16, int32, int8, uint16, uint32, uint64, uint8, complex64, complex128.
|
|
2295
|
+
- **accumulation** (Union[Parameter, Tensor]) - Accumulated gradient value by moment weight,
|
|
2293
2296
|
has the same data type with `variable`.
|
|
2294
2297
|
- **learning_rate** (Union[Number, Tensor]) - The learning rate value, must be a float64, int64, float,
|
|
2295
2298
|
float16, int16, int32, int8, uint16, uint32, uint64, uint8, complex64, complex128 number or
|
|
@@ -2306,7 +2309,7 @@ class ApplyMomentum(Primitive):
|
|
|
2306
2309
|
|
|
2307
2310
|
Raises:
|
|
2308
2311
|
TypeError: If the `use_locking` or `use_nesterov` is not a bool or `gradient_scale` is not a float.
|
|
2309
|
-
TypeError: If the data type of `var`, `accum` and `grad` conversion
|
|
2312
|
+
TypeError: If the data type of `var`, `accum` and `grad` conversion is not supported.
|
|
2310
2313
|
|
|
2311
2314
|
Supported Platforms:
|
|
2312
2315
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -2354,55 +2357,6 @@ class ApplyMomentum(Primitive):
|
|
|
2354
2357
|
self.add_prim_attr('side_effect_mem', True)
|
|
2355
2358
|
|
|
2356
2359
|
|
|
2357
|
-
class SmoothL1Loss(Primitive):
|
|
2358
|
-
r"""
|
|
2359
|
-
Calculate the smooth L1 loss, and the L1 loss function has robustness.
|
|
2360
|
-
|
|
2361
|
-
Refer to :func:`mindspore.ops.smooth_l1_loss` for more details.
|
|
2362
|
-
|
|
2363
|
-
Args:
|
|
2364
|
-
beta (float, optional): A parameter used to control the point where the function will change between
|
|
2365
|
-
L1 to L2 loss. The value should be greater than zero. Default: ``1.0`` .
|
|
2366
|
-
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2367
|
-
``'sum'`` . Default: ``'none'`` .
|
|
2368
|
-
|
|
2369
|
-
- ``'none'``: no reduction will be applied.
|
|
2370
|
-
- ``'mean'``: compute and return the mean of elements in the output.
|
|
2371
|
-
- ``'sum'``: the output elements will be summed.
|
|
2372
|
-
|
|
2373
|
-
Inputs:
|
|
2374
|
-
- **logits** (Tensor) - Input Tensor of any dimension. Data type must be float16, float32 or float64.
|
|
2375
|
-
- **labels** (Tensor) - Ground truth data, has the same shape and dtype as the `logits`.
|
|
2376
|
-
|
|
2377
|
-
Outputs:
|
|
2378
|
-
Tensor, loss float tensor, same shape and dtype as the `logits`.
|
|
2379
|
-
|
|
2380
|
-
Supported Platforms:
|
|
2381
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2382
|
-
|
|
2383
|
-
Examples:
|
|
2384
|
-
>>> import mindspore
|
|
2385
|
-
>>> import numpy as np
|
|
2386
|
-
>>> from mindspore import Tensor, ops
|
|
2387
|
-
>>> loss = ops.SmoothL1Loss()
|
|
2388
|
-
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
|
2389
|
-
>>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
|
|
2390
|
-
>>> output = loss(logits, labels)
|
|
2391
|
-
>>> print(output)
|
|
2392
|
-
[0. 0. 0.5]
|
|
2393
|
-
"""
|
|
2394
|
-
|
|
2395
|
-
@prim_attr_register
|
|
2396
|
-
def __init__(self, beta=1.0, reduction='none'):
|
|
2397
|
-
"""Initialize SmoothL1Loss."""
|
|
2398
|
-
validator.check_value_type('beta', beta, [float], self.name)
|
|
2399
|
-
validator.check('beta', beta, '', 0, validator.GT, self.name)
|
|
2400
|
-
validator.check_string(
|
|
2401
|
-
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
|
2402
|
-
self.add_prim_attr('sigma', self.beta)
|
|
2403
|
-
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
2360
|
class MultiMarginLoss(Primitive):
|
|
2407
2361
|
r"""
|
|
2408
2362
|
Creates a loss function that minimizes the hinge loss
|
|
@@ -3610,11 +3564,11 @@ class Adam(Primitive):
|
|
|
3610
3564
|
If ``False`` , update the gradients without using NAG. Default: ``False`` .
|
|
3611
3565
|
|
|
3612
3566
|
Inputs:
|
|
3613
|
-
- **var** (Parameter) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
|
3567
|
+
- **var** (Union[Parameter, Tensor]) - Weights to be updated. The shape is :math:`(N, *)` where :math:`*` means,
|
|
3614
3568
|
any number of additional dimensions. The data type can be float16 or float32.
|
|
3615
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula,
|
|
3569
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula,
|
|
3616
3570
|
the shape should be the same as `var`.
|
|
3617
|
-
- **v** (Parameter) - the 2nd moment vector in the updating formula,
|
|
3571
|
+
- **v** (Union[Parameter, Tensor]) - the 2nd moment vector in the updating formula,
|
|
3618
3572
|
the shape should be the same as `var`.
|
|
3619
3573
|
- **beta1_power** (float) - :math:`beta_1^t(\beta_1^{t})` in the updating formula.
|
|
3620
3574
|
- **beta2_power** (float) - :math:`beta_2^t(\beta_2^{t})` in the updating formula.
|
|
@@ -3785,8 +3739,8 @@ class AdamNoUpdateParam(Primitive):
|
|
|
3785
3739
|
|
|
3786
3740
|
class FusedSparseAdam(Primitive):
|
|
3787
3741
|
r"""
|
|
3788
|
-
Merges the duplicate value of the gradient and then updates parameters by the Adaptive Moment Estimation
|
|
3789
|
-
algorithm. This operator is used when the gradient is sparse.
|
|
3742
|
+
Merges the duplicate value of the gradient and then updates parameters or tensors by the Adaptive Moment Estimation
|
|
3743
|
+
(Adam) algorithm. This operator is used when the gradient is sparse.
|
|
3790
3744
|
|
|
3791
3745
|
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
|
3792
3746
|
|
|
@@ -3819,11 +3773,12 @@ class FusedSparseAdam(Primitive):
|
|
|
3819
3773
|
If ``False`` , update the gradients without using NAG. Default: ``False`` .
|
|
3820
3774
|
|
|
3821
3775
|
Inputs:
|
|
3822
|
-
- **var** (Parameter) - Parameters to be updated with float32 data type. The shape is
|
|
3823
|
-
where :math:`*` means, any number of additional dimensions.
|
|
3824
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape and data
|
|
3825
|
-
|
|
3826
|
-
|
|
3776
|
+
- **var** (Union[Parameter, Tensor]) - Parameters or tensors to be updated with float32 data type. The shape is:
|
|
3777
|
+
math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
3778
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula, has the same shape and data
|
|
3779
|
+
type as `var`.
|
|
3780
|
+
- **v** (Union[Parameter, Tensor]) - The 2nd moment vector in the updating formula, has the same shape and data
|
|
3781
|
+
type as `var`. Mean square gradients, has the same type as `var` with float32 data type.
|
|
3827
3782
|
- **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula with float32 data type.
|
|
3828
3783
|
The shape is :math:`(1, )`.
|
|
3829
3784
|
- **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula with float32 data type.
|
|
@@ -3841,7 +3796,7 @@ class FusedSparseAdam(Primitive):
|
|
|
3841
3796
|
- **indices** (Tensor) - Gradient indices with int32 data type and indices.shape[0] = gradient.shape[0].
|
|
3842
3797
|
|
|
3843
3798
|
Outputs:
|
|
3844
|
-
Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.
|
|
3799
|
+
Tuple of 3 Tensors, this operator will update the input parameters or tensors directly, the outputs are useless.
|
|
3845
3800
|
|
|
3846
3801
|
- **var** (Tensor) - A Tensor with shape :math:`(N, *)`.
|
|
3847
3802
|
- **m** (Tensor) - A Tensor with shape :math:`(1, )`.
|
|
@@ -3911,8 +3866,8 @@ class FusedSparseAdam(Primitive):
|
|
|
3911
3866
|
|
|
3912
3867
|
class FusedSparseLazyAdam(Primitive):
|
|
3913
3868
|
r"""
|
|
3914
|
-
Merges the duplicate value of the gradient and then updates parameters by the Adaptive Moment Estimation
|
|
3915
|
-
algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
|
|
3869
|
+
Merges the duplicate value of the gradient and then updates parameters or tensors by the Adaptive Moment Estimation
|
|
3870
|
+
(Adam) algorithm. This operator is used when the gradient is sparse. The behavior is not equivalent to the
|
|
3916
3871
|
original Adam algorithm, as only the current indices parameters will be updated.
|
|
3917
3872
|
|
|
3918
3873
|
The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
|
|
@@ -3946,11 +3901,12 @@ class FusedSparseLazyAdam(Primitive):
|
|
|
3946
3901
|
If ``False`` , update the gradients without using NAG. Default: ``False`` .
|
|
3947
3902
|
|
|
3948
3903
|
Inputs:
|
|
3949
|
-
- **var** (Parameter) - Parameters to be updated with float32 data type. The shape is
|
|
3950
|
-
where :math:`*` means, any number of additional dimensions.
|
|
3951
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape and data
|
|
3952
|
-
|
|
3953
|
-
|
|
3904
|
+
- **var** (Union[Parameter, Tensor]) - Parameters or tensors to be updated with float32 data type. The shape is:
|
|
3905
|
+
math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
3906
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula, has the same shape and data
|
|
3907
|
+
type as `var`.
|
|
3908
|
+
- **v** (Union[Parameter, Tensor]) - The 2nd moment vector in the updating formula, has the same shape and data
|
|
3909
|
+
type as `var`. Mean square gradients, has the same type as `var` with float32 data type.
|
|
3954
3910
|
- **beta1_power** (Tensor) - :math:`beta_1^t` in the updating formula with float32 data type.
|
|
3955
3911
|
The shape is :math:`(1, )`.
|
|
3956
3912
|
- **beta2_power** (Tensor) - :math:`beta_2^t` in the updating formula with float32 data type.
|
|
@@ -3968,7 +3924,7 @@ class FusedSparseLazyAdam(Primitive):
|
|
|
3968
3924
|
- **indices** (Tensor) - Gradient indices with int32 data type and indices.shape[0] = gradient.shape[0].
|
|
3969
3925
|
|
|
3970
3926
|
Outputs:
|
|
3971
|
-
Tuple of 3 Tensors, this operator will update the input parameters directly, the outputs are useless.
|
|
3927
|
+
Tuple of 3 Tensors, this operator will update the input parameters or tensors directly, the outputs are useless.
|
|
3972
3928
|
|
|
3973
3929
|
- **var** (Tensor) - A Tensor with shape :math:`(N, *)`.
|
|
3974
3930
|
- **m** (Tensor) - A Tensor with shape :math:`(1, )`.
|
|
@@ -4054,17 +4010,18 @@ class FusedSparseFtrl(Primitive):
|
|
|
4054
4010
|
use_locking (bool): Use locks for updating operation if True . Default: ``False`` .
|
|
4055
4011
|
|
|
4056
4012
|
Inputs:
|
|
4057
|
-
- **var** (Parameter) - The variable to be updated. The data type must be float32. The shape is
|
|
4058
|
-
where :math:`*` means, any number of additional dimensions.
|
|
4059
|
-
- **accum** (Parameter) - The accumulation to be updated, must be same type and shape as `var`.
|
|
4060
|
-
- **linear** (Parameter) - the linear coefficient to be updated, must be same type and shape as
|
|
4013
|
+
- **var** (Union[Parameter, Tensor]) - The variable to be updated. The data type must be float32. The shape is:
|
|
4014
|
+
math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4015
|
+
- **accum** (Union[Parameter, Tensor]) - The accumulation to be updated, must be same type and shape as `var`.
|
|
4016
|
+
- **linear** (Union[Parameter, Tensor]) - the linear coefficient to be updated, must be same type and shape as
|
|
4017
|
+
`var`.
|
|
4061
4018
|
- **grad** (Tensor) - A tensor of the same type as `var` and
|
|
4062
4019
|
grad.shape[1:] = var.shape[1:] if var.shape > 1.
|
|
4063
4020
|
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
|
|
4064
4021
|
The type must be int32 and indices.shape[0] = grad.shape[0].
|
|
4065
4022
|
|
|
4066
4023
|
Outputs:
|
|
4067
|
-
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
|
|
4024
|
+
Tuple of 3 Tensor, this operator will update the input parameters or tensors directly, the outputs are useless.
|
|
4068
4025
|
|
|
4069
4026
|
- **var** (Tensor) - A Tensor with shape :math:`(N, *)`.
|
|
4070
4027
|
- **accum** (Tensor) - A Tensor with shape :math:`(1, )`.
|
|
@@ -4151,9 +4108,10 @@ class FusedSparseProximalAdagrad(Primitive):
|
|
|
4151
4108
|
Default: ``False`` .
|
|
4152
4109
|
|
|
4153
4110
|
Inputs:
|
|
4154
|
-
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
|
|
4111
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. The data type must be float32.
|
|
4155
4112
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4156
|
-
- **accum** (Parameter) - Variable tensor to be updated, has the same shape and data type as
|
|
4113
|
+
- **accum** (Union[Parameter, Tensor]) - Variable tensor to be updated, has the same shape and data type as
|
|
4114
|
+
`var`.
|
|
4157
4115
|
- **lr** (Tensor) - The learning rate value. The data type must be float32. The shape is :math:`(1, )`.
|
|
4158
4116
|
- **l1** (Tensor) - l1 regularization strength. The data type must be float32. The shape is :math:`(1, )`.
|
|
4159
4117
|
- **l2** (Tensor) - l2 regularization strength. The data type must be float32. The shape is :math:`(1, )`.
|
|
@@ -4163,7 +4121,7 @@ class FusedSparseProximalAdagrad(Primitive):
|
|
|
4163
4121
|
The type must be int32 and indices.shape[0] = grad.shape[0].
|
|
4164
4122
|
|
|
4165
4123
|
Outputs:
|
|
4166
|
-
Tuple of 2 Tensors, this operator will update the input parameters directly, the outputs are useless.
|
|
4124
|
+
Tuple of 2 Tensors, this operator will update the input parameters or tensors directly, the outputs are useless.
|
|
4167
4125
|
|
|
4168
4126
|
- **var** (Tensor) - A Tensor with shape :math:`(N, *)`.
|
|
4169
4127
|
- **accum** (Tensor) - A Tensor with shape :math:`(1, )`.
|
|
@@ -4342,11 +4300,11 @@ class ApplyAdaMax(Primitive):
|
|
|
4342
4300
|
the relatively highest priority data type.
|
|
4343
4301
|
|
|
4344
4302
|
Inputs:
|
|
4345
|
-
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type.
|
|
4303
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. With float32 or float16 data type.
|
|
4346
4304
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4347
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula, has the same shape as `var`.
|
|
4305
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula, has the same shape as `var`.
|
|
4348
4306
|
With float32 or float16 data type.
|
|
4349
|
-
- **v** (Parameter) - The 2nd moment vector in the updating formula. Mean square gradients
|
|
4307
|
+
- **v** (Union[Parameter, Tensor]) - The 2nd moment vector in the updating formula. Mean square gradients
|
|
4350
4308
|
with the same shape as `var`. With float32 or float16 data type.
|
|
4351
4309
|
- **beta1_power** (Union[Number, Tensor]) - :math:`beta_1^t` in the updating formula, must be a scalar.
|
|
4352
4310
|
With float32 or float16 data type.
|
|
@@ -4362,7 +4320,7 @@ class ApplyAdaMax(Primitive):
|
|
|
4362
4320
|
With float32 or float16 data type.
|
|
4363
4321
|
|
|
4364
4322
|
Outputs:
|
|
4365
|
-
Tuple of 3 Tensor, the updated parameters.
|
|
4323
|
+
Tuple of 3 Tensor, the updated parameters or tensors.
|
|
4366
4324
|
|
|
4367
4325
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4368
4326
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
@@ -4456,10 +4414,11 @@ class ApplyAdadelta(Primitive):
|
|
|
4456
4414
|
the relatively highest priority data type.
|
|
4457
4415
|
|
|
4458
4416
|
Inputs:
|
|
4459
|
-
- **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
|
|
4417
|
+
- **var** (Union[Parameter, Tensor]) - Weights to be updated. With float32 or float16 data type.
|
|
4460
4418
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4461
|
-
- **accum** (Parameter) - Accumulation to be updated, has the same shape and data type as `var`.
|
|
4462
|
-
- **accum_update** (Parameter) - Accum_update to be updated, has the same shape and data type as
|
|
4419
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated, has the same shape and data type as `var`.
|
|
4420
|
+
- **accum_update** (Union[Parameter, Tensor]) - Accum_update to be updated, has the same shape and data type as
|
|
4421
|
+
`var`.
|
|
4463
4422
|
- **lr** (Union[Number, Tensor]) - Learning rate, must be a scalar. With float32 or float16 data type.
|
|
4464
4423
|
- **rho** (Union[Number, Tensor]) - Decay rate, must be a scalar. With float32 or float16 data type.
|
|
4465
4424
|
- **epsilon** (Union[Number, Tensor]) - A small value added for numerical stability, must be a scalar.
|
|
@@ -4467,7 +4426,7 @@ class ApplyAdadelta(Primitive):
|
|
|
4467
4426
|
- **grad** (Tensor) - Gradients, has the same shape and data type as `var`.
|
|
4468
4427
|
|
|
4469
4428
|
Outputs:
|
|
4470
|
-
Tuple of 3 Tensor, the updated parameters.
|
|
4429
|
+
Tuple of 3 Tensor, the updated parameters or tensors.
|
|
4471
4430
|
|
|
4472
4431
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4473
4432
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -4558,14 +4517,14 @@ class ApplyAdagrad(Primitive):
|
|
|
4558
4517
|
update_slots (bool): If ``True`` , `accum` will be updated. Default: ``True`` .
|
|
4559
4518
|
|
|
4560
4519
|
Inputs:
|
|
4561
|
-
- **var** (Parameter) - Variable to be updated. With float or complex data type.
|
|
4520
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. With float or complex data type.
|
|
4562
4521
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4563
|
-
- **accum** (Parameter) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4522
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4564
4523
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be a scalar. With float or complex data type.
|
|
4565
4524
|
- **grad** (Tensor) - A tensor for gradient. The shape must be the same as `var`.
|
|
4566
4525
|
|
|
4567
4526
|
Outputs:
|
|
4568
|
-
Tuple of 2 Tensors, the updated parameters.
|
|
4527
|
+
Tuple of 2 Tensors, the updated parameters or tensors.
|
|
4569
4528
|
|
|
4570
4529
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4571
4530
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -4645,15 +4604,15 @@ class ApplyAdagradV2(Primitive):
|
|
|
4645
4604
|
update_slots (bool): If ``True`` , `accum` will be updated. Default: ``True`` .
|
|
4646
4605
|
|
|
4647
4606
|
Inputs:
|
|
4648
|
-
- **var** (Parameter) - Variable to be updated. With float16 or float32 data type.
|
|
4607
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. With float16 or float32 data type.
|
|
4649
4608
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4650
|
-
- **accum** (Parameter) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4609
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4651
4610
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be a float number or
|
|
4652
4611
|
a scalar tensor with float16 or float32 data type.
|
|
4653
4612
|
- **grad** (Tensor) - A tensor for gradient. The shape must be the same as `var`.
|
|
4654
4613
|
|
|
4655
4614
|
Outputs:
|
|
4656
|
-
Tuple of 2 Tensors, the updated parameters.
|
|
4615
|
+
Tuple of 2 Tensors, the updated parameters or tensors.
|
|
4657
4616
|
|
|
4658
4617
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4659
4618
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -4756,9 +4715,9 @@ class SparseApplyAdagradV2(Primitive):
|
|
|
4756
4715
|
update_slots (bool): If ``True`` , the computation logic will be different to `False`. Default: ``True`` .
|
|
4757
4716
|
|
|
4758
4717
|
Inputs:
|
|
4759
|
-
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
|
|
4718
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type must be float16 or float32.
|
|
4760
4719
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4761
|
-
- **accum** (Parameter) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4720
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated. The shape must be the same as `var`.
|
|
4762
4721
|
- **grad** (Tensor) - Gradients has the same shape as `var` and
|
|
4763
4722
|
:math:`grad.shape[1:] = var.shape[1:]` if var.shape > 1.
|
|
4764
4723
|
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
|
|
@@ -4766,7 +4725,7 @@ class SparseApplyAdagradV2(Primitive):
|
|
|
4766
4725
|
must be unique. Otherwise, the result is unpredictable.
|
|
4767
4726
|
|
|
4768
4727
|
Outputs:
|
|
4769
|
-
Tuple of 2 tensors, the updated parameters.
|
|
4728
|
+
Tuple of 2 tensors, the updated parameters or tensors.
|
|
4770
4729
|
|
|
4771
4730
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4772
4731
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -4846,9 +4805,10 @@ class ApplyProximalAdagrad(Primitive):
|
|
|
4846
4805
|
Default: ``False`` .
|
|
4847
4806
|
|
|
4848
4807
|
Inputs:
|
|
4849
|
-
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
|
|
4808
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type must be float16 or float32.
|
|
4850
4809
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4851
|
-
- **accum** (Parameter) - Accumulation to be updated, must have the same shape and dtype as
|
|
4810
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated, must have the same shape and dtype as
|
|
4811
|
+
`var`.
|
|
4852
4812
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be a scalar. The data type must be
|
|
4853
4813
|
float16 or float32.
|
|
4854
4814
|
- **l1** (Union[Number, Tensor]) - l1 regularization strength, must be a scalar. The data type must be
|
|
@@ -4858,7 +4818,7 @@ class ApplyProximalAdagrad(Primitive):
|
|
|
4858
4818
|
- **grad** (Tensor) - Gradient with the same shape and dtype as `var`.
|
|
4859
4819
|
|
|
4860
4820
|
Outputs:
|
|
4861
|
-
Tuple of 2 Tensors, the updated parameters.
|
|
4821
|
+
Tuple of 2 Tensors, the updated parameters or tensors.
|
|
4862
4822
|
|
|
4863
4823
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4864
4824
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -4943,9 +4903,9 @@ class SparseApplyProximalAdagrad(Primitive):
|
|
|
4943
4903
|
Default: ``False`` .
|
|
4944
4904
|
|
|
4945
4905
|
Inputs:
|
|
4946
|
-
- **var** (Parameter) - Variable tensor to be updated. The data type must be float16 or float32.
|
|
4906
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. The data type must be float16 or float32.
|
|
4947
4907
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
4948
|
-
- **accum** (
|
|
4908
|
+
- **accum** (Parameterv) - Variable tensor to be updated, has the same shape as `var`.
|
|
4949
4909
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be a float number or
|
|
4950
4910
|
a scalar tensor with float16 or float32 data type. It must be positive.
|
|
4951
4911
|
- **l1** (Union[Number, Tensor]) - l1 regularization strength, must be a float number or
|
|
@@ -4959,7 +4919,7 @@ class SparseApplyProximalAdagrad(Primitive):
|
|
|
4959
4919
|
following types: int32, int64 and :math:`indices.shape[0] = grad.shape[0]`.
|
|
4960
4920
|
|
|
4961
4921
|
Outputs:
|
|
4962
|
-
Tuple of 2 tensors, the updated parameters.
|
|
4922
|
+
Tuple of 2 tensors, the updated parameters or tensors.
|
|
4963
4923
|
|
|
4964
4924
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
4965
4925
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -5045,9 +5005,9 @@ class ApplyAddSign(Primitive):
|
|
|
5045
5005
|
the relatively highest priority data type.
|
|
5046
5006
|
|
|
5047
5007
|
Inputs:
|
|
5048
|
-
- **var** (Parameter) - Variable tensor to be updated.
|
|
5008
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated.
|
|
5049
5009
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5050
|
-
- **m** (Parameter) - Variable tensor to be updated, has the same data type as `var`.
|
|
5010
|
+
- **m** (Union[Parameter, Tensor]) - Variable tensor to be updated, has the same data type as `var`.
|
|
5051
5011
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be a scalar.
|
|
5052
5012
|
- **alpha** (Union[Number, Tensor]) - Must be a scalar.
|
|
5053
5013
|
- **sign_decay** (Union[Number, Tensor]) - Must be a scalar.
|
|
@@ -5055,7 +5015,7 @@ class ApplyAddSign(Primitive):
|
|
|
5055
5015
|
- **grad** (Tensor) - A tensor of the same shape as `var`, for the gradient.
|
|
5056
5016
|
|
|
5057
5017
|
Outputs:
|
|
5058
|
-
Tuple of 2 Tensors, the updated parameters.
|
|
5018
|
+
Tuple of 2 Tensors, the updated parameters or tensors.
|
|
5059
5019
|
|
|
5060
5020
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
5061
5021
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
@@ -5144,10 +5104,10 @@ class ApplyPowerSign(Primitive):
|
|
|
5144
5104
|
On Ascend, input data type of float64 is currently not supported.
|
|
5145
5105
|
|
|
5146
5106
|
Inputs:
|
|
5147
|
-
- **var** (Parameter) - Variable tensor to be updated. With float64, float32 or float16 data
|
|
5148
|
-
If data type of `var` is float16, all inputs must have the same data type as `var`.
|
|
5107
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. With float64, float32 or float16 data
|
|
5108
|
+
type. If data type of `var` is float16, all inputs must have the same data type as `var`.
|
|
5149
5109
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5150
|
-
- **m** (Parameter) - Variable tensor to be updated, has the same shape as `var`.
|
|
5110
|
+
- **m** (Union[Parameter, Tensor]) - Variable tensor to be updated, has the same shape as `var`.
|
|
5151
5111
|
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar or Tensor
|
|
5152
5112
|
with float64, float32 or float16 data type.
|
|
5153
5113
|
- **logbase** (Union[Number, Tensor]) - Should be a scalar or Tensor with float64, float32 or float16 data type.
|
|
@@ -5158,7 +5118,7 @@ class ApplyPowerSign(Primitive):
|
|
|
5158
5118
|
- **grad** (Tensor) - A tensor of the same shape as `var`, for the gradient.
|
|
5159
5119
|
|
|
5160
5120
|
Outputs:
|
|
5161
|
-
Tuple of 2 Tensors, the updated parameters.
|
|
5121
|
+
Tuple of 2 Tensors, the updated parameters or tensors.
|
|
5162
5122
|
|
|
5163
5123
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
5164
5124
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
@@ -5235,7 +5195,7 @@ class ApplyGradientDescent(Primitive):
|
|
|
5235
5195
|
the relatively highest priority data type.
|
|
5236
5196
|
|
|
5237
5197
|
Inputs:
|
|
5238
|
-
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
|
|
5198
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. With float32 or float16 data type.
|
|
5239
5199
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5240
5200
|
- **alpha** (Union[Number, Tensor]) - Scaling factor, must be a scalar. With float32 or float16 data type.
|
|
5241
5201
|
- **delta** (Tensor) - A tensor for the change, has the same shape as `var`.
|
|
@@ -5304,7 +5264,7 @@ class ApplyProximalGradientDescent(Primitive):
|
|
|
5304
5264
|
the relatively highest priority data type.
|
|
5305
5265
|
|
|
5306
5266
|
Inputs:
|
|
5307
|
-
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
|
|
5267
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. With float32 or float16 data type.
|
|
5308
5268
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5309
5269
|
- **alpha** (Union[Number, Tensor]) - Scaling factor, must be a scalar. With float32 or float16 data type.
|
|
5310
5270
|
- **l1** (Union[Number, Tensor]) - l1 regularization strength, must be a scalar.
|
|
@@ -5448,10 +5408,10 @@ class ApplyFtrl(Primitive):
|
|
|
5448
5408
|
use_locking (bool): Use locks for updating operation if ``True`` . Default: ``False`` .
|
|
5449
5409
|
|
|
5450
5410
|
Inputs:
|
|
5451
|
-
- **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
|
|
5411
|
+
- **var** (Union[Parameter, Tensor]) - The variable to be updated. The data type must be float16 or float32.
|
|
5452
5412
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5453
|
-
- **accum** (Parameter) - The accumulation to be updated, must be same shape as `var`.
|
|
5454
|
-
- **linear** (Parameter) - The linear coefficient to be updated, must be same shape as `var`.
|
|
5413
|
+
- **accum** (Union[Parameter, Tensor]) - The accumulation to be updated, must be same shape as `var`.
|
|
5414
|
+
- **linear** (Union[Parameter, Tensor]) - The linear coefficient to be updated, must be same shape as `var`.
|
|
5455
5415
|
- **grad** (Tensor) - Gradient. The data type must be float16 or float32.
|
|
5456
5416
|
- **lr** (Union[Number, Tensor]) - The learning rate value, must be positive. Default: ``0.001`` .
|
|
5457
5417
|
It must be a float number or a scalar tensor with float16 or float32 data type.
|
|
@@ -5464,16 +5424,16 @@ class ApplyFtrl(Primitive):
|
|
|
5464
5424
|
Default: ``-0.5`` . It must be a float number or a scalar tensor with float16 or float32 data type.
|
|
5465
5425
|
|
|
5466
5426
|
Outputs:
|
|
5467
|
-
- **var** (Tensor) - Represents the updated `var`. As the input parameters has been updated in-place,
|
|
5468
|
-
value is always zero when the platform is GPU.
|
|
5427
|
+
- **var** (Tensor) - Represents the updated `var`. As the input parameters or tensors has been updated in-place,
|
|
5428
|
+
this value is always zero when the platform is GPU.
|
|
5469
5429
|
|
|
5470
5430
|
Raises:
|
|
5471
5431
|
TypeError: If `use_locking` is not a bool.
|
|
5472
5432
|
TypeError: If dtype of `var`, `grad`, `lr`, `l1`, `l2` or `lr_power` is neither float16 nor float32.
|
|
5473
5433
|
TypeError: If `lr`, `l1`, `l2` or `lr_power` is neither a Number nor a Tensor.
|
|
5474
5434
|
TypeError: If `grad` is not a Tensor.
|
|
5475
|
-
TypeError: If the parameter types of `var`, `accum` and `linear` are inconsistent.
|
|
5476
|
-
TypeError: If the parameter types of `grad`, `lr`, `l1`, `l2`, `lr_power` are inconsistent with `var`
|
|
5435
|
+
TypeError: If the parameter or tensor types of `var`, `accum` and `linear` are inconsistent.
|
|
5436
|
+
TypeError: If the parameter or tensor types of `grad`, `lr`, `l1`, `l2`, `lr_power` are inconsistent with `var`
|
|
5477
5437
|
and the precision is greater than `var`.
|
|
5478
5438
|
|
|
5479
5439
|
Supported Platforms:
|
|
@@ -5548,10 +5508,10 @@ class SparseApplyFtrl(Primitive):
|
|
|
5548
5508
|
use_locking (bool, optional): Use locks for updating operation if ``True`` . Default: ``False`` .
|
|
5549
5509
|
|
|
5550
5510
|
Inputs:
|
|
5551
|
-
- **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
|
|
5511
|
+
- **var** (Union[Parameter, Tensor]) - The variable to be updated. The data type must be float16 or float32.
|
|
5552
5512
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
5553
|
-
- **accum** (Parameter) - The accumulation to be updated, must be same shape as `var`.
|
|
5554
|
-
- **linear** (Parameter) - The linear coefficient to be updated, must be the same shape as `var`.
|
|
5513
|
+
- **accum** (Union[Parameter, Tensor]) - The accumulation to be updated, must be same shape as `var`.
|
|
5514
|
+
- **linear** (Union[Parameter, Tensor]) - The linear coefficient to be updated, must be the same shape as `var`.
|
|
5555
5515
|
- **grad** (Tensor) - A tensor must meet with :math:`grad.shape[1:] = var.shape[1:]`
|
|
5556
5516
|
if var.shape > 1.
|
|
5557
5517
|
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
|
|
@@ -6908,7 +6868,7 @@ class SparseApplyAdadelta(Primitive):
|
|
|
6908
6868
|
to make the data types consistent. Besides, inputs of 'lr' and 'rho' also support implicit type conversion.
|
|
6909
6869
|
If they have different data types, the lower priority data type will be converted to
|
|
6910
6870
|
relatively highest priority data type.
|
|
6911
|
-
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
|
|
6871
|
+
RuntimeError exception will be thrown when the data type conversion of Parameter or Tensor is required.
|
|
6912
6872
|
|
|
6913
6873
|
Note:
|
|
6914
6874
|
If there are negative values or values greater than or equal to var.shape[0] in `indices`,
|
|
@@ -6920,11 +6880,11 @@ class SparseApplyAdadelta(Primitive):
|
|
|
6920
6880
|
Default: ``False`` .
|
|
6921
6881
|
|
|
6922
6882
|
Inputs:
|
|
6923
|
-
- **var** (Parameter) - Weights to be updated. With float32 or float16 data type.
|
|
6924
|
-
- **accum** (Parameter) - Accumulation to be updated. Mush have the same shape and dtype as
|
|
6925
|
-
With float32 or float16 data type.
|
|
6926
|
-
- **accum_update** (Parameter) - Accum_update to be updated. Must have the same shape and dtype
|
|
6927
|
-
With float32 or float16 data type.
|
|
6883
|
+
- **var** (Union[Parameter, Tensor]) - Weights to be updated. With float32 or float16 data type.
|
|
6884
|
+
- **accum** (Union[Parameter, Tensor]) - Accumulation to be updated. Mush have the same shape and dtype as
|
|
6885
|
+
`var`. With float32 or float16 data type.
|
|
6886
|
+
- **accum_update** (Union[Parameter, Tensor]) - Accum_update to be updated. Must have the same shape and dtype
|
|
6887
|
+
as `var`. With float32 or float16 data type.
|
|
6928
6888
|
- **lr** (Union[float, Tensor]) - Learning rate, must be a scalar. With float32 or float16 data type.
|
|
6929
6889
|
- **rho** (Union[float, Tensor]) - Decay rate, must be a scalar. With float32 or float16 data type.
|
|
6930
6890
|
- **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`.
|
|
@@ -6932,7 +6892,7 @@ class SparseApplyAdadelta(Primitive):
|
|
|
6932
6892
|
Must be one of the following types: int32, int64 and indices.shape[0] = grad.shape[0].
|
|
6933
6893
|
|
|
6934
6894
|
Outputs:
|
|
6935
|
-
Tuple of 3 Tensor, the updated parameters.
|
|
6895
|
+
Tuple of 3 Tensor, the updated parameters or tensors.
|
|
6936
6896
|
|
|
6937
6897
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
6938
6898
|
- **accum** (Tensor) - The same shape and data type as `accum`.
|
|
@@ -7209,12 +7169,15 @@ class Conv3DTranspose(Primitive):
|
|
|
7209
7169
|
Inputs:
|
|
7210
7170
|
- **dout** (Tensor) - The gradients with respect to the output of the convolution.
|
|
7211
7171
|
The shape conforms to the default.
|
|
7212
|
-
data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`.
|
|
7213
|
-
|
|
7172
|
+
data_format :math:`(N, C_{in}, D_{out}, H_{out}, W_{out})`.
|
|
7173
|
+
Supported dtypes:
|
|
7174
|
+
|
|
7175
|
+
- Ascend: float16.
|
|
7176
|
+
- GPU/CPU: float16, float32.
|
|
7214
7177
|
- **weight** (Tensor) - Set size of kernel is :math:`(K_d, K_h, K_w)`, then the shape is
|
|
7215
7178
|
:math:`(C_{in}, C_{out}//group, K_d, K_h, K_w)`. Where :math:`group` is the Args parameter,
|
|
7216
7179
|
:math:`//` is the symbol for integer division.
|
|
7217
|
-
|
|
7180
|
+
It has the same dtype as `dout`.
|
|
7218
7181
|
- **bias** (Tensor) - Tensor of shape :math:`C_{out}`. Currently, only support none. Default: ``None`` .
|
|
7219
7182
|
|
|
7220
7183
|
Outputs:
|
|
@@ -7500,12 +7463,12 @@ class ApplyAdagradDA(Primitive):
|
|
|
7500
7463
|
Otherwise the behavior is undefined, but may exhibit less contention. Default: ``False`` .
|
|
7501
7464
|
|
|
7502
7465
|
Inputs:
|
|
7503
|
-
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
|
|
7466
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type must be float16 or float32.
|
|
7504
7467
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
7505
|
-
- **gradient_accumulator** (Parameter) - The dict of mutable tensor :math:`grad\_accum`.
|
|
7506
|
-
shape as `var`.
|
|
7507
|
-
- **gradient_squared_accumulator** (Parameter) - The dict of mutable tensor :math:`grad\_squared\_accum`.
|
|
7468
|
+
- **gradient_accumulator** (Union[Parameter, Tensor]) - The dict of mutable tensor :math:`grad\_accum`.
|
|
7508
7469
|
Must have the same shape as `var`.
|
|
7470
|
+
- **gradient_squared_accumulator** (Union[Parameter, Tensor]) - The dict of mutable tensor
|
|
7471
|
+
:math:`grad\_squared\_accum`. Must have the same shape as `var`.
|
|
7509
7472
|
- **grad** (Tensor) - A tensor for gradient. Must have the same shape as `var`.
|
|
7510
7473
|
- **lr** ([Number, Tensor]) - Scaling factor. Must be a scalar. With float32 or float16 data type.
|
|
7511
7474
|
- **l1** ([Number, Tensor]) - L1 regularization. Must be a scalar. With float32 or float16 data type.
|
|
@@ -7513,12 +7476,12 @@ class ApplyAdagradDA(Primitive):
|
|
|
7513
7476
|
- **global_step** ([Number, Tensor]) - Training step number. Must be a scalar. With int32 or int64 data type.
|
|
7514
7477
|
|
|
7515
7478
|
Outputs:
|
|
7516
|
-
Tuple of 1 Tensors, the updated parameters.
|
|
7479
|
+
Tuple of 1 Tensors, the updated parameters or tensors.
|
|
7517
7480
|
|
|
7518
7481
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
7519
7482
|
|
|
7520
7483
|
Raises:
|
|
7521
|
-
TypeError: If `var`, `gradient_accumulator` or `gradient_squared_accumulator`
|
|
7484
|
+
TypeError: If `var`, `gradient_accumulator` or `gradient_squared_accumulator` neither a Parameter nor a Tensor.
|
|
7522
7485
|
TypeError: If `grad` is not a Tensor.
|
|
7523
7486
|
TypeError: If `lr`, `l1`, `l2` or `global_step` is neither a Number nor a Tensor.
|
|
7524
7487
|
TypeError: If use_locking is not a bool.
|
|
@@ -7612,10 +7575,12 @@ class SparseApplyRMSProp(Primitive):
|
|
|
7612
7575
|
otherwise the behavior is undefined, but may exhibit less contention. Default: ``False`` .
|
|
7613
7576
|
|
|
7614
7577
|
Inputs:
|
|
7615
|
-
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
|
|
7578
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type must be float16 or float32.
|
|
7616
7579
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
7617
|
-
- **ms** (Parameter) - The dict of mutable tensor ms. Must have the same shape and dtype as
|
|
7618
|
-
|
|
7580
|
+
- **ms** (Union[Parameter, Tensor]) - The dict of mutable tensor ms. Must have the same shape and dtype as
|
|
7581
|
+
`var`.
|
|
7582
|
+
- **mom** (Union[Parameter, Tensor]) - The dict of mutable tensor mom. Must have the same shape and dtype as
|
|
7583
|
+
`var`.
|
|
7619
7584
|
- **lr** ([Number, Tensor]) - Learning rate. Must be a scalar. With float16 or float32 data type.
|
|
7620
7585
|
- **grad** (Tensor) - A tensor for gradient. Must have the same shape and dtype as `var`.
|
|
7621
7586
|
- **indices** (Tensor) - A tensor of indices in the first dimension of `var`, `ms` and `mom`.
|
|
@@ -7623,7 +7588,7 @@ class SparseApplyRMSProp(Primitive):
|
|
|
7623
7588
|
following types: int32, int64 and indices.shape[0] = var.shape[0].
|
|
7624
7589
|
|
|
7625
7590
|
Outputs:
|
|
7626
|
-
Tuple of 3 Tensors, the updated parameters.
|
|
7591
|
+
Tuple of 3 Tensors, the updated parameters or tensors.
|
|
7627
7592
|
|
|
7628
7593
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
7629
7594
|
- **ms** (Tensor) - The same shape and data type as `ms`.
|
|
@@ -7729,12 +7694,12 @@ class SparseApplyCenteredRMSProp(Primitive):
|
|
|
7729
7694
|
Default: ``False`` .
|
|
7730
7695
|
|
|
7731
7696
|
Inputs:
|
|
7732
|
-
- **var** (Parameter) - Variable tensor to be updated. The data type must be int8, int16, int32,
|
|
7733
|
-
uint8, uint16, uint32, uint64, float16, float32 or float64.
|
|
7697
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. The data type must be int8, int16, int32,
|
|
7698
|
+
int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
|
|
7734
7699
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
7735
|
-
- **mg** (Parameter) - Mean gradients. Must have the same shape and dtype as `var`.
|
|
7736
|
-
- **ms** (Parameter) - Mean square gradients. Must have the same shape and dtype as `var`.
|
|
7737
|
-
- **mom** (Parameter) - Delta of `var`. Must have the same shape and dtype as `var`.
|
|
7700
|
+
- **mg** (Union[Parameter, Tensor]) - Mean gradients. Must have the same shape and dtype as `var`.
|
|
7701
|
+
- **ms** (Union[Parameter, Tensor]) - Mean square gradients. Must have the same shape and dtype as `var`.
|
|
7702
|
+
- **mom** (Union[Parameter, Tensor]) - Delta of `var`. Must have the same shape and dtype as `var`.
|
|
7738
7703
|
- **lr** (Union[Number, Tensor]) - Learning rate. Must be a float number or a scalar tensor.
|
|
7739
7704
|
Must have the same type as `var`.
|
|
7740
7705
|
- **rho** (Union[Number, Tensor]) - Decay rate. Must be a float number or a scalar tensor.
|
|
@@ -7837,8 +7802,9 @@ class ApplyKerasMomentum(Primitive):
|
|
|
7837
7802
|
so in the end, the var you get is actually var + momentum * accum. Default: ``False`` .
|
|
7838
7803
|
|
|
7839
7804
|
Inputs:
|
|
7840
|
-
- **var** (Parameter) - Variable to be updated. With float16 or float32 data type.
|
|
7841
|
-
- **accum** (Parameter) - Must have the same shape and type as `var`. With float16 or float32
|
|
7805
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. With float16 or float32 data type.
|
|
7806
|
+
- **accum** (Union[Parameter, Tensor]) - Must have the same shape and type as `var`. With float16 or float32
|
|
7807
|
+
data type.
|
|
7842
7808
|
- **lr** (Union[Number, Tensor]) - Scaling factor. Must be a scalar. With float16 or float32 data type.
|
|
7843
7809
|
- **grad** (Tensor) - The gradient. Must have the same shape and type as `var`.
|
|
7844
7810
|
With float16 or float32 data type.
|
|
@@ -7989,12 +7955,12 @@ class ApplyAdamWithAmsgrad(Primitive):
|
|
|
7989
7955
|
Default: ``False`` .
|
|
7990
7956
|
|
|
7991
7957
|
Inputs:
|
|
7992
|
-
- **var** (Parameter) - Variable to be updated. The data type can be float16 or float32.
|
|
7993
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula,
|
|
7958
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type can be float16 or float32.
|
|
7959
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula,
|
|
7994
7960
|
the shape and data type value should be the same as `var`.
|
|
7995
|
-
- **v** (Parameter) - the 2nd moment vector in the updating formula,
|
|
7961
|
+
- **v** (Union[Parameter, Tensor]) - the 2nd moment vector in the updating formula,
|
|
7996
7962
|
the shape and data type value should be the same as `var`.
|
|
7997
|
-
- **vhat** (Parameter) - :math:`\hat v_t` in the updating formula,
|
|
7963
|
+
- **vhat** (Union[Parameter, Tensor]) - :math:`\hat v_t` in the updating formula,
|
|
7998
7964
|
the shape and data type value should be the same as `var`.
|
|
7999
7965
|
- **beta1_power** (Union[float, Tensor]) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
|
|
8000
7966
|
a scalar tensor with float16 or float32 data type.
|
|
@@ -8004,7 +7970,7 @@ class ApplyAdamWithAmsgrad(Primitive):
|
|
|
8004
7970
|
- **grad** (Tensor) - The gradient, has the same shape and data type as `var`.
|
|
8005
7971
|
|
|
8006
7972
|
Outputs:
|
|
8007
|
-
Tuple of 4 Tensors, the updated parameters.
|
|
7973
|
+
Tuple of 4 Tensors, the updated parameters or tensors.
|
|
8008
7974
|
|
|
8009
7975
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
8010
7976
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
@@ -8012,7 +7978,7 @@ class ApplyAdamWithAmsgrad(Primitive):
|
|
|
8012
7978
|
- **vhat** (Tensor) - The same shape and data type as `vhat`.
|
|
8013
7979
|
|
|
8014
7980
|
Raises:
|
|
8015
|
-
TypeError: If `var`, `m`, `v`, `vhat`
|
|
7981
|
+
TypeError: If `var`, `m`, `v`, `vhat` neither a Parameter nor a Tensor.
|
|
8016
7982
|
TypeError: If `beta1_power`, `beta2_power`, `lr` is neither a Number nor a Tensor.
|
|
8017
7983
|
TypeError: If `grad` is not a Tensor.
|
|
8018
7984
|
TypeError: If dtype of `var`, `m`, `v`, `vhat`, `beta1_power`, `beta2_power`,
|
|
@@ -8096,12 +8062,12 @@ class ApplyAdamWithAmsgradV2(Primitive):
|
|
|
8096
8062
|
Default: ``False`` .
|
|
8097
8063
|
|
|
8098
8064
|
Inputs:
|
|
8099
|
-
- **var** (Parameter) - Variable to be updated. The data type can be float16, float32 or float64.
|
|
8100
|
-
- **m** (Parameter) - The 1st moment vector in the updating formula,
|
|
8065
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated. The data type can be float16, float32 or float64.
|
|
8066
|
+
- **m** (Union[Parameter, Tensor]) - The 1st moment vector in the updating formula,
|
|
8101
8067
|
the shape should be the same as `var`.
|
|
8102
|
-
- **v** (Parameter) - The 2nd moment vector in the updating formula,
|
|
8068
|
+
- **v** (Union[Parameter, Tensor]) - The 2nd moment vector in the updating formula,
|
|
8103
8069
|
the shape should be the same as `var`.
|
|
8104
|
-
- **vhat** (Parameter) - :math:`\hat v_t` in the updating formula,
|
|
8070
|
+
- **vhat** (Union[Parameter, Tensor]) - :math:`\hat v_t` in the updating formula,
|
|
8105
8071
|
the shape and data type value should be the same as `var`.
|
|
8106
8072
|
- **beta1_power** (Union[float, Tensor]) - :math:`beta_1^t(\beta_1^{t})` in the updating formula,
|
|
8107
8073
|
with float16, float32 or float64 data type.
|
|
@@ -8117,7 +8083,7 @@ class ApplyAdamWithAmsgradV2(Primitive):
|
|
|
8117
8083
|
- **grad** (Tensor) - The gradient, has the same shape as `var`.
|
|
8118
8084
|
|
|
8119
8085
|
Outputs:
|
|
8120
|
-
Tuple of 4 Tensors, the updated parameters.
|
|
8086
|
+
Tuple of 4 Tensors, the updated parameters or tensors.
|
|
8121
8087
|
|
|
8122
8088
|
- **var** (Tensor) - The same shape and data type as `var`.
|
|
8123
8089
|
- **m** (Tensor) - The same shape and data type as `m`.
|
|
@@ -8125,7 +8091,7 @@ class ApplyAdamWithAmsgradV2(Primitive):
|
|
|
8125
8091
|
- **vhat** (Tensor) - The same shape and data type as `vhat`.
|
|
8126
8092
|
|
|
8127
8093
|
Raises:
|
|
8128
|
-
TypeError: If `var`, `m`, `v`, `vhat`
|
|
8094
|
+
TypeError: If `var`, `m`, `v`, `vhat` neither a Parameter nor a Tensor.
|
|
8129
8095
|
TypeError: If dtype of `var`, `m`, `v`, `vhat`, `beta1_power`, `beta2_power`,
|
|
8130
8096
|
`lr`, `beta1` , `beta2` , `epsilon` or `grad` is not float64, float32 or float16.
|
|
8131
8097
|
RuntimeError: If the data type of `var`, `m`, `v` , `vhat` and `grad` conversion of Parameter is not supported.
|
|
@@ -8805,11 +8771,11 @@ class SparseApplyAdagradDA(Primitive):
|
|
|
8805
8771
|
Otherwise the behavior is undefined, but may exhibit less contention. Default: ``False`` .
|
|
8806
8772
|
|
|
8807
8773
|
Inputs:
|
|
8808
|
-
- **var** (Parameter) - Variable to be updated.
|
|
8774
|
+
- **var** (Union[Parameter, Tensor]) - Variable to be updated.
|
|
8809
8775
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
8810
|
-
- **grad_accum** (Parameter) - The dict of mutable tensor grad_accum. Must have the same
|
|
8776
|
+
- **grad_accum** (Union[Parameter, Tensor]) - The dict of mutable tensor grad_accum. Must have the same
|
|
8811
8777
|
shape and dtype as `var`.
|
|
8812
|
-
- **grad_square_accum** (Parameter) - The dict of mutable tensor grad_square_accum.
|
|
8778
|
+
- **grad_square_accum** (Union[Parameter, Tensor]) - The dict of mutable tensor grad_square_accum.
|
|
8813
8779
|
Must have the same shape and dtype as `var`.
|
|
8814
8780
|
- **grad** (Tensor) - A tensor of the same type as `var` and grad.shape[1:] = var.shape[1:] if rank(var) > 1.
|
|
8815
8781
|
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
|
|
@@ -8987,8 +8953,8 @@ class SparseApplyProximalGradientDescent(Primitive):
|
|
|
8987
8953
|
Default: ``False`` .
|
|
8988
8954
|
|
|
8989
8955
|
Inputs:
|
|
8990
|
-
- **var** (Parameter) - Variable tensor to be updated. The data type must be int8, int16, int32,
|
|
8991
|
-
uint8, uint16, uint32, uint64, float16, float32 or float64.
|
|
8956
|
+
- **var** (Union[Parameter, Tensor]) - Variable tensor to be updated. The data type must be int8, int16, int32,
|
|
8957
|
+
int64, uint8, uint16, uint32, uint64, float16, float32 or float64.
|
|
8992
8958
|
The shape is :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
|
8993
8959
|
- **alpha** (Union[Number, Tensor]) - Scaling factor. Must be a scalar with same type as `var`.
|
|
8994
8960
|
- **l1** (Union[Number, Tensor]) - L1 regularization. Must be a scalar with same type as `var`.
|
|
@@ -9003,7 +8969,7 @@ class SparseApplyProximalGradientDescent(Primitive):
|
|
|
9003
8969
|
- **var** (Tensor) - Tensor, has the same shape and type as 'var'.
|
|
9004
8970
|
|
|
9005
8971
|
Raises:
|
|
9006
|
-
TypeError: If `var
|
|
8972
|
+
TypeError: If `var` neither a Parameter nor a Tensor.
|
|
9007
8973
|
TypeError: If `alpha`, `l1`, `l2` is neither a Number nor a Tensor.
|
|
9008
8974
|
TypeError: If `use_locking` is not a bool.
|
|
9009
8975
|
TypeError: If dtype of `var`, `alpha`, `l1`, `l2` or `grad` is not one of int8, int16,
|
|
@@ -9139,51 +9105,6 @@ class NuclearNorm(Primitive):
|
|
|
9139
9105
|
validator.check_value_type("keepdim", keepdim, [bool], self.name)
|
|
9140
9106
|
|
|
9141
9107
|
|
|
9142
|
-
class GLU(Primitive):
|
|
9143
|
-
r"""
|
|
9144
|
-
Computes GLU (Gated Linear Unit activation function) of input tensors.
|
|
9145
|
-
|
|
9146
|
-
.. warning::
|
|
9147
|
-
This is an experimental API that is subject to change or deletion.
|
|
9148
|
-
|
|
9149
|
-
Refer to :func:`mindspore.ops.glu` for more details.
|
|
9150
|
-
|
|
9151
|
-
Args:
|
|
9152
|
-
axis (int, optional): Axis on which to split the input.
|
|
9153
|
-
The value of `axis` must be an int within range [-rank(`x`), rank(`x`)).
|
|
9154
|
-
Default: ``-1`` , specifying the last dimension.
|
|
9155
|
-
|
|
9156
|
-
Inputs:
|
|
9157
|
-
- **x** (Tensor) - Input tensor. `x.shape[axis]` must be even.
|
|
9158
|
-
|
|
9159
|
-
Outputs:
|
|
9160
|
-
Tensor, has the same data type with `x`.
|
|
9161
|
-
|
|
9162
|
-
Supported Platforms:
|
|
9163
|
-
``Ascend`` ``CPU``
|
|
9164
|
-
|
|
9165
|
-
Examples:
|
|
9166
|
-
>>> from mindspore import ops, Tensor
|
|
9167
|
-
>>> from mindspore import dtype as mstype
|
|
9168
|
-
>>> import numpy as np
|
|
9169
|
-
>>> axis = 0
|
|
9170
|
-
>>> x = Tensor(np.array([0.3220, 0.9545, 0.7879, 0.0975, 0.3698,
|
|
9171
|
-
... 0.5135, 0.5740, 0.3435, 0.1895, 0.8764,
|
|
9172
|
-
... 0.4980, 0.9673, 0.9879, 0.6988, 0.9022,
|
|
9173
|
-
... 0.9304, 0.1558, 0.0153, 0.1559, 0.9852]).reshape([2, 2, 5]), mstype.float32)
|
|
9174
|
-
>>> glu = ops.GLU(axis=axis)
|
|
9175
|
-
>>> y = glu(x)
|
|
9176
|
-
>>> print(y)
|
|
9177
|
-
[[[0.20028052 0.6916126 0.57412136 0.06512236 0.26307625]
|
|
9178
|
-
[0.3682598 0.3093122 0.17306386 0.10212085 0.63814086]]]
|
|
9179
|
-
"""
|
|
9180
|
-
|
|
9181
|
-
@prim_attr_register
|
|
9182
|
-
def __init__(self, axis=-1):
|
|
9183
|
-
"""Initialize GLU"""
|
|
9184
|
-
validator.check_value_type("axis", axis, [int], self.name)
|
|
9185
|
-
|
|
9186
|
-
|
|
9187
9108
|
class FractionalMaxPoolWithFixedKsize(Primitive):
|
|
9188
9109
|
r"""
|
|
9189
9110
|
Applies a 2D fractional max pooling to an input signal composed of multiple input planes.
|
|
@@ -9267,7 +9188,8 @@ class FractionalMaxPoolWithFixedKsize(Primitive):
|
|
|
9267
9188
|
class ChannelShuffle(Primitive):
|
|
9268
9189
|
r"""
|
|
9269
9190
|
Divide the channels in a tensor of shape :math:`(*, C, H, W)` into :math:`g` group and
|
|
9270
|
-
rearrange them as :math:`(*, \frac
|
|
9191
|
+
rearrange them as :math:`(*, \frac{C}{g}, g, H*W)`, while retaining the original tensor
|
|
9192
|
+
shape in the final output.
|
|
9271
9193
|
|
|
9272
9194
|
.. warning::
|
|
9273
9195
|
This is an experimental API that is subject to change or deletion.
|
|
@@ -9475,93 +9397,6 @@ class WKV(Primitive):
|
|
|
9475
9397
|
outputs=["output", "out_sp", "out_sq", "out_sm"])
|
|
9476
9398
|
|
|
9477
9399
|
|
|
9478
|
-
class PromptFlashAttention(Primitive):
|
|
9479
|
-
r"""
|
|
9480
|
-
The interface for fully inference.
|
|
9481
|
-
B -- Batch size
|
|
9482
|
-
S -- Sequence length
|
|
9483
|
-
H -- Hidden size
|
|
9484
|
-
|
|
9485
|
-
Note:
|
|
9486
|
-
experiment ops
|
|
9487
|
-
|
|
9488
|
-
.. warning::
|
|
9489
|
-
This is an experimental API that is subject to change or deletion.
|
|
9490
|
-
|
|
9491
|
-
Args:
|
|
9492
|
-
num_heads (int): The number of heads.
|
|
9493
|
-
scale_value (float): The scale value indicating the scale coefficient, which is used as the scalar of
|
|
9494
|
-
Muls in the calculation. Default: 1.0.
|
|
9495
|
-
pre_tokens (int): Previous tokens. Default: 2147483547.
|
|
9496
|
-
next_tokens (int): next tokens. Default: 0.
|
|
9497
|
-
indicate the upper triangle, Indicate the number of data blocks involved in the calculation. The value 0
|
|
9498
|
-
indicates that the data blocks in the upper triangle are not involved in the calculation
|
|
9499
|
-
input_layout (str): the data layout of the input qkv, support `(BSH)` and `(BNSD)`, Default `BSH`.
|
|
9500
|
-
num_key_value_heads (int): head numbers of key/value which are used in GQA algorithm.
|
|
9501
|
-
The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
|
|
9502
|
-
sparse_mode (int): Default: 0
|
|
9503
|
-
inner_precise (int): 0, float16 high precision. 1, high performance. default 1
|
|
9504
|
-
|
|
9505
|
-
Inputs:
|
|
9506
|
-
- **query** (Tensor) - The query tensor with data type of float16 or float32.
|
|
9507
|
-
Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
9508
|
-
- **key** (Tensor) - The key tensor with data type of float16 or float32.
|
|
9509
|
-
Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
9510
|
-
- **value** (Tensor) - The value tensor with data type of float16 or float32.
|
|
9511
|
-
Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
9512
|
-
- **attn_mask** (Tensor) - The attention mask tensor with data type of float16 or float32.
|
|
9513
|
-
For each element, 0 indicates retention and 1 indicates discard. Input tensor of shape :math:`(B, 1, S, S)`.
|
|
9514
|
-
- **actual_seq_lengths** (Tensor): Describe actual sequence length of each input with data type of int64.
|
|
9515
|
-
- **actual_seq_lengths_kv** (Tensor): Describe actual sequence length of each input with data type of int64.
|
|
9516
|
-
- **pse_shift** (Tensor) - The position encoding tensor with data type of float16 or float32.
|
|
9517
|
-
- **dep_scale1** (Tensor)
|
|
9518
|
-
- **quant_scale1** (Tensor)
|
|
9519
|
-
- **deq_scale2** (Tensor)
|
|
9520
|
-
- **quant_scale2** (Tensor)
|
|
9521
|
-
- **quant_offset2** (Tensor)
|
|
9522
|
-
|
|
9523
|
-
Outputs:
|
|
9524
|
-
- **attention_out** (Tensor) - Input tensor of shape :math:`(B, S, H)` / `(B, N, S, D)`.
|
|
9525
|
-
|
|
9526
|
-
Supported Platforms:
|
|
9527
|
-
``Ascend``
|
|
9528
|
-
|
|
9529
|
-
Examples:
|
|
9530
|
-
>>> import mindspore.ops.operations.nn_ops as P
|
|
9531
|
-
>>> from mindspore import Tensor
|
|
9532
|
-
>>> import numpy as np
|
|
9533
|
-
>>> B = 1
|
|
9534
|
-
>>> N = 16
|
|
9535
|
-
>>> S = 256
|
|
9536
|
-
>>> D = 16
|
|
9537
|
-
>>> query = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
9538
|
-
>>> key = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
9539
|
-
>>> value = Tensor(np.ones((B, N, S, D), dtype=np.float16))
|
|
9540
|
-
>>> attn_mask = Tensor(np.ones((B, 1, S, S), dtype=np.float16))
|
|
9541
|
-
>>> pfa = P.PromptFlashAttention(N, input_layout='BNSD')
|
|
9542
|
-
>>> out = pfa(query, key, value, attn_mask, None, None, None, None, None, None, None, None)
|
|
9543
|
-
>>> print(out.shape)
|
|
9544
|
-
(1, 16, 256, 16)
|
|
9545
|
-
"""
|
|
9546
|
-
|
|
9547
|
-
@prim_attr_register
|
|
9548
|
-
def __init__(self, num_heads, scale_value=1.0, pre_tokens=214748647, next_tokens=0, input_layout='BSH',
|
|
9549
|
-
num_key_value_heads=0, sparse_mode=0, inner_precise=1):
|
|
9550
|
-
"""Initialize PromptFlashAttention."""
|
|
9551
|
-
validator.check_value_type('num_heads', num_heads, [int], self.name)
|
|
9552
|
-
validator.check_value_type('scale_value', scale_value, [float], self.name)
|
|
9553
|
-
validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
|
|
9554
|
-
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
9555
|
-
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
9556
|
-
validator.check_value_type('num_key_value_heads', num_key_value_heads, [int], self.name)
|
|
9557
|
-
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
9558
|
-
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
9559
|
-
self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "actual_seq_lengths",
|
|
9560
|
-
"actual_seq_lengths_kv", "pse_shift", "deq_scale1", "quant_scale1",
|
|
9561
|
-
"deq_scale2", "quant_scale2", "quant_offset2"],
|
|
9562
|
-
outputs=["attention_out"])
|
|
9563
|
-
|
|
9564
|
-
|
|
9565
9400
|
class AllFinite(Primitive):
|
|
9566
9401
|
r"""
|
|
9567
9402
|
Check all gradients is finite.
|
|
@@ -9578,3 +9413,6 @@ class AllFinite(Primitive):
|
|
|
9578
9413
|
raise RuntimeError(
|
|
9579
9414
|
"The version of Ascend AI software package installed "
|
|
9580
9415
|
"in the current environment does not support AllFinite.")
|
|
9416
|
+
|
|
9417
|
+
def __call__(self, *args):
|
|
9418
|
+
return _convert_stub(pyboost_all_finite(self, args))
|