mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""StridedSliceV2 op"""
|
|
17
|
-
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
|
-
|
|
19
|
-
strided_slice_v2_op_info = AiCPURegOp("StridedSliceV2") \
|
|
20
|
-
.fusion_type("OPAQUE") \
|
|
21
|
-
.input(0, "x", "required") \
|
|
22
|
-
.input(1, "begin", "required") \
|
|
23
|
-
.input(2, "end", "required") \
|
|
24
|
-
.input(3, "strides", "required") \
|
|
25
|
-
.output(0, "output", "required") \
|
|
26
|
-
.attr("begin_mask", "int") \
|
|
27
|
-
.attr("end_mask", "int") \
|
|
28
|
-
.attr("ellipsis_mask", "int") \
|
|
29
|
-
.attr("new_axis_mask", "int") \
|
|
30
|
-
.attr("shrink_axis_mask", "int") \
|
|
31
|
-
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,
|
|
32
|
-
DataType.I64_Default, DataType.BOOL_Default) \
|
|
33
|
-
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I64_Default,
|
|
34
|
-
DataType.I64_Default, DataType.I8_Default) \
|
|
35
|
-
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I64_Default,
|
|
36
|
-
DataType.I64_Default, DataType.I16_Default) \
|
|
37
|
-
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I64_Default,
|
|
38
|
-
DataType.I64_Default, DataType.I32_Default) \
|
|
39
|
-
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
|
40
|
-
DataType.I64_Default, DataType.I64_Default) \
|
|
41
|
-
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I64_Default,
|
|
42
|
-
DataType.I64_Default, DataType.U8_Default) \
|
|
43
|
-
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I64_Default,
|
|
44
|
-
DataType.I64_Default, DataType.U16_Default) \
|
|
45
|
-
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I64_Default,
|
|
46
|
-
DataType.I64_Default, DataType.U32_Default) \
|
|
47
|
-
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default,
|
|
48
|
-
DataType.I64_Default, DataType.U64_Default) \
|
|
49
|
-
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default,
|
|
50
|
-
DataType.I64_Default, DataType.F16_Default) \
|
|
51
|
-
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default,
|
|
52
|
-
DataType.I64_Default, DataType.F32_Default) \
|
|
53
|
-
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default,
|
|
54
|
-
DataType.I64_Default, DataType.F64_Default) \
|
|
55
|
-
.dtype_format(DataType.C64_Default, DataType.I64_Default, DataType.I64_Default,
|
|
56
|
-
DataType.I64_Default, DataType.C64_Default) \
|
|
57
|
-
.dtype_format(DataType.C128_Default, DataType.I64_Default, DataType.I64_Default,
|
|
58
|
-
DataType.I64_Default, DataType.C128_Default) \
|
|
59
|
-
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.I32_Default,
|
|
60
|
-
DataType.I32_Default, DataType.BOOL_Default) \
|
|
61
|
-
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default,
|
|
62
|
-
DataType.I32_Default, DataType.I8_Default) \
|
|
63
|
-
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default,
|
|
64
|
-
DataType.I32_Default, DataType.I16_Default) \
|
|
65
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
|
66
|
-
DataType.I32_Default, DataType.I32_Default) \
|
|
67
|
-
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default,
|
|
68
|
-
DataType.I32_Default, DataType.I64_Default) \
|
|
69
|
-
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default,
|
|
70
|
-
DataType.I32_Default, DataType.U8_Default) \
|
|
71
|
-
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default,
|
|
72
|
-
DataType.I32_Default, DataType.U16_Default) \
|
|
73
|
-
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default,
|
|
74
|
-
DataType.I32_Default, DataType.U32_Default) \
|
|
75
|
-
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default,
|
|
76
|
-
DataType.I32_Default, DataType.U64_Default) \
|
|
77
|
-
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default,
|
|
78
|
-
DataType.I32_Default, DataType.F16_Default) \
|
|
79
|
-
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default,
|
|
80
|
-
DataType.I32_Default, DataType.F32_Default) \
|
|
81
|
-
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default,
|
|
82
|
-
DataType.I32_Default, DataType.F64_Default) \
|
|
83
|
-
.dtype_format(DataType.C64_Default, DataType.I32_Default, DataType.I32_Default,
|
|
84
|
-
DataType.I32_Default, DataType.C64_Default) \
|
|
85
|
-
.dtype_format(DataType.C128_Default, DataType.I32_Default, DataType.I32_Default,
|
|
86
|
-
DataType.I32_Default, DataType.C128_Default) \
|
|
87
|
-
.get_op_info()
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@op_info_register(strided_slice_v2_op_info)
|
|
91
|
-
def _strided_slice_v2_aicpu():
|
|
92
|
-
"""StridedSliceV2 AiCPU register"""
|
|
93
|
-
return
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""StridedSliceGradV2 op"""
|
|
17
|
-
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
|
-
|
|
19
|
-
strided_slice_v2_grad_op_info = AiCPURegOp("StridedSliceV2Grad") \
|
|
20
|
-
.fusion_type("OPAQUE") \
|
|
21
|
-
.input(0, "shapex", "required") \
|
|
22
|
-
.input(1, "begin", "required") \
|
|
23
|
-
.input(2, "end", "required") \
|
|
24
|
-
.input(3, "strides", "required") \
|
|
25
|
-
.input(4, "dy", "required") \
|
|
26
|
-
.output(0, "output", "required") \
|
|
27
|
-
.attr("begin_mask", "int") \
|
|
28
|
-
.attr("end_mask", "int") \
|
|
29
|
-
.attr("ellipsis_mask", "int") \
|
|
30
|
-
.attr("new_axis_mask", "int") \
|
|
31
|
-
.attr("shrink_axis_mask", "int") \
|
|
32
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
33
|
-
DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
|
|
34
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
35
|
-
DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
|
|
36
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
37
|
-
DataType.I32_Default, DataType.I32_Default, DataType.I16_Default, DataType.I16_Default) \
|
|
38
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
39
|
-
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
|
40
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
41
|
-
DataType.I32_Default, DataType.I32_Default, DataType.I64_Default, DataType.I64_Default) \
|
|
42
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
43
|
-
DataType.I32_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
|
|
44
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
45
|
-
DataType.I32_Default, DataType.I32_Default, DataType.U16_Default, DataType.U16_Default) \
|
|
46
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
47
|
-
DataType.I32_Default, DataType.I32_Default, DataType.U32_Default, DataType.U32_Default) \
|
|
48
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
49
|
-
DataType.I32_Default, DataType.I32_Default, DataType.U64_Default, DataType.U64_Default) \
|
|
50
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
51
|
-
DataType.I32_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
|
|
52
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
53
|
-
DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
|
|
54
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
55
|
-
DataType.I32_Default, DataType.I32_Default, DataType.F64_Default, DataType.F64_Default) \
|
|
56
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
57
|
-
DataType.I32_Default, DataType.I32_Default, DataType.C64_Default, DataType.C64_Default) \
|
|
58
|
-
.dtype_format(DataType.I32_Default, DataType.I32_Default,
|
|
59
|
-
DataType.I32_Default, DataType.I32_Default, DataType.C128_Default, DataType.C128_Default) \
|
|
60
|
-
.get_op_info()
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
@op_info_register(strided_slice_v2_grad_op_info)
|
|
64
|
-
def _strided_slice_v2_grad_aicpu():
|
|
65
|
-
"""StridedSliceV2Grad AiCPU register"""
|
|
66
|
-
return
|
mindspore/ops/extend/__init__.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Operators with better performance
|
|
19
|
-
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
from __future__ import absolute_import
|
|
23
|
-
|
|
24
|
-
from mindspore.common import Tensor
|
|
25
|
-
from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
|
|
26
|
-
from mindspore.ops.vm_impl_registry import get_vm_impl_fn, vm_impl_registry
|
|
27
|
-
from mindspore.ops.op_info_register import op_info_register, custom_info_register, AkgGpuRegOp, AkgAscendRegOp, \
|
|
28
|
-
AiCPURegOp, TBERegOp, CpuRegOp, CustomRegOp, DataType
|
|
29
|
-
from mindspore.ops.primitive import constexpr
|
|
30
|
-
from . import (
|
|
31
|
-
array_func,
|
|
32
|
-
math_func,
|
|
33
|
-
nn_func,
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
from .array_func import gather, max, min, one_hot
|
|
37
|
-
from .math_func import (
|
|
38
|
-
baddbmm,
|
|
39
|
-
bmm,
|
|
40
|
-
add,
|
|
41
|
-
sub
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
from .nn_func import (
|
|
45
|
-
conv2d,
|
|
46
|
-
max_pool2d,
|
|
47
|
-
leaky_relu_ext
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
__all__ = []
|
|
51
|
-
__all__.extend(array_func.__all__)
|
|
52
|
-
__all__.extend(math_func.__all__)
|
|
53
|
-
__all__.extend(nn_func.__all__)
|
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Array Operators
|
|
19
|
-
|
|
20
|
-
"""
|
|
21
|
-
from mindspore.common import Tensor
|
|
22
|
-
from mindspore.ops.operations.array_ops import ArgMaxWithValue, ArgMinWithValue
|
|
23
|
-
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
24
|
-
from mindspore.ops.auto_generate.gen_ops_prim import gather_d_op
|
|
25
|
-
from mindspore.ops.auto_generate.gen_ops_def import max_, min_
|
|
26
|
-
from mindspore.ops.auto_generate.pyboost_inner_prim import _PyboostOneHotExtPrim
|
|
27
|
-
one_hot_ext_impl = _PyboostOneHotExtPrim()
|
|
28
|
-
|
|
29
|
-
# define Primitive global variables
|
|
30
|
-
|
|
31
|
-
def gather(input, dim, index):
|
|
32
|
-
r"""
|
|
33
|
-
Gather data from a tensor by indices.
|
|
34
|
-
|
|
35
|
-
.. math::
|
|
36
|
-
output[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)] =
|
|
37
|
-
input[(i_0, i_1, ..., index[(i_0, i_1, ..., i_{dim}, i_{dim+1}, ..., i_n)], i_{dim+1}, ..., i_n)]
|
|
38
|
-
|
|
39
|
-
.. warning::
|
|
40
|
-
On Ascend, the behavior is unpredictable in the following cases:
|
|
41
|
-
|
|
42
|
-
- the value of `index` is not in the range `[-input.shape[dim], input.shape[dim])` in forward;
|
|
43
|
-
- the value of `index` is not in the range `[0, input.shape[dim])` in backward.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
input (Tensor): The target tensor to gather values.
|
|
47
|
-
dim (int): the axis to index along, must be in range `[-input.rank, input.rank)`.
|
|
48
|
-
index (Tensor): The index tensor, with int32 or int64 data type. An valid `index` should be:
|
|
49
|
-
|
|
50
|
-
- `index.rank == input.rank`;
|
|
51
|
-
- for `axis != dim`, `index.shape[axis] <= input.shape[axis]`;
|
|
52
|
-
- the value of `index` is in range `[-input.shape[dim], input.shape[dim])`.
|
|
53
|
-
|
|
54
|
-
Returns:
|
|
55
|
-
Tensor, has the same type as `input` and the same shape as `index`.
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
ValueError: If the shape of `index` is illegal.
|
|
59
|
-
ValueError: If `dim` is not in `[-input.rank, input.rank)`.
|
|
60
|
-
ValueError: If the value of `index` is out of the valid range.
|
|
61
|
-
TypeError: If the type of `index` is illegal.
|
|
62
|
-
|
|
63
|
-
Supported Platforms:
|
|
64
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
65
|
-
|
|
66
|
-
Examples:
|
|
67
|
-
>>> import mindspore
|
|
68
|
-
>>> import numpy as np
|
|
69
|
-
>>> from mindspore import Tensor, ops
|
|
70
|
-
>>> input = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
|
71
|
-
>>> index = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
|
|
72
|
-
>>> output = ops.extend.gather(input, 1, index)
|
|
73
|
-
>>> print(output)
|
|
74
|
-
[[-0.1 -0.1]
|
|
75
|
-
[0.5 0.5]]
|
|
76
|
-
"""
|
|
77
|
-
return gather_d_op(input, dim, index)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def max(input, dim=None, keepdim=False):
|
|
81
|
-
"""
|
|
82
|
-
Calculates the maximum value along with the given dimension for the input tensor.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
|
|
86
|
-
dim (int, optional): The dimension to reduce. Default: ``None`` .
|
|
87
|
-
keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
|
|
88
|
-
with the input, the output will reduce dimension if false. Default: ``False`` .
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Tensor if `dim` is the default value ``None`` , the maximum value of input tensor, with the shape :math:`()` ,
|
|
92
|
-
and same dtype as `input`.
|
|
93
|
-
|
|
94
|
-
tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the maximum
|
|
95
|
-
value of the input tensor along the given dimension `dim` and the corresponding index.
|
|
96
|
-
|
|
97
|
-
- **values (Tensor)** - The maximum value of input tensor along the given dimension `dim`, with same dtype as
|
|
98
|
-
`input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
|
|
99
|
-
input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
|
|
100
|
-
input_{axis-1}, input_{axis+1}, ..., input_N)` .
|
|
101
|
-
- **index (Tensor)** - The index for the maximum value of the input tensor along the given dimension `dim`, with
|
|
102
|
-
the same shape as `values`.
|
|
103
|
-
|
|
104
|
-
Raises:
|
|
105
|
-
ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
|
|
106
|
-
|
|
107
|
-
Supported Platforms:
|
|
108
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
109
|
-
|
|
110
|
-
Examples:
|
|
111
|
-
>>> import mindspore
|
|
112
|
-
>>> import numpy as np
|
|
113
|
-
>>> from mindspore import Tensor, ops
|
|
114
|
-
>>> y = Tensor(np.array([[0.0, 0.3, 0.4, 0.5, 0.1],
|
|
115
|
-
... [3.2, 0.4, 0.1, 2.9, 4.0]]), mindspore.float32)
|
|
116
|
-
>>> output, index = ops.extend.max(y, 0, True)
|
|
117
|
-
>>> print(output, index)
|
|
118
|
-
[[3.2 0.4 0.4 2.9 4. ]] [[1 1 0 1 1]]
|
|
119
|
-
"""
|
|
120
|
-
if dim is None:
|
|
121
|
-
if keepdim is not False:
|
|
122
|
-
raise ValueError(f"For 'max', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
|
|
123
|
-
return max_(input)
|
|
124
|
-
argmax_with_value_op = _get_cache_prim(ArgMaxWithValue)(dim, keepdim)
|
|
125
|
-
indices, values = argmax_with_value_op(input)
|
|
126
|
-
return values, indices
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def min(input, dim=None, keepdim=False):
|
|
130
|
-
"""
|
|
131
|
-
Calculates the minimum value along with the given dimension for the input tensor.
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
input (Tensor): The input tensor, can be any dimension. Complex tensor is not supported for now.
|
|
135
|
-
dim (int, optional): The dimension to reduce. Default: ``None`` .
|
|
136
|
-
keepdim (bool, optional): Whether to reduce dimension, if true, the output will keep same dimension
|
|
137
|
-
with the input, the output will reduce dimension if false. Default: ``False`` .
|
|
138
|
-
|
|
139
|
-
Returns:
|
|
140
|
-
Tensor if `dim` is the default value ``None`` , the minimum value of input tensor, with the shape :math:`()` ,
|
|
141
|
-
and same dtype as `input`.
|
|
142
|
-
|
|
143
|
-
tuple (Tensor) if `dim` is not the default value ``None`` , tuple of 2 tensors, containing the minimum value
|
|
144
|
-
of the input tensor along the given dimension `dim` and the corresponding index.
|
|
145
|
-
|
|
146
|
-
- **values (Tensor)** - The minimum value of input tensor along the given dimension `dim`, with same dtype as
|
|
147
|
-
`input`. If `keepdim` is ``True`` , the shape of output tensors is :math:`(input_1, input_2, ...,
|
|
148
|
-
input_{axis-1}, 1, input_{axis+1}, ..., input_N)` . Otherwise, the shape is :math:`(input_1, input_2, ...,
|
|
149
|
-
input_{axis-1}, input_{axis+1}, ..., input_N)` .
|
|
150
|
-
- **index (Tensor)** - The index for the minimum value of the input tensor along the given dimension `dim`,
|
|
151
|
-
with the same shape as `values`.
|
|
152
|
-
|
|
153
|
-
Raises:
|
|
154
|
-
ValueError: If `dim` is the default value ``None`` and `keepdim` is not ``False`` .
|
|
155
|
-
|
|
156
|
-
Supported Platforms:
|
|
157
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
158
|
-
|
|
159
|
-
Examples:
|
|
160
|
-
>>> import mindspore
|
|
161
|
-
>>> import numpy as np
|
|
162
|
-
>>> from mindspore import Tensor, ops
|
|
163
|
-
>>> x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32)
|
|
164
|
-
>>> output, index = ops.extend.min(x, 0, keepdim=True)
|
|
165
|
-
>>> print(output, index)
|
|
166
|
-
[0.0] [0]
|
|
167
|
-
"""
|
|
168
|
-
if dim is None:
|
|
169
|
-
if keepdim is not False:
|
|
170
|
-
raise ValueError(f"For 'min', the `keepdim` must be False when the `dim` is None, but got {keepdim}")
|
|
171
|
-
return min_(input)
|
|
172
|
-
argmin_with_value_op = _get_cache_prim(ArgMinWithValue)(dim, keepdim)
|
|
173
|
-
indices, values = argmin_with_value_op(input)
|
|
174
|
-
return values, indices
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def one_hot(tensor, num_classes):
|
|
178
|
-
r"""
|
|
179
|
-
Computes a one-hot tensor.
|
|
180
|
-
|
|
181
|
-
The locations represented by tensor in `tensor` take value `1`, while all
|
|
182
|
-
other locations take value `0`.
|
|
183
|
-
|
|
184
|
-
Args:
|
|
185
|
-
tensor (Tensor): A tensor of indices. Tensor of shape :math:`(X_0, \ldots, X_n)`.
|
|
186
|
-
Data type must be int32 or int64.
|
|
187
|
-
num_classes (int): A scalar defining the depth of the one-hot dimension.
|
|
188
|
-
|
|
189
|
-
Returns:
|
|
190
|
-
Tensor, one-hot tensor.
|
|
191
|
-
|
|
192
|
-
Raises:
|
|
193
|
-
TypeError: If `num_classes` is not an int.
|
|
194
|
-
TypeError: If dtype of `tensor` is not int32 or int64.
|
|
195
|
-
ValueError: If `num_classes` is less than 0.
|
|
196
|
-
|
|
197
|
-
Supported Platforms:
|
|
198
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
199
|
-
|
|
200
|
-
Examples:
|
|
201
|
-
>>> import mindspore
|
|
202
|
-
>>> import numpy as np
|
|
203
|
-
>>> from mindspore import ops
|
|
204
|
-
>>> from mindspore import Tensor
|
|
205
|
-
>>> tensor = Tensor(np.array([0, 1, 2]), mindspore.int32)
|
|
206
|
-
>>> num_classes = 3
|
|
207
|
-
>>> output = ops.extend.one_hot(tensor, num_classes)
|
|
208
|
-
>>> print(output)
|
|
209
|
-
[[1. 0. 0.]
|
|
210
|
-
[0. 1. 0.]
|
|
211
|
-
[0. 0. 1.]]
|
|
212
|
-
"""
|
|
213
|
-
on_value = Tensor(1, dtype=tensor.dtype)
|
|
214
|
-
off_value = Tensor(0, dtype=tensor.dtype)
|
|
215
|
-
return one_hot_ext_impl(tensor, num_classes, on_value, off_value, -1)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
__all__ = ['gather', 'max', 'min', 'one_hot']
|
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
Math Operators with better performance
|
|
19
|
-
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
from mindspore.ops import auto_generate as P
|
|
23
|
-
from mindspore.ops.auto_generate.gen_ops_def import add_ext as add, sub_ext as sub, bmm_ext as bmm
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
# define Primitive global variables
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def baddbmm(input, batch1, batch2, beta=1, alpha=1):
|
|
30
|
-
r"""
|
|
31
|
-
The result is the sum of the input and a batch matrix-matrix product of matrices in batch1 and batch2.
|
|
32
|
-
The formula is defined as follows:
|
|
33
|
-
|
|
34
|
-
.. math::
|
|
35
|
-
\text{out}_{i} = \beta \text{input}_{i} + \alpha (\text{batch1}_{i} \mathbin{@} \text{batch2}_{i})
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
input (Tensor): The input Tensor. When batch1 is a :math:`(C, W, T)` Tensor and batch2 is a
|
|
39
|
-
:math:`(C, T, H)` Tensor, input must be broadcastable with :math:`(C, W, H)` Tensor.
|
|
40
|
-
batch1 (Tensor): :math:`batch1` in the above formula. Must be 3-D Tensor, dtype is same as input.
|
|
41
|
-
batch2 (Tensor): :math:`batch2` in the above formula. Must be 3-D Tensor, dtype is same as input.
|
|
42
|
-
beta (Union[float, int], optional): multiplier for input. Default: ``1`` .
|
|
43
|
-
alpha (Union[float, int], optional): multiplier for :math:`batch1 @ batch2`. Default: ``1`` .
|
|
44
|
-
Arguments beta and alpha must be integers when inputs of type not FloatTensor, otherwise they should
|
|
45
|
-
be a real number.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
Tensor, has the same dtype as input, shape will be :math:`(C, W, H)`.
|
|
49
|
-
|
|
50
|
-
Raises:
|
|
51
|
-
TypeError: The type of `input`, `batch1`, `batch2` is not Tensor.
|
|
52
|
-
TypeError: The types of `input`, `batch1`, `batch2` are different.
|
|
53
|
-
TypeError: For inputs of type FloatTensor or DoubleTensor, \
|
|
54
|
-
arguments beta and alpha not be real numbers, otherwise not be integers.
|
|
55
|
-
TypeError: For Baddbmm, attributes alpha and beta are not real numbers
|
|
56
|
-
ValueError: If `batch1` and `batch2` are not 3-D tensors.
|
|
57
|
-
|
|
58
|
-
Supported Platforms:
|
|
59
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
60
|
-
|
|
61
|
-
Examples:
|
|
62
|
-
>>> import numpy as np
|
|
63
|
-
>>> from mindspore import Tensor, ops
|
|
64
|
-
>>> input = Tensor(np.ones([1, 3, 3]).astype(np.float32))
|
|
65
|
-
>>> batch1 = Tensor(np.ones([1, 3, 4]).astype(np.float32))
|
|
66
|
-
>>> batch2 = Tensor(np.ones([1, 4, 3]).astype(np.float32))
|
|
67
|
-
>>> output = ops.baddbmm(input, batch1, batch2)
|
|
68
|
-
>>> print(output)
|
|
69
|
-
[[[5. 5. 5.]
|
|
70
|
-
[5. 5. 5.]
|
|
71
|
-
[5. 5. 5.]]]
|
|
72
|
-
"""
|
|
73
|
-
return P.baddbmm(input, batch1, batch2, beta, alpha)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
__all__ = ['baddbmm', 'add', 'sub', 'bmm']
|