mindspore 2.2.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/_akg/akg/composite/build_module.py +104 -20
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +41 -15
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -21,9 +21,7 @@ import mindspore.common.dtype as mstype
|
|
|
21
21
|
from mindspore.common.tensor import Tensor
|
|
22
22
|
from mindspore import ops
|
|
23
23
|
from mindspore.nn.cell import Cell
|
|
24
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
|
|
25
24
|
from mindspore.ops.operations.nn_ops import FlashAttentionScore
|
|
26
|
-
from mindspore._c_expression import MSContext
|
|
27
25
|
|
|
28
26
|
__all__ = ['FlashAttention']
|
|
29
27
|
|
|
@@ -46,25 +44,25 @@ class FlashAttention(Cell):
|
|
|
46
44
|
Default 65536.
|
|
47
45
|
next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
|
|
48
46
|
Default 65536.
|
|
49
|
-
tiling_stgy_name(str): A str to define tiling strategy of flash attention.
|
|
50
47
|
dp(int): data parallel.
|
|
51
48
|
Default 1.
|
|
52
49
|
mp(int): model parallel.
|
|
53
50
|
Default 1.
|
|
54
|
-
high_precision(bool): This mode has higher precision but some performance loss.
|
|
51
|
+
high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
|
|
55
52
|
Default False.
|
|
56
53
|
have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
|
|
57
54
|
Default True
|
|
58
55
|
alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
|
|
59
56
|
Default: False
|
|
57
|
+
use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
|
|
60
58
|
|
|
61
59
|
|
|
62
60
|
Inputs:
|
|
63
61
|
- **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
64
62
|
- **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
65
63
|
- **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
|
|
66
|
-
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16`
|
|
67
|
-
|
|
64
|
+
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
|
|
65
|
+
[batch_size, seq_length, seq_length]): A matrix to pass masked information.
|
|
68
66
|
|
|
69
67
|
Outputs:
|
|
70
68
|
A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
|
|
@@ -97,56 +95,51 @@ class FlashAttention(Cell):
|
|
|
97
95
|
dropout_rate=0.0,
|
|
98
96
|
prev_block_num=65536,
|
|
99
97
|
next_block_num=65536,
|
|
100
|
-
tiling_stgy_name="sparse",
|
|
101
98
|
dp=1,
|
|
102
99
|
mp=1,
|
|
103
100
|
high_precision=False,
|
|
104
101
|
have_attention_mask_batch=True,
|
|
105
|
-
alibi=False
|
|
102
|
+
alibi=False,
|
|
103
|
+
use_mqa=False
|
|
106
104
|
):
|
|
107
105
|
super(FlashAttention, self).__init__()
|
|
108
106
|
|
|
109
107
|
scaling_constant = math.sqrt(head_dim)
|
|
110
108
|
if scaling_constant == 0:
|
|
111
109
|
raise ValueError("the scaling constant must not be 0.")
|
|
112
|
-
self.
|
|
110
|
+
self.dropout_rate = dropout_rate
|
|
111
|
+
self.alibi = alibi
|
|
112
|
+
self.have_attention_mask_batch = have_attention_mask_batch
|
|
113
113
|
|
|
114
|
-
self.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
if alibi:
|
|
125
|
-
raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
|
|
126
|
-
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
127
|
-
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
128
|
-
self.reshape = ops.Reshape()
|
|
129
|
-
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
130
|
-
self.zeros = ops.Zeros()
|
|
131
|
-
self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
|
|
132
|
-
fa_strategies = ((dp, 1, mp),
|
|
133
|
-
(dp, 1, mp),
|
|
134
|
-
(dp, 1, mp),
|
|
114
|
+
self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
|
|
115
|
+
self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
|
|
116
|
+
self.reshape = ops.Reshape()
|
|
117
|
+
self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
|
|
118
|
+
self.zeros = ops.Zeros()
|
|
119
|
+
self.attn_cast = ops.Cast()
|
|
120
|
+
if use_mqa:
|
|
121
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
122
|
+
(dp, 1, 1, 1),
|
|
135
123
|
(dp, 1, 1, 1))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
124
|
+
else:
|
|
125
|
+
fa_strategies = ((dp, mp, 1, 1),
|
|
126
|
+
(dp, mp, 1, 1),
|
|
127
|
+
(dp, mp, 1, 1))
|
|
128
|
+
if self.alibi:
|
|
129
|
+
self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
130
|
+
self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
|
|
131
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
132
|
+
if dropout_rate > 1e-5:
|
|
133
|
+
fa_strategies += ((dp, mp, 1, 1),)
|
|
134
|
+
fa_strategies += ((dp, 1, 1, 1),)
|
|
135
|
+
self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
|
|
136
|
+
next_tokens=next_block_num,
|
|
137
|
+
keep_prob=1 - dropout_rate,
|
|
138
|
+
scale_value=1. / scaling_constant,
|
|
139
|
+
inner_precise=0,
|
|
140
|
+
input_layout="BNSD").shard(fa_strategies)
|
|
143
141
|
|
|
144
|
-
self.ones = ops.Ones()
|
|
145
|
-
self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
|
|
146
|
-
self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
|
|
147
142
|
self.dropout_rate = dropout_rate
|
|
148
|
-
self.have_attention_mask_batch = have_attention_mask_batch
|
|
149
|
-
self.alibi = alibi
|
|
150
143
|
if self.dropout_rate > 1e-5:
|
|
151
144
|
self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
|
|
152
145
|
self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
|
|
@@ -162,46 +155,7 @@ class FlashAttention(Cell):
|
|
|
162
155
|
such as MatMul. Default: None.
|
|
163
156
|
:return:
|
|
164
157
|
"""
|
|
165
|
-
if in_strategy is None:
|
|
166
|
-
# default: dp=1, mp=1, construct inputs only contain query, key, value
|
|
167
|
-
in_strategy = (
|
|
168
|
-
(1, 1, 1, 1),
|
|
169
|
-
(1, 1, 1, 1),
|
|
170
|
-
(1, 1, 1, 1),
|
|
171
|
-
)
|
|
172
158
|
self.flash_attention.shard(in_strategy)
|
|
173
|
-
dp = in_strategy[0][0]
|
|
174
|
-
mp = in_strategy[0][1]
|
|
175
|
-
self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
|
|
176
|
-
inputs_tensor_map = [
|
|
177
|
-
[3, 2, 1, 0],
|
|
178
|
-
[3, 2, 1, 0],
|
|
179
|
-
[3, 2, 1, 0],
|
|
180
|
-
]
|
|
181
|
-
if self.have_attention_mask_batch:
|
|
182
|
-
inputs_tensor_map.append([3, 1, 0])
|
|
183
|
-
else:
|
|
184
|
-
inputs_tensor_map.append([-1, 1, 0])
|
|
185
|
-
|
|
186
|
-
input_empty_args_num = 2
|
|
187
|
-
# dropout_mask
|
|
188
|
-
if self.dropout_rate > 1e-5:
|
|
189
|
-
input_empty_args_num -= 1
|
|
190
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
191
|
-
|
|
192
|
-
if self.alibi:
|
|
193
|
-
input_empty_args_num -= 1
|
|
194
|
-
inputs_tensor_map.append([3, 2, 1, 0])
|
|
195
|
-
|
|
196
|
-
self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
|
|
197
|
-
|
|
198
|
-
self.flash_attention.add_prim_attr("outputs_tensor_map", [
|
|
199
|
-
[3, 2, 1, 0], # O
|
|
200
|
-
[3, 2, 1], # L
|
|
201
|
-
[3, 2, 1] # M
|
|
202
|
-
])
|
|
203
|
-
self.flash_attention.add_prim_attr("as_loss_divisor", 0)
|
|
204
|
-
self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
|
|
205
159
|
|
|
206
160
|
def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
|
|
207
161
|
"""FlashAttention forward
|
|
@@ -212,53 +166,24 @@ class FlashAttention(Cell):
|
|
|
212
166
|
:param alibi_mask: [bsz, head_num, 1, seq_len], if not None
|
|
213
167
|
:return: output [bsz, head_num, seq_len, head_dim]
|
|
214
168
|
"""
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
raise ValueError(
|
|
221
|
-
"the head_num of query, key and value must be the same, "
|
|
222
|
-
"If different head_num are used, users need to change themselves to be same by tile.")
|
|
223
|
-
if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
|
|
224
|
-
raise ValueError(
|
|
225
|
-
"query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
|
|
226
|
-
|
|
227
|
-
if head_dim > 304:
|
|
228
|
-
raise ValueError(
|
|
229
|
-
"the head_dim must be less than 304, otherwise the ub would be OOM.")
|
|
230
|
-
|
|
231
|
-
if self.is_910A:
|
|
232
|
-
# 910A -- FlashAttentionPrimtive
|
|
233
|
-
if self.dropout_rate > 1e-5:
|
|
234
|
-
drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
|
|
235
|
-
tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
|
|
236
|
-
ones = self.fill_v2(tensor_shape, self.tensor_one)
|
|
237
|
-
ones = self.depend(ones, query)
|
|
238
|
-
drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
|
|
239
|
-
else:
|
|
240
|
-
drop_mask = None
|
|
241
|
-
output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
|
|
169
|
+
bsz, head_num, seq_len, _ = query.shape
|
|
170
|
+
# 910B -- FlashAttentionScore
|
|
171
|
+
if self.dropout_rate > 1e-5:
|
|
172
|
+
drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
|
|
173
|
+
(bsz, head_num, seq_len, seq_len // 8))
|
|
242
174
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
value,
|
|
258
|
-
attn_mask,
|
|
259
|
-
drop_mask_bits,
|
|
260
|
-
None,
|
|
261
|
-
None)
|
|
262
|
-
output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
|
|
263
|
-
|
|
175
|
+
drop_mask_bits = None
|
|
176
|
+
if self.alibi:
|
|
177
|
+
alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
|
|
178
|
+
# (B, S, S) -> (B, 1, S, S)
|
|
179
|
+
if self.have_attention_mask_batch:
|
|
180
|
+
attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
|
|
181
|
+
_, _, _, output = self.flash_attention(query,
|
|
182
|
+
key,
|
|
183
|
+
value,
|
|
184
|
+
alibi_mask,
|
|
185
|
+
drop_mask_bits,
|
|
186
|
+
None,
|
|
187
|
+
attn_mask,
|
|
188
|
+
None)
|
|
264
189
|
return output
|
mindspore/nn/layer/math.py
CHANGED
|
@@ -375,9 +375,6 @@ class DiGamma(Cell):
|
|
|
375
375
|
nan, real_result)
|
|
376
376
|
|
|
377
377
|
|
|
378
|
-
eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
379
|
-
|
|
380
|
-
|
|
381
378
|
def _while_helper_func(cond, body, vals):
|
|
382
379
|
while cond(vals).any():
|
|
383
380
|
vals = body(vals)
|
|
@@ -394,7 +391,7 @@ def _igamma_series(ax, x, a, enabled):
|
|
|
394
391
|
select = P.Select()
|
|
395
392
|
|
|
396
393
|
# If more data types are supported, this epsilon need to be selected.
|
|
397
|
-
epsilon =
|
|
394
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
398
395
|
|
|
399
396
|
def cond(vals):
|
|
400
397
|
enabled = vals[0]
|
|
@@ -443,7 +440,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
|
|
|
443
440
|
select = P.Select()
|
|
444
441
|
|
|
445
442
|
# If more data types are supported, this epsilon need to be selected.
|
|
446
|
-
epsilon =
|
|
443
|
+
epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
|
|
447
444
|
|
|
448
445
|
def cond(vals):
|
|
449
446
|
enabled = vals[0]
|
|
@@ -620,8 +617,7 @@ class IGamma(Cell):
|
|
|
620
617
|
x = F.broadcast_to(x, para_shape)
|
|
621
618
|
a = F.broadcast_to(a, para_shape)
|
|
622
619
|
x_is_zero = self.equal(x, 0)
|
|
623
|
-
|
|
624
|
-
underflow = self.less(ax, self.neg(log_maxfloat))
|
|
620
|
+
underflow = self.less(ax, self.neg(self.log_maxfloat32))
|
|
625
621
|
ax = self.exp(ax)
|
|
626
622
|
enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
|
|
627
623
|
output = self.select(use_igammac,
|
mindspore/nn/layer/rnn_cells.py
CHANGED
|
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
|
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
86
|
-
|
|
86
|
+
"""RNN cell function with tanh activation"""
|
|
87
87
|
if b_ih is None:
|
|
88
88
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
89
89
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
97
|
-
|
|
97
|
+
"""RNN cell function with relu activation"""
|
|
98
98
|
if b_ih is None:
|
|
99
99
|
igates = P.MatMul(False, True)(inputs, w_ih)
|
|
100
100
|
hgates = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
105
105
|
|
|
106
106
|
|
|
107
107
|
def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
108
|
-
|
|
108
|
+
"""LSTM cell function"""
|
|
109
109
|
hx, cx = hidden
|
|
110
110
|
if b_ih is None:
|
|
111
111
|
gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
|
|
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
128
|
-
|
|
128
|
+
"""GRU cell function"""
|
|
129
129
|
if b_ih is None:
|
|
130
130
|
gi = P.MatMul(False, True)(inputs, w_ih)
|
|
131
131
|
gh = P.MatMul(False, True)(hidden, w_hh)
|
|
@@ -144,7 +144,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
class RNNCellBase(Cell):
|
|
147
|
-
|
|
147
|
+
"""Basic class for RNN Cells"""
|
|
148
148
|
def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
|
|
149
149
|
dtype=mstype.float32):
|
|
150
150
|
super().__init__()
|
mindspore/nn/wrap/__init__.py
CHANGED
|
@@ -20,7 +20,8 @@ Use the Wrapper to combine the loss or build the training steps.
|
|
|
20
20
|
from __future__ import absolute_import
|
|
21
21
|
|
|
22
22
|
from mindspore.nn.wrap.cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, \
|
|
23
|
-
WithEvalCell, ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell
|
|
23
|
+
WithEvalCell, ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell, \
|
|
24
|
+
GradAccumulationCell
|
|
24
25
|
from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell,\
|
|
25
26
|
DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
|
|
26
27
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
@@ -42,5 +43,6 @@ __all__ = [
|
|
|
42
43
|
"ParameterUpdate",
|
|
43
44
|
"DynamicLossScaleUpdateCell",
|
|
44
45
|
"FixedLossScaleUpdateCell",
|
|
45
|
-
"VirtualDatasetCellTriple"
|
|
46
|
+
"VirtualDatasetCellTriple",
|
|
47
|
+
"GradAccumulationCell"
|
|
46
48
|
]
|
|
@@ -644,6 +644,9 @@ class PipelineCell(Cell):
|
|
|
644
644
|
self.micro_inputs = nn.CellList()
|
|
645
645
|
self.micro_size = micro_size
|
|
646
646
|
self.add_list = []
|
|
647
|
+
if not isinstance(network, Cell):
|
|
648
|
+
raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
|
|
649
|
+
"but got the type : {}.".format(type(network)))
|
|
647
650
|
if not isinstance(micro_size, int):
|
|
648
651
|
raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
|
|
649
652
|
"but got the type : {}.".format(type(micro_size)))
|
|
@@ -670,7 +673,7 @@ class PipelineCell(Cell):
|
|
|
670
673
|
|
|
671
674
|
class GradAccumulationCell(Cell):
|
|
672
675
|
"""
|
|
673
|
-
Wrap the network with Micro Batch.
|
|
676
|
+
Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.
|
|
674
677
|
|
|
675
678
|
Args:
|
|
676
679
|
network (Cell): The target network to wrap.
|
|
@@ -680,8 +683,11 @@ class GradAccumulationCell(Cell):
|
|
|
680
683
|
``Ascend`` ``GPU``
|
|
681
684
|
|
|
682
685
|
Examples:
|
|
683
|
-
>>>
|
|
684
|
-
>>>
|
|
686
|
+
>>> import mindspore.nn as nn
|
|
687
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
688
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
689
|
+
>>> net = LeNet5()
|
|
690
|
+
>>> net = nn.GradAccumulationCell(net, 4)
|
|
685
691
|
"""
|
|
686
692
|
def __init__(self, network, micro_size):
|
|
687
693
|
super(GradAccumulationCell, self).__init__(auto_prefix=False)
|
|
@@ -689,6 +695,9 @@ class GradAccumulationCell(Cell):
|
|
|
689
695
|
self.micro_inputs = nn.CellList()
|
|
690
696
|
self.micro_size = micro_size
|
|
691
697
|
self.add_list = []
|
|
698
|
+
if not isinstance(network, Cell):
|
|
699
|
+
raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
|
|
700
|
+
"but got the type : {}.".format(type(network)))
|
|
692
701
|
if not isinstance(micro_size, int):
|
|
693
702
|
raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
|
|
694
703
|
"but got the type : {}.".format(type(micro_size)))
|
mindspore/numpy/utils_const.py
CHANGED
|
@@ -143,8 +143,8 @@ def _infer_out_shape(*shapes):
|
|
|
143
143
|
shape_out = list()
|
|
144
144
|
max_len = max([len(it) for it in shapes])
|
|
145
145
|
for i in range(max_len):
|
|
146
|
-
items = [
|
|
147
|
-
|
|
146
|
+
items = [
|
|
147
|
+
it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
|
|
148
148
|
max_size = 0 if 0 in items else max(items)
|
|
149
149
|
_check()
|
|
150
150
|
shape_out.append(max_size)
|
|
@@ -158,8 +158,8 @@ def _can_broadcast(*shapes):
|
|
|
158
158
|
"""
|
|
159
159
|
max_len = max([len(it) for it in shapes])
|
|
160
160
|
for i in range(max_len):
|
|
161
|
-
items = [
|
|
162
|
-
|
|
161
|
+
items = [
|
|
162
|
+
it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
|
|
163
163
|
max_size = 0 if 0 in items else max(items)
|
|
164
164
|
if any(item not in (1, max_size) for item in items):
|
|
165
165
|
return False
|
|
@@ -399,7 +399,7 @@ def _broadcast_tuples(tup1, tup2):
|
|
|
399
399
|
if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
|
|
400
400
|
raise TypeError("input shift and axis must be tuple or list or int.")
|
|
401
401
|
if len(tup1) == len(tup2) or len(tup1) == 1 or len(tup2) == 1:
|
|
402
|
-
return
|
|
402
|
+
return
|
|
403
403
|
raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
|
|
404
404
|
|
|
405
405
|
tup1 = (tup1,) if isinstance(tup1, int) else tup1
|
|
@@ -203,7 +203,7 @@ def get_bprop_index_put(self):
|
|
|
203
203
|
if is_ascend:
|
|
204
204
|
indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
|
|
205
205
|
indices_me = stack(indices_ms)
|
|
206
|
-
indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me)-1, -1, -1))
|
|
206
|
+
indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me) - 1, -1, -1))
|
|
207
207
|
values_grad = gather_nd(dout, indices_grad)
|
|
208
208
|
if equal(cast(x2.shape[0], mstype.int32), Tensor(1)):
|
|
209
209
|
values_grad = values_grad.sum().reshape(1)
|
|
@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
|
|
|
19
19
|
from mindspore.ops import operations as P
|
|
20
20
|
from mindspore.ops.composite import multitype_ops as C
|
|
21
21
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
22
|
-
from mindspore.ops._grad_experimental.grad_base import bprops
|
|
22
|
+
from mindspore.ops._grad_experimental.grad_base import bprops, bprop_getters
|
|
23
23
|
from mindspore.common import dtype as mstype
|
|
24
24
|
|
|
25
25
|
get_dtype = P.DType()
|
|
@@ -193,7 +193,7 @@ def bprop_tensor_move(x, out, dout):
|
|
|
193
193
|
return (dout,)
|
|
194
194
|
|
|
195
195
|
|
|
196
|
-
@
|
|
196
|
+
@bprop_getters.register("DictInplaceSetItem")
|
|
197
197
|
def get_bprop_dict_inplace_setitem(self):
|
|
198
198
|
"""Generate bprop for dict inplace pop"""
|
|
199
199
|
|
|
@@ -135,7 +135,7 @@ def get_bprop_matrix_triangular_solve(self):
|
|
|
135
135
|
|
|
136
136
|
def bprop(matrix, rhs, out, dout):
|
|
137
137
|
grad_rhs = matrix_triangular_solve_op(matrix, dout)
|
|
138
|
-
if matrix.dtype
|
|
138
|
+
if matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
139
139
|
grad_rhs_temp = _adjoint(grad_rhs)
|
|
140
140
|
out_temp = _adjoint(out)
|
|
141
141
|
else:
|
|
@@ -156,14 +156,14 @@ def get_bprop_matrix_triangular_solve(self):
|
|
|
156
156
|
grad_matrix = mat_mul_op(grad_rhs, out_temp)
|
|
157
157
|
grad_matrix = neg_op(grad_matrix)
|
|
158
158
|
if lower_a:
|
|
159
|
-
if grad_matrix.dtype
|
|
159
|
+
if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
160
160
|
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
|
|
161
161
|
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
|
|
162
162
|
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
|
163
163
|
else:
|
|
164
164
|
grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
|
|
165
165
|
else:
|
|
166
|
-
if grad_matrix.dtype
|
|
166
|
+
if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
|
|
167
167
|
grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
|
|
168
168
|
grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
|
|
169
169
|
grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
|
|
@@ -219,7 +219,7 @@ def get_bprop_matrix_solve(self):
|
|
|
219
219
|
@_primexpr
|
|
220
220
|
def _generate_perm_matrix_solve_ls(x_dim):
|
|
221
221
|
perm = tuple(range(x_dim - 2))
|
|
222
|
-
perm = perm + (x_dim-1, x_dim-2)
|
|
222
|
+
perm = perm + (x_dim - 1, x_dim - 2)
|
|
223
223
|
return perm
|
|
224
224
|
|
|
225
225
|
|
|
@@ -647,20 +647,21 @@ def _fft_rank_offset(norm_shape, rank):
|
|
|
647
647
|
@_primexpr
|
|
648
648
|
def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
|
|
649
649
|
"""generate reverse term for fft_with_size"""
|
|
650
|
+
norm_ = None
|
|
650
651
|
if inverse is False:
|
|
651
652
|
if norm == "forward":
|
|
652
|
-
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
653
|
-
|
|
654
|
-
norm_ = 1 * _fft_rank_offset(norm_shape, rank)
|
|
655
|
-
|
|
656
|
-
norm_ = 1
|
|
657
|
-
|
|
653
|
+
norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
|
|
654
|
+
elif norm == "backward":
|
|
655
|
+
norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
|
|
656
|
+
elif norm == "ortho":
|
|
657
|
+
norm_ = 1.0
|
|
658
|
+
else:
|
|
658
659
|
if norm == "forward":
|
|
659
|
-
norm_ = 1 * _fft_rank_offset(norm_shape, rank)
|
|
660
|
-
|
|
661
|
-
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
662
|
-
|
|
663
|
-
norm_ = 1
|
|
660
|
+
norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
|
|
661
|
+
elif norm == "backward":
|
|
662
|
+
norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
|
|
663
|
+
elif norm == "ortho":
|
|
664
|
+
norm_ = 1.0
|
|
664
665
|
return norm_
|
|
665
666
|
|
|
666
667
|
|
|
@@ -670,9 +671,9 @@ def _rfft_norm(norm_shape, norm, rank):
|
|
|
670
671
|
norm_ = 1.0
|
|
671
672
|
if norm == "forward":
|
|
672
673
|
norm_ = 1 / _fft_rank_offset(norm_shape, rank)
|
|
673
|
-
|
|
674
|
-
norm_ = 1
|
|
675
|
-
|
|
674
|
+
elif norm == "backward":
|
|
675
|
+
norm_ = 1.0
|
|
676
|
+
elif norm == "ortho":
|
|
676
677
|
norm_ = 1 / np.sqrt(_fft_rank_offset(norm_shape, rank))
|
|
677
678
|
return norm_
|
|
678
679
|
|
|
@@ -358,10 +358,10 @@ def get_bprop_ragged_tensor_to_sparse(self):
|
|
|
358
358
|
split.append(zeros_like(i))
|
|
359
359
|
all_d = (split, ragged_values_grad)
|
|
360
360
|
return all_d
|
|
361
|
-
|
|
361
|
+
split_ = ()
|
|
362
362
|
for i in enumerate(rt_nested_splits):
|
|
363
|
-
|
|
364
|
-
all_d = (
|
|
363
|
+
split_ = split_ + (zeros_like(i),)
|
|
364
|
+
all_d = (split_, ragged_values_grad)
|
|
365
365
|
return all_d
|
|
366
366
|
|
|
367
367
|
return bprop
|
|
@@ -29,9 +29,9 @@ add_op_info = AiCPURegOp("Add") \
|
|
|
29
29
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
|
30
30
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
|
31
31
|
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
|
32
|
-
.dtype_format(DataType.U16_Default, DataType.
|
|
33
|
-
.dtype_format(DataType.U32_Default, DataType.
|
|
34
|
-
.dtype_format(DataType.U64_Default, DataType.
|
|
32
|
+
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
|
33
|
+
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
|
34
|
+
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
|
35
35
|
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
|
36
36
|
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
|
37
37
|
.get_op_info()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -24,11 +24,30 @@ lsap_op_info = AiCPURegOp("LinearSumAssignment") \
|
|
|
24
24
|
.input(2, 'maximize', "required") \
|
|
25
25
|
.output(0, "row_ind", "required") \
|
|
26
26
|
.output(1, "col_ind", "required") \
|
|
27
|
-
.attr("cust_aicpu", "str") \
|
|
28
27
|
.dtype_format(DataType.F64_Default, DataType.I64_Default,
|
|
29
28
|
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
30
29
|
.dtype_format(DataType.F32_Default, DataType.I64_Default,
|
|
31
30
|
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
31
|
+
.dtype_format(DataType.F16_Default, DataType.I64_Default,
|
|
32
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
33
|
+
.dtype_format(DataType.BOOL_Default, DataType.I64_Default,
|
|
34
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
35
|
+
.dtype_format(DataType.I16_Default, DataType.I64_Default,
|
|
36
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
37
|
+
.dtype_format(DataType.I32_Default, DataType.I64_Default,
|
|
38
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
39
|
+
.dtype_format(DataType.I64_Default, DataType.I64_Default,
|
|
40
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
41
|
+
.dtype_format(DataType.I8_Default, DataType.I64_Default,
|
|
42
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
43
|
+
.dtype_format(DataType.U16_Default, DataType.I64_Default,
|
|
44
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
45
|
+
.dtype_format(DataType.U32_Default, DataType.I64_Default,
|
|
46
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
47
|
+
.dtype_format(DataType.U64_Default, DataType.I64_Default,
|
|
48
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
49
|
+
.dtype_format(DataType.U8_Default, DataType.I64_Default,
|
|
50
|
+
DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
|
|
32
51
|
.get_op_info()
|
|
33
52
|
|
|
34
53
|
|
mindspore/ops/_utils/utils.py
CHANGED
|
@@ -77,10 +77,12 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y
|
|
|
77
77
|
broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
|
|
78
78
|
return broadcast_shape
|
|
79
79
|
|
|
80
|
+
|
|
80
81
|
def dim_not_equal(dim1, dim2):
|
|
81
82
|
"""Compare dim in shape"""
|
|
82
83
|
return dim1 != dim2 and dim1 >= 0 and dim2 >= 0
|
|
83
84
|
|
|
85
|
+
|
|
84
86
|
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
|
85
87
|
"""for concat and concatoffset check args and compute offset"""
|
|
86
88
|
validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
|
|
@@ -1255,7 +1255,8 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
|
|
|
1255
1255
|
index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
|
|
1256
1256
|
index = F.broadcast_to(index, data.shape)
|
|
1257
1257
|
value = F.cast(value, F.dtype(data))
|
|
1258
|
-
|
|
1258
|
+
while value.ndim < data.ndim:
|
|
1259
|
+
value = value.unsqueeze(-1)
|
|
1259
1260
|
value = F.broadcast_to(value, data.shape)
|
|
1260
1261
|
result = F.select(index, value, data)
|
|
1261
1262
|
return result
|
|
@@ -161,7 +161,7 @@ def _tuple_getitem_by_slice(data, slice_index):
|
|
|
161
161
|
if start is None:
|
|
162
162
|
start = 0 if step >= 1 else -1
|
|
163
163
|
if stop is None:
|
|
164
|
-
stop = (2**31-1) if step >= 1 else -(2**31-1)
|
|
164
|
+
stop = (2**31 - 1) if step >= 1 else -(2**31 - 1)
|
|
165
165
|
return sequence_slice(data, start, stop, step)
|
|
166
166
|
return _tuple_slice(data, slice_index)
|
|
167
167
|
|
|
@@ -236,7 +236,7 @@ def _list_getitem_by_slice(data, slice_index):
|
|
|
236
236
|
if start is None:
|
|
237
237
|
start = 0 if step >= 1 else -1
|
|
238
238
|
if stop is None:
|
|
239
|
-
stop = (2**31-1) if step >= 1 else -(2**31-1)
|
|
239
|
+
stop = (2**31 - 1) if step >= 1 else -(2**31 - 1)
|
|
240
240
|
return sequence_slice(data, start, stop, step)
|
|
241
241
|
return _list_slice(data, slice_index)
|
|
242
242
|
|