mindspore 2.2.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_akg/akg/composite/build_module.py +104 -20
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +41 -15
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +3 -3
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/splitter.py +3 -2
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
- mindspore/_extends/parse/__init__.py +3 -2
- mindspore/_extends/parse/parser.py +6 -1
- mindspore/_extends/parse/standard_method.py +14 -11
- mindspore/_extends/remote/kernel_build_server.py +2 -1
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/common/_utils.py +16 -0
- mindspore/common/api.py +1 -1
- mindspore/common/auto_dynamic_shape.py +81 -85
- mindspore/common/dump.py +1 -1
- mindspore/common/tensor.py +3 -20
- mindspore/config/op_info.config +1 -1
- mindspore/context.py +11 -4
- mindspore/dataset/engine/cache_client.py +8 -5
- mindspore/dataset/engine/datasets_standard_format.py +5 -0
- mindspore/dataset/vision/transforms.py +21 -21
- mindspore/experimental/optim/adam.py +1 -1
- mindspore/gen_ops.py +1 -1
- mindspore/include/api/model.h +17 -0
- mindspore/include/api/status.h +8 -3
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/nn/cell.py +0 -3
- mindspore/nn/layer/activation.py +4 -5
- mindspore/nn/layer/conv.py +39 -23
- mindspore/nn/layer/flash_attention.py +54 -129
- mindspore/nn/layer/math.py +3 -7
- mindspore/nn/layer/rnn_cells.py +5 -5
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +12 -3
- mindspore/numpy/utils_const.py +5 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
- mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_utils/utils.py +2 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
- mindspore/ops/function/array_func.py +10 -7
- mindspore/ops/function/grad/grad_func.py +0 -1
- mindspore/ops/function/nn_func.py +98 -9
- mindspore/ops/function/random_func.py +2 -1
- mindspore/ops/op_info_register.py +24 -21
- mindspore/ops/operations/__init__.py +6 -2
- mindspore/ops/operations/_grad_ops.py +25 -6
- mindspore/ops/operations/_inner_ops.py +155 -23
- mindspore/ops/operations/array_ops.py +9 -7
- mindspore/ops/operations/comm_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +85 -68
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +7 -6
- mindspore/ops/operations/nn_ops.py +193 -49
- mindspore/parallel/_parallel_serialization.py +10 -3
- mindspore/parallel/_tensor.py +4 -1
- mindspore/parallel/checkpoint_transform.py +13 -2
- mindspore/parallel/shard.py +17 -10
- mindspore/profiler/common/util.py +1 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
- mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
- mindspore/profiler/parser/ascend_op_generator.py +1 -1
- mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
- mindspore/profiler/parser/base_timeline_generator.py +1 -1
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
- mindspore/profiler/parser/framework_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +19 -0
- mindspore/profiler/profiling.py +46 -24
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/parsers/for_parser.py +7 -7
- mindspore/rewrite/parsers/module_parser.py +4 -4
- mindspore/rewrite/symbol_tree.py +1 -4
- mindspore/run_check/_check_version.py +5 -3
- mindspore/safeguard/rewrite_obfuscation.py +52 -28
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/train/callback/_summary_collector.py +1 -1
- mindspore/train/dataset_helper.py +1 -0
- mindspore/train/model.py +2 -2
- mindspore/train/serialization.py +97 -11
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +23 -7
- mindspore/version.py +1 -1
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
- mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -26,7 +26,7 @@ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
|
26
26
|
from mindspore.ops.operations.comm_ops import ReduceOp
|
|
27
27
|
from mindspore.ops import signature as sig
|
|
28
28
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
29
|
-
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive
|
|
29
|
+
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
30
30
|
_run_op, _check_contains_variable
|
|
31
31
|
from mindspore._c_expression import Tensor as Tensor_
|
|
32
32
|
from mindspore._c_expression import typing
|
|
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
167
167
|
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
168
168
|
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
|
|
169
169
|
"round_mode", self.name)
|
|
170
|
+
self.add_prim_attr("dst_type", mstype.int8)
|
|
170
171
|
|
|
171
172
|
def infer_shape(self, x_shape):
|
|
172
173
|
return x_shape
|
|
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
174
175
|
def infer_dtype(self, x_type):
|
|
175
176
|
validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
|
|
176
177
|
validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
|
|
177
|
-
return
|
|
178
|
+
return self.get_attr_dict()['dst_type']
|
|
178
179
|
|
|
179
180
|
|
|
180
181
|
class Lamb(PrimitiveWithInfer):
|
|
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
|
|
|
491
492
|
self.dtype = dtype
|
|
492
493
|
self.group = group
|
|
493
494
|
self.add_prim_attr("no_eliminate", True)
|
|
494
|
-
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
495
|
+
valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
495
496
|
args = {"dtype": dtype}
|
|
496
497
|
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
|
497
498
|
|
|
@@ -2146,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
|
|
|
2146
2147
|
@prim_attr_register
|
|
2147
2148
|
def __init__(self, axis=None):
|
|
2148
2149
|
"""Initialize ClipByNorm"""
|
|
2150
|
+
self.axis_str = 'axis'
|
|
2149
2151
|
self.axis = () if axis is None else axis
|
|
2150
|
-
validator.check_value_type(
|
|
2152
|
+
validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
|
|
2151
2153
|
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
|
2152
2154
|
for i, value in enumerate(axis_check):
|
|
2153
2155
|
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
|
2154
|
-
self.init_attrs[
|
|
2155
|
-
self.add_prim_attr(
|
|
2156
|
+
self.init_attrs[self.axis_str] = self.axis
|
|
2157
|
+
self.add_prim_attr(self.axis_str, self.axis)
|
|
2156
2158
|
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
|
2157
2159
|
|
|
2158
2160
|
def infer_shape(self, x_shape, clip_norm_shape):
|
|
@@ -2729,27 +2731,29 @@ class CopyWithSlice(Primitive):
|
|
|
2729
2731
|
self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
|
|
2730
2732
|
|
|
2731
2733
|
|
|
2732
|
-
class
|
|
2734
|
+
class FFN(Primitive):
|
|
2733
2735
|
r"""
|
|
2734
|
-
The
|
|
2736
|
+
The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
|
|
2735
2737
|
|
|
2736
2738
|
Args:
|
|
2737
2739
|
activation (string): The activation type, set to 'fastgelu' or 'gelu'.
|
|
2738
|
-
|
|
2740
|
+
Only support 'fastgelu' for now. Default: "fastgelu".
|
|
2741
|
+
inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
|
|
2742
|
+
Only support 1 for now. Default: 0.
|
|
2739
2743
|
|
|
2740
2744
|
Inputs:
|
|
2741
2745
|
- **x** (Tensor) - The input tensor with data type of int8, float16.
|
|
2742
2746
|
Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
|
|
2747
|
+
- **weight1** (Tensor) - The weight1 tensor with data type of float16.
|
|
2748
|
+
Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
|
|
2749
|
+
- **weight2** (Tensor) - The weight2 tensor with data type of float16.
|
|
2750
|
+
Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
|
|
2743
2751
|
- **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
|
|
2744
2752
|
Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
|
|
2745
2753
|
indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
|
|
2746
2754
|
the 2th expert do noting and so on.
|
|
2747
|
-
- **weight1** (Tensor) - The weight1 tensor with data type of float16.
|
|
2748
|
-
Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
|
|
2749
2755
|
- **bias1** (Tensor) - The bias1 tensor with data type of float16.
|
|
2750
2756
|
Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
|
|
2751
|
-
- **weight2** (Tensor) - The weight2 tensor with data type of float16.
|
|
2752
|
-
Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
|
|
2753
2757
|
- **bias2** (Tensor) - The bias2 tensor with data type of float16.
|
|
2754
2758
|
Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
|
|
2755
2759
|
- **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
|
|
@@ -2771,21 +2775,149 @@ class MoeFFN(Primitive):
|
|
|
2771
2775
|
>>> h_f = 4 * h
|
|
2772
2776
|
>>> e = 16
|
|
2773
2777
|
>>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
|
|
2774
|
-
>>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
|
|
2775
2778
|
>>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
|
|
2776
|
-
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2777
2779
|
>>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
|
|
2780
|
+
>>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
|
|
2781
|
+
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2778
2782
|
>>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
|
|
2779
|
-
>>>
|
|
2780
|
-
>>> output =
|
|
2783
|
+
>>> ffn = _inner_ops.FFN("fastgelu", 1)
|
|
2784
|
+
>>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
|
|
2781
2785
|
>>> print(output)
|
|
2782
2786
|
"""
|
|
2783
2787
|
|
|
2784
2788
|
@prim_attr_register
|
|
2785
|
-
def __init__(self, activation):
|
|
2786
|
-
"""Initialize
|
|
2787
|
-
self.init_prim_io_names(inputs=["x", "
|
|
2788
|
-
"
|
|
2789
|
-
"deq_scale2"],
|
|
2789
|
+
def __init__(self, activation, inner_precise):
|
|
2790
|
+
"""Initialize FFN."""
|
|
2791
|
+
self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
|
|
2792
|
+
"bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
|
|
2790
2793
|
outputs=["y"])
|
|
2791
|
-
|
|
2794
|
+
cls_name = self.name
|
|
2795
|
+
validator.check_value_type("activation", activation, [str], cls_name)
|
|
2796
|
+
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2797
|
+
|
|
2798
|
+
|
|
2799
|
+
class DecoderKVCache(Primitive):
|
|
2800
|
+
r"""
|
|
2801
|
+
The DecoderKVCache is used for decoding the KVCache of transformer network.
|
|
2802
|
+
|
|
2803
|
+
Args:
|
|
2804
|
+
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
|
2805
|
+
When seq_len_axis is 2, cache tensor of shape
|
|
2806
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
|
2807
|
+
When seq_len_axis is 1, cache tensor of shape
|
|
2808
|
+
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2809
|
+
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2810
|
+
When seq_len_axis is 2, update tensor of shape
|
|
2811
|
+
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2812
|
+
When seq_len_axis is 1, update tensor of shape
|
|
2813
|
+
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2814
|
+
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2815
|
+
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2816
|
+
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2817
|
+
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2818
|
+
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2819
|
+
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2820
|
+
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2821
|
+
Indicate that user want to change the shape of cache tensor from
|
|
2822
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2823
|
+
:math:
|
|
2824
|
+
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2825
|
+
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2826
|
+
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2827
|
+
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2828
|
+
|
|
2829
|
+
Outputs:
|
|
2830
|
+
With same data type and same shape as `cache` tensor.
|
|
2831
|
+
|
|
2832
|
+
Supported Platforms:
|
|
2833
|
+
``Ascend``
|
|
2834
|
+
|
|
2835
|
+
Examples:
|
|
2836
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2837
|
+
>>> b = 4
|
|
2838
|
+
>>> h = 40
|
|
2839
|
+
>>> max_s = 1024
|
|
2840
|
+
>>> s = 1
|
|
2841
|
+
>>> d = 128
|
|
2842
|
+
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2843
|
+
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2844
|
+
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2845
|
+
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2846
|
+
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2847
|
+
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2848
|
+
>>> decoder_kv_cache = _inner_ops.DecoderKVCache()
|
|
2849
|
+
>>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2850
|
+
>>> print(cache)
|
|
2851
|
+
"""
|
|
2852
|
+
@prim_attr_register
|
|
2853
|
+
def __init__(self):
|
|
2854
|
+
"""Initialize DecoderKVCache."""
|
|
2855
|
+
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2856
|
+
"new_max_seq_len", "cur_max_seq_len"],
|
|
2857
|
+
outputs=["out"])
|
|
2858
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2859
|
+
|
|
2860
|
+
|
|
2861
|
+
class PromptKVCache(Primitive):
|
|
2862
|
+
r"""
|
|
2863
|
+
The PromptKVCache is used for prefill the KVCache of transformer network.
|
|
2864
|
+
|
|
2865
|
+
Args:
|
|
2866
|
+
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
|
2867
|
+
When seq_len_axis is 2, cache tensor of shape
|
|
2868
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
|
2869
|
+
When seq_len_axis is 1, cache tensor of shape
|
|
2870
|
+
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2871
|
+
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2872
|
+
When seq_len_axis is 2, update tensor of shape
|
|
2873
|
+
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2874
|
+
When seq_len_axis is 1, update tensor of shape
|
|
2875
|
+
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2876
|
+
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2877
|
+
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2878
|
+
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2879
|
+
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2880
|
+
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2881
|
+
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2882
|
+
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2883
|
+
Indicate that user want to change the shape of cache tensor from
|
|
2884
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2885
|
+
:math:
|
|
2886
|
+
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2887
|
+
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2888
|
+
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2889
|
+
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2890
|
+
align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
|
|
2891
|
+
|
|
2892
|
+
Outputs:
|
|
2893
|
+
With same data type and same shape as `cache` tensor.
|
|
2894
|
+
|
|
2895
|
+
Supported Platforms:
|
|
2896
|
+
``Ascend``
|
|
2897
|
+
|
|
2898
|
+
Examples:
|
|
2899
|
+
>>> from mindspore import Tensor
|
|
2900
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2901
|
+
>>> b = 4
|
|
2902
|
+
>>> h = 40
|
|
2903
|
+
>>> max_s = 1024
|
|
2904
|
+
>>> s = 256
|
|
2905
|
+
>>> d = 128
|
|
2906
|
+
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2907
|
+
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2908
|
+
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2909
|
+
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2910
|
+
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2911
|
+
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2912
|
+
>>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
|
|
2913
|
+
>>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2914
|
+
>>> print(cache)
|
|
2915
|
+
"""
|
|
2916
|
+
@prim_attr_register
|
|
2917
|
+
def __init__(self, padding_mode="right"):
|
|
2918
|
+
"""Initialize PromptKVCache."""
|
|
2919
|
+
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2920
|
+
"new_max_seq_len", "cur_max_seq_len"],
|
|
2921
|
+
outputs=["out"])
|
|
2922
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2923
|
+
self.padding_mode = padding_mode
|
|
@@ -1208,7 +1208,7 @@ class UniqueWithPad(Primitive):
|
|
|
1208
1208
|
|
|
1209
1209
|
|
|
1210
1210
|
class Split(Primitive):
|
|
1211
|
-
"""
|
|
1211
|
+
r"""
|
|
1212
1212
|
Splits the input tensor into output_num of tensors along the given axis and output numbers.
|
|
1213
1213
|
|
|
1214
1214
|
Refer to :func:`mindspore.ops.split` for more details.
|
|
@@ -1222,7 +1222,7 @@ class Split(Primitive):
|
|
|
1222
1222
|
|
|
1223
1223
|
Outputs:
|
|
1224
1224
|
tuple[Tensor], the shape of each output tensor is the same, which is
|
|
1225
|
-
:math:`(x_0, x_1, ..., x_{axis}/{
|
|
1225
|
+
:math:`(x_0, x_1, ..., x_{axis}/{output\_num}, ..., x_{R-1})`.
|
|
1226
1226
|
And the data type is the same as `input_x`.
|
|
1227
1227
|
|
|
1228
1228
|
Supported Platforms:
|
|
@@ -1763,16 +1763,18 @@ class FillV2(PrimitiveWithCheck):
|
|
|
1763
1763
|
self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
|
|
1764
1764
|
|
|
1765
1765
|
def check_elim(self, dims, x):
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
isinstance(dims, (Tensor, Tensor_))
|
|
1766
|
+
x_is_invalid = x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ())
|
|
1767
|
+
dims_is_invalid = dims is None or (isinstance(dims, (tuple, list)) and dims) or\
|
|
1768
|
+
isinstance(dims, (Tensor, Tensor_))
|
|
1769
|
+
if x_is_invalid or dims_is_invalid:
|
|
1769
1770
|
return (False, None)
|
|
1770
1771
|
return (True, x)
|
|
1771
1772
|
|
|
1772
1773
|
def infer_value(self, dims, x):
|
|
1773
|
-
|
|
1774
|
+
dims_is_invalid = dims is None or\
|
|
1774
1775
|
(isinstance(dims, (tuple, list)) and dims) or\
|
|
1775
|
-
isinstance(dims, (Tensor, Tensor_))
|
|
1776
|
+
isinstance(dims, (Tensor, Tensor_))
|
|
1777
|
+
if x is None or dims_is_invalid:
|
|
1776
1778
|
return None
|
|
1777
1779
|
return x
|
|
1778
1780
|
|
|
@@ -94,7 +94,7 @@ class ReduceOp:
|
|
|
94
94
|
|
|
95
95
|
def check_collective_target_dtype(data_name, data_dtype, prim_name):
|
|
96
96
|
"""Check if data type is valid."""
|
|
97
|
-
default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
|
|
97
|
+
default_target_dtypes = (mstype.int8, mstype.uint8, mstype.int32, mstype.float16, mstype.bfloat16, mstype.float32)
|
|
98
98
|
gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
|
|
99
99
|
mstype.float16, mstype.float32, mstype.float64)
|
|
100
100
|
|
|
@@ -1310,4 +1310,4 @@ class _GetTensorSlice(PrimitiveWithInfer):
|
|
|
1310
1310
|
from mindspore.parallel._tensor import _load_tensor
|
|
1311
1311
|
validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
|
|
1312
1312
|
validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
|
|
1313
|
-
return Tensor(_load_tensor(x, dev_mat, tensor_map))
|
|
1313
|
+
return Tensor(_load_tensor(x, dev_mat, tensor_map), x.dtype)
|
|
@@ -42,6 +42,24 @@ from ._pyfunc_registry import add_pyfunc
|
|
|
42
42
|
if platform.system() != "Windows":
|
|
43
43
|
import fcntl
|
|
44
44
|
|
|
45
|
+
KEY_ATTR = "attr"
|
|
46
|
+
KEY_NAME = "name"
|
|
47
|
+
INPUT_NAMES = "input_names"
|
|
48
|
+
ATTR_NAMES = "attr_names"
|
|
49
|
+
AUTO_DIFF = "autodiff"
|
|
50
|
+
IMPLY_TYPE = "imply_type"
|
|
51
|
+
FUSION_TYPE = "fusion_type"
|
|
52
|
+
MS_KERNEL_FLAG = "ms_kernel_flag"
|
|
53
|
+
AKG = "AKG"
|
|
54
|
+
TBE = "TBE"
|
|
55
|
+
CUDA = "CUDA"
|
|
56
|
+
AICORE = "AiCore"
|
|
57
|
+
CPU = "CPU"
|
|
58
|
+
GPU = "GPU"
|
|
59
|
+
ASCEND = "Ascend"
|
|
60
|
+
HYBRID_TYPE = "hybrid"
|
|
61
|
+
OP_NAME = "op_name"
|
|
62
|
+
|
|
45
63
|
|
|
46
64
|
def _get_cache_path():
|
|
47
65
|
"""
|
|
@@ -150,7 +168,6 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
150
168
|
|
|
151
169
|
.. warning::
|
|
152
170
|
- This is an experimental API that is subject to change.
|
|
153
|
-
- Currently, the functionality of Custom does not support Ascend 910B.
|
|
154
171
|
|
|
155
172
|
.. note::
|
|
156
173
|
The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
|
|
@@ -453,10 +470,10 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
453
470
|
op_path_in_cache = [] # Save paths for op functions created in the cached.
|
|
454
471
|
custom_aot_warning = True # Flag to enable warnings about custom aot path white list
|
|
455
472
|
|
|
456
|
-
def __init__(self, func, out_shape=None, out_dtype=None, func_type=
|
|
457
|
-
|
|
473
|
+
def __init__(self, func, out_shape=None, out_dtype=None, func_type=HYBRID_TYPE, bprop=None, reg_info=None):
|
|
474
|
+
super().__init__("Custom")
|
|
458
475
|
|
|
459
|
-
self.supported_targets = [
|
|
476
|
+
self.supported_targets = [ASCEND, GPU, CPU]
|
|
460
477
|
self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
|
|
461
478
|
self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
|
|
462
479
|
self.func = func
|
|
@@ -473,7 +490,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
473
490
|
self._update_func_info(reg_info)
|
|
474
491
|
self.add_prim_attr("func_name", self.func_name)
|
|
475
492
|
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
476
|
-
if self.func_type ==
|
|
493
|
+
if self.func_type == HYBRID_TYPE:
|
|
477
494
|
self.add_prim_attr("func_compile_attrs", self._func_compile_attrs)
|
|
478
495
|
|
|
479
496
|
self.add_prim_attr("imply_path", self.imply_path)
|
|
@@ -502,7 +519,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
502
519
|
if func_type == "akg":
|
|
503
520
|
self._set_akg_kernel_type()
|
|
504
521
|
|
|
505
|
-
if not self.bprop and self.func_type ==
|
|
522
|
+
if not self.bprop and self.func_type == HYBRID_TYPE:
|
|
506
523
|
self._hybrid_autodiff(func_type)
|
|
507
524
|
|
|
508
525
|
self.add_prim_attr("func_type", self.func_type)
|
|
@@ -577,7 +594,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
577
594
|
elif "compute" in self.func_source_str:
|
|
578
595
|
self.func_type = "tvm_compute"
|
|
579
596
|
else:
|
|
580
|
-
self.func_type =
|
|
597
|
+
self.func_type = HYBRID_TYPE
|
|
581
598
|
self._hybrid_func_analyser()
|
|
582
599
|
|
|
583
600
|
def _check_julia_func(self):
|
|
@@ -633,18 +650,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
633
650
|
|
|
634
651
|
elif self.func_type == "julia":
|
|
635
652
|
self._check_julia_func()
|
|
636
|
-
elif self.func_type ==
|
|
637
|
-
if not hasattr(self.func,
|
|
653
|
+
elif self.func_type == HYBRID_TYPE:
|
|
654
|
+
if not hasattr(self.func, MS_KERNEL_FLAG):
|
|
638
655
|
raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
|
|
639
656
|
self._is_ms_kernel = True
|
|
640
657
|
self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
|
|
641
658
|
elif self.func_type == "akg":
|
|
642
|
-
if hasattr(self.func,
|
|
659
|
+
if hasattr(self.func, MS_KERNEL_FLAG):
|
|
643
660
|
logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
|
|
644
661
|
"for the input function with decorator @kernel. "
|
|
645
662
|
"To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
|
|
646
663
|
elif self.func_type == "pyfunc":
|
|
647
|
-
if hasattr(self.func,
|
|
664
|
+
if hasattr(self.func, MS_KERNEL_FLAG):
|
|
648
665
|
logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
|
|
649
666
|
"The kernel will be executed as a native python function, which might lead to "
|
|
650
667
|
"low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
|
|
@@ -758,7 +775,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
758
775
|
continue
|
|
759
776
|
if isinstance(reg_info_item, str):
|
|
760
777
|
reg_info_item = json.loads(reg_info_item)
|
|
761
|
-
prefix = "_".join([prefix, reg_info_item.get(
|
|
778
|
+
prefix = "_".join([prefix, reg_info_item.get(OP_NAME, "")])
|
|
762
779
|
self.uniq_name = prefix + "_" + self.func_name
|
|
763
780
|
else:
|
|
764
781
|
raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
|
|
@@ -768,23 +785,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
768
785
|
"""Update op attrs in reg_info."""
|
|
769
786
|
output_name_list = []
|
|
770
787
|
for _, item in enumerate(reg_info.get("outputs", [])):
|
|
771
|
-
if isinstance(item, dict) and item.get(
|
|
772
|
-
output_name_list.append(item.get(
|
|
788
|
+
if isinstance(item, dict) and item.get(KEY_NAME):
|
|
789
|
+
output_name_list.append(item.get(KEY_NAME))
|
|
773
790
|
if output_name_list:
|
|
774
791
|
self.add_prim_attr("output_names", output_name_list)
|
|
775
792
|
|
|
776
|
-
if isinstance(reg_info.get(
|
|
777
|
-
self.add_prim_attr("reg_op_name", reg_info.get(
|
|
793
|
+
if isinstance(reg_info.get(OP_NAME), str):
|
|
794
|
+
self.add_prim_attr("reg_op_name", reg_info.get(OP_NAME))
|
|
778
795
|
|
|
779
796
|
if self.func_type == "aicpu":
|
|
780
|
-
self.uniq_name = reg_info[
|
|
797
|
+
self.uniq_name = reg_info[OP_NAME]
|
|
781
798
|
self.add_prim_attr("uniq_name", self.uniq_name)
|
|
782
799
|
|
|
783
800
|
if self.func_type in ["aot", "aicpu"]:
|
|
784
|
-
if reg_info.get(
|
|
785
|
-
for item in reg_info[
|
|
801
|
+
if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
|
|
802
|
+
for item in reg_info[KEY_ATTR]:
|
|
786
803
|
if isinstance(item, dict) and item.get("value") is not None:
|
|
787
|
-
self.add_prim_attr(item[
|
|
804
|
+
self.add_prim_attr(item[KEY_NAME], item["value"])
|
|
788
805
|
|
|
789
806
|
def _register_info(self, info):
|
|
790
807
|
"""Register reg_info."""
|
|
@@ -802,7 +819,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
802
819
|
if isinstance(reg_info, str):
|
|
803
820
|
reg_info = json.loads(reg_info)
|
|
804
821
|
if self.fake_output:
|
|
805
|
-
reg_info["outputs"].append(dict({"index": 0,
|
|
822
|
+
reg_info["outputs"].append(dict({"index": 0, KEY_NAME: "y", "param_type": "required"}))
|
|
806
823
|
new_dtype_format = []
|
|
807
824
|
for i in reg_info["dtype_format"]:
|
|
808
825
|
new_dtype_format.append(i + (DataType.I32_Default,))
|
|
@@ -874,16 +891,16 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
874
891
|
"'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
|
|
875
892
|
"use 'custom_info_register' to bind it to 'func' if 'func' is a function."
|
|
876
893
|
.format(self.log_prefix, reg_info, type(reg_info)))
|
|
877
|
-
reg_info[
|
|
878
|
-
reg_info[
|
|
879
|
-
if not isinstance(reg_info.get(
|
|
880
|
-
reg_info[
|
|
894
|
+
reg_info[OP_NAME] = self.uniq_name
|
|
895
|
+
reg_info[IMPLY_TYPE] = self._get_imply_type(reg_info, target)
|
|
896
|
+
if not isinstance(reg_info.get(FUSION_TYPE), str) or not reg_info[FUSION_TYPE].strip():
|
|
897
|
+
reg_info[FUSION_TYPE] = "OPAQUE"
|
|
881
898
|
# Supplement necessary info for TBE if these information is missing in reg_info
|
|
882
|
-
if reg_info[
|
|
883
|
-
if reg_info.get(
|
|
884
|
-
for i, item in enumerate(reg_info[
|
|
899
|
+
if reg_info[IMPLY_TYPE] == TBE:
|
|
900
|
+
if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
|
|
901
|
+
for i, item in enumerate(reg_info[KEY_ATTR]):
|
|
885
902
|
if isinstance(item, dict) and item.get("value") is None:
|
|
886
|
-
reg_info[
|
|
903
|
+
reg_info[KEY_ATTR][i]["value"] = "all"
|
|
887
904
|
reg_info["async_flag"] = reg_info.get("async_flag", False)
|
|
888
905
|
reg_info["binfile"] = "%s.so" % self.func_name
|
|
889
906
|
reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
|
|
@@ -891,8 +908,8 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
891
908
|
reg_info["partial_flag"] = reg_info.get("partial_flag", True)
|
|
892
909
|
reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
|
|
893
910
|
# Supplement necessary info for AKG if these information is missing in reg_info
|
|
894
|
-
if reg_info[
|
|
895
|
-
target_to_processor = {
|
|
911
|
+
if reg_info[IMPLY_TYPE] == AKG:
|
|
912
|
+
target_to_processor = {ASCEND: AICORE, GPU: CUDA, CPU: CPU}
|
|
896
913
|
reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
|
|
897
914
|
return reg_info
|
|
898
915
|
|
|
@@ -905,15 +922,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
905
922
|
# Infer target from reg_info["processor"], reg_info generated from AkgGpuRegOp or AkgAscendRegOp
|
|
906
923
|
# will have the processor information.
|
|
907
924
|
if target not in self.supported_targets:
|
|
908
|
-
processor_to_target = {
|
|
925
|
+
processor_to_target = {AICORE: ASCEND, CUDA: GPU, CPU: CPU}
|
|
909
926
|
target = processor_to_target.get(reg_info.get("processor"))
|
|
910
|
-
# Infer target from reg_info[
|
|
927
|
+
# Infer target from reg_info[IMPLY_TYPE]
|
|
911
928
|
if target not in self.supported_targets:
|
|
912
|
-
imply_type_to_target = {
|
|
913
|
-
target = imply_type_to_target.get(reg_info.get(
|
|
929
|
+
imply_type_to_target = {TBE: ASCEND, GPU: GPU, CPU: CPU}
|
|
930
|
+
target = imply_type_to_target.get(reg_info.get(IMPLY_TYPE))
|
|
914
931
|
# Infer target from func_type
|
|
915
932
|
if target not in self.supported_targets:
|
|
916
|
-
func_type_to_target = {"tbe":
|
|
933
|
+
func_type_to_target = {"tbe": ASCEND, "pyfunc": CPU}
|
|
917
934
|
target = func_type_to_target.get(self.func_type)
|
|
918
935
|
if target not in self.supported_targets:
|
|
919
936
|
raise ValueError("{}, target set in registration information must be one of {}, but got {}"
|
|
@@ -922,14 +939,14 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
922
939
|
|
|
923
940
|
def _get_imply_type(self, reg_info, target):
|
|
924
941
|
"""Get imply_typ information."""
|
|
925
|
-
# Get imply_type from reg_info[
|
|
926
|
-
if isinstance(reg_info, dict) and isinstance(reg_info.get(
|
|
927
|
-
reg_info[
|
|
928
|
-
return reg_info[
|
|
942
|
+
# Get imply_type from reg_info[IMPLY_TYPE]
|
|
943
|
+
if isinstance(reg_info, dict) and isinstance(reg_info.get(IMPLY_TYPE), str) and \
|
|
944
|
+
reg_info[IMPLY_TYPE].strip():
|
|
945
|
+
return reg_info[IMPLY_TYPE]
|
|
929
946
|
# Infer imply_type from func_type
|
|
930
|
-
func_type_to_imply_type = {"hybrid":
|
|
931
|
-
"julia": target, "aot": "BiSheng" if target ==
|
|
932
|
-
return func_type_to_imply_type.get(self.func_type,
|
|
947
|
+
func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
|
|
948
|
+
"julia": target, "aot": "BiSheng" if target == ASCEND else target}
|
|
949
|
+
return func_type_to_imply_type.get(self.func_type, AKG)
|
|
933
950
|
|
|
934
951
|
def _save_attr(self, reg_info):
|
|
935
952
|
"""Save input_names and attr_names of current func."""
|
|
@@ -943,18 +960,18 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
943
960
|
return value
|
|
944
961
|
|
|
945
962
|
tensor_inputs = _get_value_list("inputs")
|
|
946
|
-
attr = _get_value_list(
|
|
963
|
+
attr = _get_value_list(KEY_ATTR)
|
|
947
964
|
input_names = [] # include tensor input names and attr input names
|
|
948
965
|
attr_names = []
|
|
949
966
|
pure_input_names = []
|
|
950
967
|
for item in tensor_inputs:
|
|
951
|
-
if isinstance(item, dict) and item.get(
|
|
952
|
-
input_names.append(item[
|
|
953
|
-
pure_input_names.append(item[
|
|
968
|
+
if isinstance(item, dict) and item.get(KEY_NAME) is not None:
|
|
969
|
+
input_names.append(item[KEY_NAME])
|
|
970
|
+
pure_input_names.append(item[KEY_NAME])
|
|
954
971
|
# attr is converted from inputs only when graph mode or when inputs name is also in reg info
|
|
955
972
|
attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
|
|
956
973
|
for item in attr:
|
|
957
|
-
if isinstance(item, dict) and item.get(
|
|
974
|
+
if isinstance(item, dict) and item.get(KEY_NAME) is not None:
|
|
958
975
|
# for custom op with function tbe, we always add attrs to inputs as we don't
|
|
959
976
|
# deal with attr value here and leave them to the backend process to fit the
|
|
960
977
|
# usual process of tbe op compiling in mindspore
|
|
@@ -963,9 +980,9 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
963
980
|
# add attr name to input name only when the value of attr is None in reg info
|
|
964
981
|
# as we need to get values of attrs from inputs
|
|
965
982
|
if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
|
|
966
|
-
input_names.append(item[
|
|
967
|
-
attr_names.append(item[
|
|
968
|
-
cur_attr = {
|
|
983
|
+
input_names.append(item[KEY_NAME])
|
|
984
|
+
attr_names.append(item[KEY_NAME])
|
|
985
|
+
cur_attr = {INPUT_NAMES: input_names, ATTR_NAMES: attr_names, "pure_input_names": pure_input_names}
|
|
969
986
|
# If func does not have attr, save current attr.
|
|
970
987
|
# Else, check if current attr is same as previous saved one.
|
|
971
988
|
prev_attr_names = attr_names
|
|
@@ -974,13 +991,13 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
974
991
|
if not isinstance(func_attr, dict):
|
|
975
992
|
setattr(self.func, "func_attr", cur_attr)
|
|
976
993
|
else:
|
|
977
|
-
prev_attr_names = func_attr.get(
|
|
994
|
+
prev_attr_names = func_attr.get(ATTR_NAMES)
|
|
978
995
|
elif isinstance(self.func, str):
|
|
979
996
|
func_attr = Custom.attr_dict.get(self.func)
|
|
980
997
|
if not isinstance(func_attr, dict):
|
|
981
998
|
Custom.attr_dict[self.func] = cur_attr
|
|
982
999
|
else:
|
|
983
|
-
prev_attr_names = func_attr.get(
|
|
1000
|
+
prev_attr_names = func_attr.get(ATTR_NAMES)
|
|
984
1001
|
if attr_names != prev_attr_names:
|
|
985
1002
|
raise ValueError("{}, attr names set in registration information must be the same as previous saved one, "
|
|
986
1003
|
"but got {} vs {}".format(self.log_prefix, attr_names, prev_attr_names))
|
|
@@ -989,23 +1006,23 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
989
1006
|
"""Add primitive_target to primitive's attr."""
|
|
990
1007
|
registered_targets = self._get_registered_targets()
|
|
991
1008
|
if self.func_type == "pyfunc":
|
|
992
|
-
self.set_device(
|
|
993
|
-
if registered_targets and registered_targets != [
|
|
1009
|
+
self.set_device(CPU)
|
|
1010
|
+
if registered_targets and registered_targets != [CPU]:
|
|
994
1011
|
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
|
995
1012
|
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
|
996
1013
|
elif self.func_type == "aot":
|
|
997
1014
|
if len(registered_targets) != 1:
|
|
998
1015
|
logger.info("{}, target will be set according to context.".format(self.log_prefix))
|
|
999
|
-
elif registered_targets == [
|
|
1000
|
-
self.set_device(
|
|
1001
|
-
elif registered_targets == [
|
|
1002
|
-
self.set_device(
|
|
1016
|
+
elif registered_targets == [GPU]:
|
|
1017
|
+
self.set_device(GPU)
|
|
1018
|
+
elif registered_targets == [CPU]:
|
|
1019
|
+
self.set_device(CPU)
|
|
1003
1020
|
elif self.func_type == "julia":
|
|
1004
|
-
self.set_device(
|
|
1021
|
+
self.set_device(CPU)
|
|
1005
1022
|
device_target = context.get_context('device_target')
|
|
1006
|
-
if device_target ==
|
|
1023
|
+
if device_target == CPU:
|
|
1007
1024
|
pass
|
|
1008
|
-
elif device_target ==
|
|
1025
|
+
elif device_target == GPU and registered_targets and registered_targets == [CPU]:
|
|
1009
1026
|
logger.warning("{}, only supports CPU platform, but got registered target {}. "
|
|
1010
1027
|
"We will run it on CPU".format(self.log_prefix, registered_targets))
|
|
1011
1028
|
else:
|
|
@@ -1028,15 +1045,15 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1028
1045
|
elif isinstance(self.func, str):
|
|
1029
1046
|
func_attr = Custom.attr_dict.get(self.func)
|
|
1030
1047
|
if isinstance(func_attr, dict):
|
|
1031
|
-
_add_prim_attr(
|
|
1032
|
-
_add_prim_attr(
|
|
1048
|
+
_add_prim_attr(INPUT_NAMES)
|
|
1049
|
+
_add_prim_attr(ATTR_NAMES)
|
|
1033
1050
|
_add_prim_attr("pure_input_names")
|
|
1034
1051
|
self._add_prim_target()
|
|
1035
1052
|
if callable(self.func) and callable(self.out_shape):
|
|
1036
|
-
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") ==
|
|
1037
|
-
self.add_prim_attr(
|
|
1053
|
+
if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == AUTO_DIFF:
|
|
1054
|
+
self.add_prim_attr(AUTO_DIFF, True)
|
|
1038
1055
|
else:
|
|
1039
|
-
self.add_prim_attr(
|
|
1056
|
+
self.add_prim_attr(AUTO_DIFF, False)
|
|
1040
1057
|
|
|
1041
1058
|
def _hybrid_autodiff(self, input_func_type):
|
|
1042
1059
|
"""generate backward op for a custom hybrid op"""
|
|
@@ -1052,7 +1069,7 @@ class Custom(ops.PrimitiveWithInfer):
|
|
|
1052
1069
|
def infer_func(*args):
|
|
1053
1070
|
return args[:inputs_num]
|
|
1054
1071
|
|
|
1055
|
-
setattr(infer_func, "type",
|
|
1072
|
+
setattr(infer_func, "type", AUTO_DIFF)
|
|
1056
1073
|
op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
|
|
1057
1074
|
func_type=input_func_type, bprop=True)
|
|
1058
1075
|
self.bprop = grad_func(op)
|