mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
- mindspore/_extends/parse/parser.py +22 -28
- mindspore/_extends/parse/standard_method.py +1 -15
- mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/amp.py +18 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +12 -18
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +38 -102
- mindspore/common/_utils.py +1 -9
- mindspore/common/api.py +106 -155
- mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
- mindspore/common/dtype.py +57 -98
- mindspore/common/dump.py +1 -1
- mindspore/common/file_system.py +9 -59
- mindspore/common/hook_handle.py +3 -22
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +20 -4
- mindspore/common/recompute.py +4 -2
- mindspore/common/tensor.py +52 -38
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/context.py +21 -15
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +1 -35
- mindspore/dataset/engine/datasets.py +315 -330
- mindspore/dataset/engine/datasets_user_defined.py +22 -38
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +5 -17
- mindspore/dataset/vision/utils.py +21 -632
- mindspore/device_context/ascend/op_tuning.py +1 -35
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
- mindspore/include/api/cell.h +4 -28
- mindspore/include/api/cfg.h +7 -24
- mindspore/include/api/context.h +0 -1
- mindspore/include/api/delegate.h +2 -0
- mindspore/include/api/dual_abi_helper.h +19 -100
- mindspore/include/api/graph.h +1 -14
- mindspore/include/api/kernel.h +3 -16
- mindspore/include/api/kernel_api.h +1 -9
- mindspore/include/api/metrics/accuracy.h +0 -9
- mindspore/include/api/model.h +1 -5
- mindspore/include/api/model_group.h +0 -4
- mindspore/include/api/model_parallel_runner.h +0 -2
- mindspore/include/api/status.h +10 -48
- mindspore/include/api/types.h +1 -6
- mindspore/include/dataset/constants.h +0 -9
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +2 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/distributed/__init__.py +0 -4
- mindspore/mint/distributed/distributed.py +14 -217
- mindspore/mint/nn/layer/_functions.py +2 -1
- mindspore/mint/nn/layer/conv.py +6 -6
- mindspore/mint/nn/layer/normalization.py +3 -3
- mindspore/nn/cell.py +174 -216
- mindspore/nn/layer/activation.py +2 -4
- mindspore/nn/layer/basic.py +13 -7
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/optim/adam.py +3 -1
- mindspore/nn/optim/lamb.py +3 -1
- mindspore/nn/optim/tft_wrapper.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +5 -39
- mindspore/nn/wrap/grad_reducer.py +15 -0
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_op_impl/cpu/__init__.py +0 -1
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
- mindspore/ops/auto_generate/gen_extend_func.py +4 -4
- mindspore/ops/auto_generate/gen_ops_def.py +16 -290
- mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/function/__init__.py +0 -1
- mindspore/ops/function/array_func.py +6 -10
- mindspore/ops/function/debug_func.py +2 -4
- mindspore/ops/function/grad/grad_func.py +12 -4
- mindspore/ops/function/math_func.py +32 -44
- mindspore/ops/function/nn_func.py +20 -18
- mindspore/ops/functional.py +1 -2
- mindspore/ops/functional_overload.py +12 -23
- mindspore/ops/operations/_inner_ops.py +12 -11
- mindspore/ops/operations/array_ops.py +50 -4
- mindspore/ops/operations/comm_ops.py +15 -1
- mindspore/ops/operations/custom_ops.py +4 -10
- mindspore/ops/operations/debug_ops.py +6 -6
- mindspore/ops/operations/manually_defined/ops_def.py +12 -12
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +1 -1
- mindspore/ops/primitive.py +10 -3
- mindspore/ops/tensor_method.py +7 -16
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
- mindspore/parallel/_auto_parallel_context.py +15 -5
- mindspore/parallel/_parallel_serialization.py +2 -3
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_transformer/transformer.py +4 -4
- mindspore/parallel/_utils.py +11 -5
- mindspore/parallel/auto_parallel.py +9 -23
- mindspore/parallel/checkpoint_transform.py +0 -2
- mindspore/parallel/cluster/process_entity/_api.py +1 -4
- mindspore/parallel/cluster/run.py +3 -5
- mindspore/parallel/function/reshard_func.py +5 -6
- mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
- mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
- mindspore/parallel/shard.py +21 -7
- mindspore/parallel/transform_safetensors.py +4 -10
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +0 -9
- mindspore/profiler/common/profiler_context.py +2 -25
- mindspore/profiler/common/profiler_meta_data.py +0 -1
- mindspore/profiler/common/profiler_op_analyse.py +6 -10
- mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +46 -91
- mindspore/profiler/envprofiler.py +5 -30
- mindspore/profiler/experimental_config.py +1 -16
- mindspore/profiler/platform/cpu_profiler.py +4 -10
- mindspore/profiler/platform/npu_profiler.py +1 -1
- mindspore/profiler/profiler.py +145 -193
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/runtime/__init__.py +4 -6
- mindspore/runtime/executor.py +0 -27
- mindspore/runtime/memory.py +0 -1
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +3 -3
- mindspore/train/amp.py +3 -0
- mindspore/train/callback/_callback.py +1 -2
- mindspore/train/callback/_checkpoint.py +8 -1
- mindspore/train/callback/_flops_collector.py +6 -10
- mindspore/train/callback/_train_fault_tolerance.py +7 -3
- mindspore/train/data_sink.py +4 -4
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +20 -4
- mindspore/train/serialization.py +15 -35
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
- /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -1,1109 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
"""Custom op dsl file, used for dynamic format/data type select, update akg info and compile akg info"""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
|
-
import os
|
|
18
|
-
import sys
|
|
19
|
-
import json
|
|
20
|
-
import copy
|
|
21
|
-
import functools
|
|
22
|
-
import subprocess
|
|
23
|
-
import shutil
|
|
24
|
-
|
|
25
|
-
from tbe.common.buildcfg import get_current_build_config
|
|
26
|
-
from impl.util.util_select_op_base import gen_param
|
|
27
|
-
from impl.util.util_select_op_base import get_dynamic_param_in_json
|
|
28
|
-
|
|
29
|
-
BLOCK = 16
|
|
30
|
-
FP16_MAX = 65504
|
|
31
|
-
OP = "op"
|
|
32
|
-
STR = "str"
|
|
33
|
-
NAME = "name"
|
|
34
|
-
TENSOR_NAME = "tensor_name"
|
|
35
|
-
ATTR = "attr"
|
|
36
|
-
VALUE = "value"
|
|
37
|
-
SHAPE = "shape"
|
|
38
|
-
FORMAT = "format"
|
|
39
|
-
DATA_TYPE = "data_type"
|
|
40
|
-
ORI_SHAPE = "ori_shape"
|
|
41
|
-
ORI_FORMAT = "ori_format"
|
|
42
|
-
ORI_DATA_TYPE = "ori_data_type"
|
|
43
|
-
OP_DESC = "op_desc"
|
|
44
|
-
INPUT_DESC = "input_desc"
|
|
45
|
-
OUTPUT_DESC = "output_desc"
|
|
46
|
-
FRACTAL_NZ = "FRACTAL_NZ"
|
|
47
|
-
DEFAULT_FORMAT = "DefaultFormat"
|
|
48
|
-
FLOAT16 = "float16"
|
|
49
|
-
FLOAT32 = "float32"
|
|
50
|
-
O_SUFFIX = ".o"
|
|
51
|
-
JSON_SUFFIX = ".json"
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def copy_shape(shape):
|
|
55
|
-
"""Deep copy shape"""
|
|
56
|
-
res = []
|
|
57
|
-
if isinstance(shape, int):
|
|
58
|
-
shape = [shape]
|
|
59
|
-
for _, s in enumerate(shape):
|
|
60
|
-
res.append(s)
|
|
61
|
-
return res
|
|
62
|
-
|
|
63
|
-
# InfoGlobalConfig is used to store global configuration for info files.
|
|
64
|
-
# It can be accessed or modified internally in custom.py using InfoGlobalConfig.xxx.
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class InfoGlobalConfig:
|
|
68
|
-
# whether enable akg cce lib
|
|
69
|
-
enable_cce_lib = False
|
|
70
|
-
# ascend arch type, for 910B and 910A
|
|
71
|
-
ascend_arch = ""
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
class OpInfer:
|
|
75
|
-
"""Base infer class, used to provide supported formats and data type of each op and update each of"""
|
|
76
|
-
|
|
77
|
-
def __init__(self, op_desc):
|
|
78
|
-
self.name = op_desc[NAME]
|
|
79
|
-
self.op_desc = op_desc
|
|
80
|
-
self.input_desc = []
|
|
81
|
-
self.output_desc = []
|
|
82
|
-
self.attr = {}
|
|
83
|
-
if isinstance(op_desc.get(INPUT_DESC), list):
|
|
84
|
-
for desc in op_desc[INPUT_DESC]:
|
|
85
|
-
for item in desc:
|
|
86
|
-
self.input_desc.append(item)
|
|
87
|
-
if isinstance(op_desc.get(ATTR), list):
|
|
88
|
-
for item in op_desc[ATTR]:
|
|
89
|
-
self.attr[item[NAME]] = item
|
|
90
|
-
if isinstance(op_desc.get(OUTPUT_DESC), list):
|
|
91
|
-
for item in op_desc[OUTPUT_DESC]:
|
|
92
|
-
self.output_desc.append(item)
|
|
93
|
-
|
|
94
|
-
@staticmethod
|
|
95
|
-
def is_nz(shape):
|
|
96
|
-
"""check if shape can be converted to FRACTAL_NZ"""
|
|
97
|
-
if len(shape) >= 2 and shape[-2] % BLOCK == 0 and shape[-1] % BLOCK == 0:
|
|
98
|
-
return True
|
|
99
|
-
return False
|
|
100
|
-
|
|
101
|
-
@staticmethod
|
|
102
|
-
def update_format(formats, new_format):
|
|
103
|
-
"""combine new_format to formats"""
|
|
104
|
-
new_formats = [new_format] if not isinstance(new_format, (list, tuple)) else new_format
|
|
105
|
-
for f in new_formats:
|
|
106
|
-
if f not in formats:
|
|
107
|
-
formats.append(f)
|
|
108
|
-
|
|
109
|
-
def get_attr(self, key):
|
|
110
|
-
"""get the value of attr"""
|
|
111
|
-
if key not in self.attr:
|
|
112
|
-
raise KeyError("Can not find attr '{}' in op '{}'".format(key, self.name))
|
|
113
|
-
return self.attr.get(key)[VALUE]
|
|
114
|
-
|
|
115
|
-
def set_attr(self, key, value):
|
|
116
|
-
"""set the value of attr"""
|
|
117
|
-
if key not in self.attr:
|
|
118
|
-
raise KeyError("Can not find attr '{}' in op '{}'".format(key, self.name))
|
|
119
|
-
self.attr.get(key)[VALUE] = value
|
|
120
|
-
|
|
121
|
-
def supported_type(self):
|
|
122
|
-
"""get the supported data type of current op"""
|
|
123
|
-
keep_fp32 = False
|
|
124
|
-
for item in self.input_desc:
|
|
125
|
-
# check if type can reduce precision
|
|
126
|
-
value = item.get(VALUE, None)
|
|
127
|
-
if item[DATA_TYPE] == FLOAT32 and value is not None and abs(value) > FP16_MAX:
|
|
128
|
-
keep_fp32 = True
|
|
129
|
-
break
|
|
130
|
-
io_type = ",".join([t[DATA_TYPE] for t in self.input_desc] + [t[DATA_TYPE] for t in self.output_desc])
|
|
131
|
-
fp32_type = io_type.replace(FLOAT16, FLOAT32)
|
|
132
|
-
fp16_type = io_type.replace(FLOAT32, FLOAT16)
|
|
133
|
-
supported_types = [io_type]
|
|
134
|
-
if fp32_type not in supported_types:
|
|
135
|
-
supported_types.append(fp32_type)
|
|
136
|
-
if not keep_fp32 and fp16_type not in supported_types:
|
|
137
|
-
supported_types.append(fp16_type)
|
|
138
|
-
return supported_types
|
|
139
|
-
|
|
140
|
-
def supported_format(self):
|
|
141
|
-
"""get the supported format of current op"""
|
|
142
|
-
io_num = len(self.input_desc) + len(self.output_desc)
|
|
143
|
-
nd = ["ND"] * io_num
|
|
144
|
-
return [",".join(nd)]
|
|
145
|
-
|
|
146
|
-
def infer_type(self):
|
|
147
|
-
"""infer data type"""
|
|
148
|
-
fixed_out_type_ops = ["Equal", "Less", "LessEqual", "Greater", "GreaterEqual"]
|
|
149
|
-
if self.name not in fixed_out_type_ops:
|
|
150
|
-
self.output_desc[0][DATA_TYPE] = self.input_desc[0][DATA_TYPE]
|
|
151
|
-
|
|
152
|
-
def infer_format(self):
|
|
153
|
-
"""infer format"""
|
|
154
|
-
self.output_desc[0][FORMAT] = self.input_desc[0][FORMAT]
|
|
155
|
-
|
|
156
|
-
def infer_shape(self):
|
|
157
|
-
"""infer shape"""
|
|
158
|
-
self.output_desc[0][SHAPE] = copy_shape(self.input_desc[0][SHAPE])
|
|
159
|
-
|
|
160
|
-
def infer_ori_shape(self):
|
|
161
|
-
"""infer original shape"""
|
|
162
|
-
for _, desc in enumerate(self.output_desc):
|
|
163
|
-
desc[ORI_SHAPE] = copy_shape(desc[SHAPE])
|
|
164
|
-
|
|
165
|
-
def infer(self):
|
|
166
|
-
"""infer shape, format and data type"""
|
|
167
|
-
self.infer_type()
|
|
168
|
-
self.infer_format()
|
|
169
|
-
self.infer_shape()
|
|
170
|
-
|
|
171
|
-
def post_process(self):
|
|
172
|
-
"""post process after infer"""
|
|
173
|
-
|
|
174
|
-
def update(self):
|
|
175
|
-
"""update each of"""
|
|
176
|
-
for _, desc in enumerate(self.output_desc):
|
|
177
|
-
desc[ORI_DATA_TYPE] = desc[DATA_TYPE]
|
|
178
|
-
desc[ORI_FORMAT] = desc[FORMAT]
|
|
179
|
-
self.infer_ori_shape()
|
|
180
|
-
self.infer()
|
|
181
|
-
self.post_process()
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
class Elemwise(OpInfer):
|
|
185
|
-
"""Elemwise op with one input and one output."""
|
|
186
|
-
|
|
187
|
-
def supported_format(self):
|
|
188
|
-
if self.name == "Reciprocal":
|
|
189
|
-
supported_formats = ["ND,ND"]
|
|
190
|
-
# pad will cause 'divided by 0'
|
|
191
|
-
if self.is_nz(self.input_desc[0][SHAPE]):
|
|
192
|
-
self.update_format(supported_formats, "FRACTAL_NZ,FRACTAL_NZ")
|
|
193
|
-
return supported_formats
|
|
194
|
-
return ["ND,ND", "FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0", "FRACTAL_Z,FRACTAL_Z"]
|
|
195
|
-
|
|
196
|
-
def infer_ori_shape(self):
|
|
197
|
-
self.output_desc[0][ORI_SHAPE] = self.input_desc[0][ORI_SHAPE]
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
class Cast(Elemwise):
|
|
201
|
-
"""Cast op."""
|
|
202
|
-
|
|
203
|
-
def supported_type(self):
|
|
204
|
-
in_type = self.input_desc[0][DATA_TYPE]
|
|
205
|
-
out_type = self.output_desc[0][DATA_TYPE]
|
|
206
|
-
io_type = ",".join([in_type, out_type])
|
|
207
|
-
return [io_type]
|
|
208
|
-
|
|
209
|
-
def infer_type(self):
|
|
210
|
-
self.output_desc[0][DATA_TYPE] = self.output_desc[0][DATA_TYPE]
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class ElemwiseBinaryNoBroadcast(OpInfer):
|
|
214
|
-
"""Elemwise op with two inputs and one output, not supports broadcast."""
|
|
215
|
-
|
|
216
|
-
def supported_format(self):
|
|
217
|
-
return ["ND,ND,ND", "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0,NC1HWC0",
|
|
218
|
-
"FRACTAL_Z,FRACTAL_Z,FRACTAL_Z"]
|
|
219
|
-
|
|
220
|
-
def infer_ori_shape(self):
|
|
221
|
-
self.output_desc[0][ORI_SHAPE] = self.input_desc[0][ORI_SHAPE]
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
class ElemwiseBinary(OpInfer):
|
|
225
|
-
"""Elemwise op with two inputs and one output, supports broadcast."""
|
|
226
|
-
|
|
227
|
-
@staticmethod
|
|
228
|
-
def nd2fractal_nz(shape):
|
|
229
|
-
"""convert ND shape to FRACTAL_NZ shape"""
|
|
230
|
-
if len(shape) == 1:
|
|
231
|
-
if shape[-1] == 1:
|
|
232
|
-
return [1, 1, 1, 1]
|
|
233
|
-
if shape[-1] % BLOCK == 0:
|
|
234
|
-
return [shape[-1] // BLOCK, 1, 1, BLOCK]
|
|
235
|
-
elif len(shape) >= 2:
|
|
236
|
-
if shape[-2] == 1 and shape[-1] == 1:
|
|
237
|
-
return shape[:-2] + [1, 1, 1, 1]
|
|
238
|
-
if shape[-2] == 1 and shape[-1] % BLOCK == 0:
|
|
239
|
-
return shape[:-2] + [shape[-1] // BLOCK, 1, 1, BLOCK]
|
|
240
|
-
if shape[-2] % BLOCK == 0 and shape[-1] == 1:
|
|
241
|
-
return shape[:-2] + [1, shape[-2] // BLOCK, BLOCK, 1]
|
|
242
|
-
return []
|
|
243
|
-
|
|
244
|
-
def broadcast_shape(self, sh0, sh1):
|
|
245
|
-
"""calculate broadcast shape"""
|
|
246
|
-
out_shape = []
|
|
247
|
-
max_len = max(len(sh0), len(sh1))
|
|
248
|
-
pad_sh0 = [1] * (max_len - len(sh0)) + sh0
|
|
249
|
-
pad_sh1 = [1] * (max_len - len(sh1)) + sh1
|
|
250
|
-
for i in range(max_len):
|
|
251
|
-
a, b = pad_sh0[i], pad_sh1[i]
|
|
252
|
-
if a == 1:
|
|
253
|
-
out_shape.append(b)
|
|
254
|
-
elif b in [1, a]:
|
|
255
|
-
out_shape.append(a)
|
|
256
|
-
else:
|
|
257
|
-
raise ValueError("For '{}', input shapes {} and {} can not broadcast".format(self.name, sh0, sh1))
|
|
258
|
-
return pad_sh0, pad_sh1, out_shape
|
|
259
|
-
|
|
260
|
-
def supported_format(self):
|
|
261
|
-
sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
|
|
262
|
-
supported_formats = ["ND,ND,ND"]
|
|
263
|
-
is_const_0 = (VALUE in self.input_desc[0])
|
|
264
|
-
is_const_1 = (VALUE in self.input_desc[1])
|
|
265
|
-
if sh0 == sh1 or is_const_0 or is_const_1:
|
|
266
|
-
# No broadcast case
|
|
267
|
-
self.update_format(supported_formats, ["FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0,NC1HWC0",
|
|
268
|
-
"FRACTAL_Z,FRACTAL_Z,FRACTAL_Z"])
|
|
269
|
-
else:
|
|
270
|
-
# note: (1, 640), (640) "FRACTAL_NZ,ND,FRACTAL_NZ", (1, 640) comes from MatMul
|
|
271
|
-
if len(sh0) == 2 and len(sh1) == 1:
|
|
272
|
-
if sh0[-1] == sh1[-1] and sh1[-1] % BLOCK == 0:
|
|
273
|
-
self.update_format(supported_formats, "FRACTAL_NZ,ND,FRACTAL_NZ")
|
|
274
|
-
elif len(sh0) == 1 and len(sh1) == 2:
|
|
275
|
-
if sh0[-1] == sh1[-1] and sh0[-1] % BLOCK == 0:
|
|
276
|
-
self.update_format(supported_formats, "ND,FRACTAL_NZ,FRACTAL_NZ")
|
|
277
|
-
# Broadcast case
|
|
278
|
-
pad_sh0, pad_sh1, _ = self.broadcast_shape(sh0, sh1)
|
|
279
|
-
# 1D with broadcast only supports "ND,ND,ND"
|
|
280
|
-
if len(pad_sh0) > 1:
|
|
281
|
-
nz0, nz1 = self.is_nz(pad_sh0), self.is_nz(pad_sh1)
|
|
282
|
-
if nz0 and nz1:
|
|
283
|
-
self.update_format(supported_formats, "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ")
|
|
284
|
-
elif nz0:
|
|
285
|
-
self.update_format(supported_formats, "FRACTAL_NZ,ND,FRACTAL_NZ")
|
|
286
|
-
elif nz1:
|
|
287
|
-
self.update_format(supported_formats, "ND,FRACTAL_NZ,FRACTAL_NZ")
|
|
288
|
-
# note: ND,ND,FRACTAL_NZ? e.g. (1024, 1), (1, 5120)
|
|
289
|
-
return supported_formats
|
|
290
|
-
|
|
291
|
-
def infer_format(self):
|
|
292
|
-
# select special format
|
|
293
|
-
special_formats = ["FRACTAL", "C0"]
|
|
294
|
-
format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
|
|
295
|
-
for f in special_formats:
|
|
296
|
-
if format0.find(f) != -1:
|
|
297
|
-
self.output_desc[0][FORMAT] = format0
|
|
298
|
-
return
|
|
299
|
-
self.output_desc[0][FORMAT] = format1
|
|
300
|
-
|
|
301
|
-
def infer_shape(self):
|
|
302
|
-
sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
|
|
303
|
-
if sh0 == sh1:
|
|
304
|
-
self.output_desc[0][SHAPE] = copy_shape(sh0)
|
|
305
|
-
format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
|
|
306
|
-
if format0 != format1:
|
|
307
|
-
new_sh0 = self.nd2fractal_nz(sh0)
|
|
308
|
-
new_sh1 = self.nd2fractal_nz(sh1)
|
|
309
|
-
if format0 == FRACTAL_NZ and new_sh1:
|
|
310
|
-
_, _, out_shape = self.broadcast_shape(sh0, new_sh1)
|
|
311
|
-
self.output_desc[0][SHAPE] = out_shape
|
|
312
|
-
return
|
|
313
|
-
if format1 == FRACTAL_NZ and new_sh0:
|
|
314
|
-
_, _, out_shape = self.broadcast_shape(new_sh0, sh1)
|
|
315
|
-
self.output_desc[0][SHAPE] = out_shape
|
|
316
|
-
return
|
|
317
|
-
_, _, out_shape = self.broadcast_shape(sh0, sh1)
|
|
318
|
-
self.output_desc[0][SHAPE] = out_shape
|
|
319
|
-
|
|
320
|
-
def infer_ori_shape(self):
|
|
321
|
-
sh0, sh1 = self.input_desc[0][ORI_SHAPE], self.input_desc[1][ORI_SHAPE]
|
|
322
|
-
_, _, out_shape = self.broadcast_shape(sh0, sh1)
|
|
323
|
-
self.output_desc[0][ORI_SHAPE] = out_shape
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
class MatMul(OpInfer):
|
|
327
|
-
"""MatMul op."""
|
|
328
|
-
|
|
329
|
-
def supported_format(self):
|
|
330
|
-
input_num = len(self.input_desc)
|
|
331
|
-
# MatMul cce only support ND
|
|
332
|
-
if InfoGlobalConfig.enable_cce_lib and input_num == 2:
|
|
333
|
-
return ["ND,ND,ND"]
|
|
334
|
-
if InfoGlobalConfig.enable_cce_lib and input_num == 3:
|
|
335
|
-
return ["ND,ND,ND,ND"]
|
|
336
|
-
if input_num == 2:
|
|
337
|
-
return ["FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ"]
|
|
338
|
-
if input_num == 3:
|
|
339
|
-
bias_shape = self.input_desc[2][SHAPE]
|
|
340
|
-
if len(bias_shape) == 1 and (bias_shape[-1] == 1 or bias_shape[-1] % BLOCK == 0):
|
|
341
|
-
return ["FRACTAL_NZ,FRACTAL_NZ,ND,FRACTAL_NZ"]
|
|
342
|
-
return ["ND,ND,ND,ND"]
|
|
343
|
-
raise ValueError("MatMul only supports 2 or 3 input tensors, but got {} input tensors".format(input_num))
|
|
344
|
-
|
|
345
|
-
def nd_infer(self, sh0, sh1, trans_a, trans_b):
|
|
346
|
-
"""infer shape with nd format"""
|
|
347
|
-
if len(sh0) != len(sh1):
|
|
348
|
-
raise ValueError("For '{}', input shape '{}' and '{}' are not supported".format(self.name, sh0, sh1))
|
|
349
|
-
m = sh0[-2] if not trans_a else sh0[-1]
|
|
350
|
-
n = sh1[-1] if not trans_b else sh1[-2]
|
|
351
|
-
res = sh0[:-2] + [m, n]
|
|
352
|
-
return res
|
|
353
|
-
|
|
354
|
-
def infer_shape(self):
|
|
355
|
-
sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
|
|
356
|
-
format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
|
|
357
|
-
trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
|
|
358
|
-
if format0 != format1 or len(sh0) != len(sh1):
|
|
359
|
-
raise ValueError("For '{}', input '{}' and '{}' are not supported"
|
|
360
|
-
.format(self.name, self.input_desc[0], self.input_desc[1]))
|
|
361
|
-
if format0 != FRACTAL_NZ and len(sh0) >= 2:
|
|
362
|
-
self.output_desc[0][SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
|
|
363
|
-
elif format0 == FRACTAL_NZ and len(sh0) >= 4:
|
|
364
|
-
m1, m0 = sh0[-3], sh0[-2]
|
|
365
|
-
if trans_a:
|
|
366
|
-
m1, m0 = sh0[-4], sh0[-1]
|
|
367
|
-
n1, n0 = sh1[-4], sh1[-1]
|
|
368
|
-
if trans_b:
|
|
369
|
-
n1, n0 = sh1[-3], sh1[-2]
|
|
370
|
-
self.output_desc[0][SHAPE] = sh0[:-4] + [n1, m1, m0, n0]
|
|
371
|
-
else:
|
|
372
|
-
raise ValueError("For '{}', input '{}' and '{}' are not supported"
|
|
373
|
-
.format(self.name, self.input_desc[0], self.input_desc[1]))
|
|
374
|
-
|
|
375
|
-
def infer_ori_shape(self):
|
|
376
|
-
sh0, sh1 = self.input_desc[0][ORI_SHAPE], self.input_desc[1][ORI_SHAPE]
|
|
377
|
-
trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
|
|
378
|
-
self.output_desc[0][ORI_SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
|
|
379
|
-
|
|
380
|
-
def post_process(self):
|
|
381
|
-
self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "left_format", VALUE: self.input_desc[0][FORMAT]})
|
|
382
|
-
self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "right_format", VALUE: self.input_desc[1][FORMAT]})
|
|
383
|
-
self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "dst_type", VALUE: self.output_desc[0][DATA_TYPE]})
|
|
384
|
-
|
|
385
|
-
def infer_type(self):
|
|
386
|
-
"""infer data type"""
|
|
387
|
-
if "910B" in InfoGlobalConfig.ascend_arch and not InfoGlobalConfig.enable_cce_lib:
|
|
388
|
-
self.output_desc[0][DATA_TYPE] = "float32"
|
|
389
|
-
else:
|
|
390
|
-
super().infer_type()
|
|
391
|
-
|
|
392
|
-
def supported_type(self):
|
|
393
|
-
if "910B" in InfoGlobalConfig.ascend_arch and not InfoGlobalConfig.enable_cce_lib:
|
|
394
|
-
support_types = "float16,float16,float32"
|
|
395
|
-
return [support_types]
|
|
396
|
-
return super().supported_type()
|
|
397
|
-
|
|
398
|
-
class BatchMatMul(MatMul):
|
|
399
|
-
"""BatchMatMul op. Only support cce lib"""
|
|
400
|
-
def __init__(self, op_desc):
|
|
401
|
-
super().__init__(op_desc)
|
|
402
|
-
if "910B" not in InfoGlobalConfig.ascend_arch or not InfoGlobalConfig.enable_cce_lib:
|
|
403
|
-
raise ValueError("BatchMatMul only support 910B cce lib")
|
|
404
|
-
|
|
405
|
-
def infer_shape(self):
|
|
406
|
-
sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
|
|
407
|
-
format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
|
|
408
|
-
trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
|
|
409
|
-
# only support nd
|
|
410
|
-
if (format0 != FRACTAL_NZ and format1 != FRACTAL_NZ):
|
|
411
|
-
self.output_desc[0][SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
|
|
412
|
-
else:
|
|
413
|
-
raise ValueError("For '{}', input '{}' and '{}' are not supported"
|
|
414
|
-
.format(self.name, self.input_desc[0], self.input_desc[1]))
|
|
415
|
-
|
|
416
|
-
def nd_infer(self, sh0, sh1, trans_a, trans_b):
|
|
417
|
-
"""infer shape with nd format"""
|
|
418
|
-
m = sh0[-2] if not trans_a else sh0[-1]
|
|
419
|
-
n = sh1[-1] if not trans_b else sh1[-2]
|
|
420
|
-
res = sh0[:-2] + [m, n]
|
|
421
|
-
return res
|
|
422
|
-
|
|
423
|
-
def infer_type(self):
|
|
424
|
-
"""infer data type"""
|
|
425
|
-
self.output_desc[0][DATA_TYPE] = "float16"
|
|
426
|
-
|
|
427
|
-
def supported_type(self):
|
|
428
|
-
"""supported type"""
|
|
429
|
-
return ["float16,float16,float16"]
|
|
430
|
-
|
|
431
|
-
class Reduce(OpInfer):
|
|
432
|
-
"""Reduce op."""
|
|
433
|
-
|
|
434
|
-
@staticmethod
|
|
435
|
-
def _out_nz(rank, axis):
|
|
436
|
-
"""check if output remains FRACTAL_NZ"""
|
|
437
|
-
if rank - 2 not in axis and rank - 1 not in axis:
|
|
438
|
-
return True
|
|
439
|
-
return False
|
|
440
|
-
|
|
441
|
-
@staticmethod
|
|
442
|
-
def _reduced_shape(shape, axis, keep_dims):
|
|
443
|
-
"""calc reduced shape"""
|
|
444
|
-
out_shape = []
|
|
445
|
-
for i, s in enumerate(shape):
|
|
446
|
-
if i in axis:
|
|
447
|
-
if keep_dims:
|
|
448
|
-
out_shape.append(1)
|
|
449
|
-
else:
|
|
450
|
-
out_shape.append(s)
|
|
451
|
-
return out_shape
|
|
452
|
-
|
|
453
|
-
def _get_axis(self, rank):
|
|
454
|
-
axis_input = self.input_desc[1][VALUE]
|
|
455
|
-
axis = []
|
|
456
|
-
if isinstance(axis_input, int):
|
|
457
|
-
axis = [axis_input + rank if axis_input < 0 else axis_input]
|
|
458
|
-
else:
|
|
459
|
-
axis = [i + rank if i < 0 else i for i in axis_input]
|
|
460
|
-
return axis
|
|
461
|
-
|
|
462
|
-
def supported_type(self):
|
|
463
|
-
in_type = self.input_desc[0][DATA_TYPE]
|
|
464
|
-
if in_type == FLOAT16:
|
|
465
|
-
return ["float16,int64,float16", "float32,int64,float32"]
|
|
466
|
-
if in_type == FLOAT32:
|
|
467
|
-
return ["float32,int64,float32"]
|
|
468
|
-
io_type = ",".join([in_type, "int64", in_type])
|
|
469
|
-
return [io_type]
|
|
470
|
-
|
|
471
|
-
def supported_format(self):
|
|
472
|
-
supported_formats = ["ND,DefaultFormat,ND"]
|
|
473
|
-
shape = self.input_desc[0][SHAPE]
|
|
474
|
-
rank = len(shape)
|
|
475
|
-
axis = self._get_axis(rank)
|
|
476
|
-
if self.is_nz(shape):
|
|
477
|
-
if self._out_nz(rank, axis):
|
|
478
|
-
supported_formats.append("FRACTAL_NZ,DefaultFormat,FRACTAL_NZ")
|
|
479
|
-
return supported_formats
|
|
480
|
-
|
|
481
|
-
def infer_shape(self):
|
|
482
|
-
ori_format, cur_format = self.input_desc[0][ORI_FORMAT], self.input_desc[0][FORMAT]
|
|
483
|
-
if cur_format == FRACTAL_NZ and cur_format != ori_format:
|
|
484
|
-
ori_shape, cur_shape = self.input_desc[0][ORI_SHAPE], self.input_desc[0][SHAPE]
|
|
485
|
-
ori_rank = len(ori_shape)
|
|
486
|
-
rank = len(cur_shape)
|
|
487
|
-
axis = self._get_axis(ori_rank)
|
|
488
|
-
new_axis = []
|
|
489
|
-
for i in axis:
|
|
490
|
-
if i == ori_rank - 1:
|
|
491
|
-
new_axis.extend([rank - 4, rank - 1])
|
|
492
|
-
elif i == ori_rank - 2:
|
|
493
|
-
new_axis.extend([rank - 3, rank - 2])
|
|
494
|
-
else:
|
|
495
|
-
new_axis.append(i)
|
|
496
|
-
self.input_desc[1][VALUE] = new_axis
|
|
497
|
-
self.input_desc[1][SHAPE] = [len(new_axis)]
|
|
498
|
-
self.output_desc[0][SHAPE] = self._reduced_shape(cur_shape, new_axis, self.get_attr("keep_dims"))
|
|
499
|
-
else:
|
|
500
|
-
self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
|
|
501
|
-
|
|
502
|
-
def infer_ori_shape(self):
|
|
503
|
-
shape = self.input_desc[0][ORI_SHAPE]
|
|
504
|
-
rank = len(shape)
|
|
505
|
-
axis = self._get_axis(rank)
|
|
506
|
-
self.output_desc[0][ORI_SHAPE] = self._reduced_shape(shape, axis, self.get_attr("keep_dims"))
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
class Reshape(OpInfer):
|
|
510
|
-
"""Reshape op."""
|
|
511
|
-
|
|
512
|
-
def supported_format(self):
|
|
513
|
-
return ["ND,DefaultFormat,ND"]
|
|
514
|
-
|
|
515
|
-
def infer_shape(self):
|
|
516
|
-
"""Reshape keeps ND format, so the output shape will not be changed"""
|
|
517
|
-
self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
|
|
518
|
-
|
|
519
|
-
def infer_ori_shape(self):
|
|
520
|
-
shape = self.input_desc[0][ORI_SHAPE]
|
|
521
|
-
out_shape = copy_shape(self.input_desc[1][VALUE])
|
|
522
|
-
if -1 in out_shape:
|
|
523
|
-
idx = out_shape.index(-1)
|
|
524
|
-
tmp = []
|
|
525
|
-
for _, s in enumerate(out_shape):
|
|
526
|
-
if s != -1:
|
|
527
|
-
tmp.append(s)
|
|
528
|
-
if len(tmp) + 1 != len(out_shape):
|
|
529
|
-
raise ValueError("Find multiple -1 in attr 'shape' {}".format(out_shape))
|
|
530
|
-
tmp_sz = functools.reduce(lambda x, y: x * y, tmp, 1)
|
|
531
|
-
out_shape[idx] = functools.reduce(lambda x, y: x * y, shape, 1) // tmp_sz
|
|
532
|
-
self.output_desc[0][ORI_SHAPE] = out_shape
|
|
533
|
-
|
|
534
|
-
def post_process(self):
|
|
535
|
-
self.input_desc[1]["ori_value"] = self.input_desc[1][VALUE]
|
|
536
|
-
self.input_desc[1][VALUE] = self.output_desc[0][SHAPE]
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
class ExpandDimAndSqueeze(Reshape):
|
|
540
|
-
def copy_axis(self, axis):
|
|
541
|
-
out_axis = []
|
|
542
|
-
if isinstance(axis, int):
|
|
543
|
-
out_axis.append(axis)
|
|
544
|
-
else:
|
|
545
|
-
out_axis = copy.deepcopy(axis)
|
|
546
|
-
return out_axis
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
class Squeeze(ExpandDimAndSqueeze):
|
|
550
|
-
def infer_ori_shape(self):
|
|
551
|
-
axis = self.copy_axis(self.input_desc[1][VALUE])
|
|
552
|
-
input_shape = copy_shape(self.input_desc[0][SHAPE])
|
|
553
|
-
for idx in axis:
|
|
554
|
-
if input_shape[idx] != 1:
|
|
555
|
-
raise ValueError("The value of attr 'axis' is wrong , the squeezed axis must be 1, but got {}. 'axis': "
|
|
556
|
-
"{}, input shape: {}".format(input_shape[idx], axis, input_shape))
|
|
557
|
-
input_shape.pop(idx)
|
|
558
|
-
self.output_desc[0][ORI_SHAPE] = input_shape
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
class ExpandDim(ExpandDimAndSqueeze):
|
|
562
|
-
def infer_ori_shape(self):
|
|
563
|
-
axis = self.copy_axis(self.input_desc[1][VALUE])
|
|
564
|
-
input_shape = copy_shape(self.input_desc[0][SHAPE])
|
|
565
|
-
for idx in axis:
|
|
566
|
-
input_shape.insert(idx, 1)
|
|
567
|
-
self.output_desc[0][ORI_SHAPE] = input_shape
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
class BroadcastTo(OpInfer):
|
|
571
|
-
"""BroadcastTo op."""
|
|
572
|
-
|
|
573
|
-
def supported_format(self):
|
|
574
|
-
io_format = ["ND"] * len(self.input_desc)
|
|
575
|
-
return [",".join(io_format)]
|
|
576
|
-
|
|
577
|
-
def infer_shape(self):
|
|
578
|
-
"""Broadcast op keeps ND format, so the output shape will not be changed"""
|
|
579
|
-
self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
|
|
580
|
-
|
|
581
|
-
def infer_ori_shape(self):
|
|
582
|
-
shape = self.input_desc[0][ORI_SHAPE]
|
|
583
|
-
broad_shape = self.get_attr(SHAPE) if SHAPE in self.attr else self.input_desc[1][VALUE]
|
|
584
|
-
if len(broad_shape) < len(shape):
|
|
585
|
-
raise ValueError("The length of attr 'shape' must be >= the length of input shape, but got attr 'shape': "
|
|
586
|
-
"{}, input shape: {}".format(broad_shape, shape))
|
|
587
|
-
pad_shape = [1] * (len(broad_shape) - len(shape)) + shape
|
|
588
|
-
out_shape = []
|
|
589
|
-
for i, b in enumerate(broad_shape):
|
|
590
|
-
if b == -1:
|
|
591
|
-
out_shape.append(pad_shape[i])
|
|
592
|
-
else:
|
|
593
|
-
out_shape.append(b)
|
|
594
|
-
self.output_desc[0][ORI_SHAPE] = out_shape
|
|
595
|
-
|
|
596
|
-
def post_process(self):
|
|
597
|
-
if not isinstance(self.op_desc.get(ATTR), list):
|
|
598
|
-
return
|
|
599
|
-
for item in self.op_desc[ATTR]:
|
|
600
|
-
if item[NAME] == SHAPE:
|
|
601
|
-
item["ori_value"] = item[VALUE]
|
|
602
|
-
item[VALUE] = self.output_desc[0][SHAPE]
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
class Tile(OpInfer):
|
|
606
|
-
"""BroadcastTo op."""
|
|
607
|
-
|
|
608
|
-
def supported_format(self):
|
|
609
|
-
return ["ND,ND"]
|
|
610
|
-
|
|
611
|
-
def infer_shape(self):
|
|
612
|
-
"""Tile op keeps ND format, so the output shape will not be changed"""
|
|
613
|
-
self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
|
|
614
|
-
|
|
615
|
-
def infer_ori_shape(self):
|
|
616
|
-
shape = self.input_desc[0][ORI_SHAPE]
|
|
617
|
-
multiples = self.input_desc[1][VALUE]
|
|
618
|
-
if len(multiples) < len(shape):
|
|
619
|
-
raise ValueError("The length of attr 'multiples' must be >= the length of input shape, but got attr "
|
|
620
|
-
"'multiples': {}, input shape: {}".format(multiples, shape))
|
|
621
|
-
pad_shape = [1] * (len(multiples) - len(shape)) + shape
|
|
622
|
-
out_shape = []
|
|
623
|
-
for i, m in enumerate(multiples):
|
|
624
|
-
out_shape.append(m * pad_shape[i])
|
|
625
|
-
self.output_desc[0][ORI_SHAPE] = out_shape
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
class PagedAttention(OpInfer):
|
|
629
|
-
"""PagedAttention"""
|
|
630
|
-
|
|
631
|
-
def supported_format(self):
|
|
632
|
-
return ["ND,ND,ND,ND,ND,ND"]
|
|
633
|
-
|
|
634
|
-
def infer_shape(self):
|
|
635
|
-
"""PagedAttention op keeps ND format, so the output shape will not be changed"""
|
|
636
|
-
self.output_desc[0]["shape"] = self.output_desc[0]["ori_shape"]
|
|
637
|
-
|
|
638
|
-
def infer_ori_shape(self):
|
|
639
|
-
self.output_desc[0]["ori_shape"] = self.input_desc[0]["ori_shape"]
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
class ReshapeAndCache(OpInfer):
|
|
643
|
-
"""ReshapeAndCache"""
|
|
644
|
-
|
|
645
|
-
def supported_format(self):
|
|
646
|
-
return ["ND,ND,ND,ND,ND,ND"]
|
|
647
|
-
|
|
648
|
-
def infer_shape(self):
|
|
649
|
-
"""ReshapeAndCache op keeps ND format, so the output shape will not be changed"""
|
|
650
|
-
self.output_desc[0]["shape"] = self.output_desc[0]["ori_shape"]
|
|
651
|
-
|
|
652
|
-
def infer_ori_shape(self):
|
|
653
|
-
self.output_desc[0]["ori_shape"] = self.input_desc[0]["ori_shape"]
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
class PagedAttentionMask(PagedAttention):
|
|
657
|
-
"""PagedAttentionMask"""
|
|
658
|
-
|
|
659
|
-
def supported_format(self):
|
|
660
|
-
return ["ND,ND,ND,ND,ND,ND,ND"]
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
# Ge will convert dtype bool to int8, and ReLU will be expand to Greater op in expander,
|
|
664
|
-
# and the dtype of Greater op is bool, which is incompatible with bool.
|
|
665
|
-
# As a result akg will rise error when parsing Greater op with dtype int8.
|
|
666
|
-
# Expand And Sequeeze op will be expanded into Reshape op in expander,
|
|
667
|
-
# but in dynamic shape scenario, the meaning of -1 in Reshape op different from -1 in Expand And Sequeeze op.
|
|
668
|
-
# So this will lead to infer shape error.
|
|
669
|
-
# To solve these problems we need to cluster these ops in to subgraph and update info file here.
|
|
670
|
-
prims = {
|
|
671
|
-
"Abs": Elemwise,
|
|
672
|
-
"Neg": Elemwise,
|
|
673
|
-
"Sqrt": Elemwise,
|
|
674
|
-
"Rsqrt": Elemwise,
|
|
675
|
-
"Reciprocal": Elemwise,
|
|
676
|
-
"FastGeLU": Elemwise,
|
|
677
|
-
"Round": Elemwise,
|
|
678
|
-
"Assign": ElemwiseBinaryNoBroadcast,
|
|
679
|
-
"Add": ElemwiseBinary,
|
|
680
|
-
"Sub": ElemwiseBinary,
|
|
681
|
-
"Mul": ElemwiseBinary,
|
|
682
|
-
"Div": ElemwiseBinary,
|
|
683
|
-
"Mod": ElemwiseBinary,
|
|
684
|
-
"RealDiv": ElemwiseBinary,
|
|
685
|
-
"Maximum": ElemwiseBinary,
|
|
686
|
-
"Minimum": ElemwiseBinary,
|
|
687
|
-
"MatMul": MatMul,
|
|
688
|
-
"BatchMatMul": BatchMatMul,
|
|
689
|
-
"ReduceSum": Reduce,
|
|
690
|
-
"Reshape": Reshape,
|
|
691
|
-
"ExpandDims": ExpandDim,
|
|
692
|
-
"Squeeze": Squeeze,
|
|
693
|
-
"BroadcastTo": BroadcastTo,
|
|
694
|
-
"Tile": Tile,
|
|
695
|
-
"Log": Elemwise,
|
|
696
|
-
"Exp": Elemwise,
|
|
697
|
-
"Pow": Elemwise,
|
|
698
|
-
"Sign": Elemwise,
|
|
699
|
-
"ReLU": Elemwise,
|
|
700
|
-
"Tanh": Elemwise,
|
|
701
|
-
"ReduceMax": Reduce,
|
|
702
|
-
"ReduceMin": Reduce,
|
|
703
|
-
"Cast": Cast,
|
|
704
|
-
"PagedAttention": PagedAttention,
|
|
705
|
-
"PagedAttentionMask": PagedAttentionMask,
|
|
706
|
-
"ReshapeAndCache": ReshapeAndCache,
|
|
707
|
-
}
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
def convert_to_default_format(desc):
|
|
711
|
-
"""Convert to DefaultFormat"""
|
|
712
|
-
default_format = ["ND", "NCHW", "NHWC", "HWCN", DEFAULT_FORMAT]
|
|
713
|
-
for _, input_desc in enumerate(desc[INPUT_DESC]):
|
|
714
|
-
if input_desc[0][FORMAT] in default_format:
|
|
715
|
-
input_desc[0][FORMAT] = DEFAULT_FORMAT
|
|
716
|
-
if not input_desc[0][SHAPE]:
|
|
717
|
-
input_desc[0][SHAPE] = [1]
|
|
718
|
-
for _, op_desc in enumerate(desc[OP_DESC]):
|
|
719
|
-
for _, input_desc in enumerate(op_desc[INPUT_DESC]):
|
|
720
|
-
if input_desc[0][FORMAT] in default_format:
|
|
721
|
-
input_desc[0][FORMAT] = DEFAULT_FORMAT
|
|
722
|
-
if not input_desc[0][SHAPE]:
|
|
723
|
-
input_desc[0][SHAPE] = [1]
|
|
724
|
-
for _, output_desc in enumerate(op_desc[OUTPUT_DESC]):
|
|
725
|
-
if output_desc[FORMAT] in default_format:
|
|
726
|
-
output_desc[FORMAT] = DEFAULT_FORMAT
|
|
727
|
-
if not output_desc[SHAPE]:
|
|
728
|
-
output_desc[SHAPE] = [1]
|
|
729
|
-
for _, output_desc in enumerate(desc[OUTPUT_DESC]):
|
|
730
|
-
if output_desc[FORMAT] in default_format:
|
|
731
|
-
output_desc[FORMAT] = DEFAULT_FORMAT
|
|
732
|
-
if not output_desc[SHAPE]:
|
|
733
|
-
output_desc[SHAPE] = [1]
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
def update_global_input_desc(info_desc, args):
|
|
737
|
-
"""Update the global input of the fused info file"""
|
|
738
|
-
|
|
739
|
-
def _convert_tbe_type(tbe_type, ori_type):
|
|
740
|
-
if tbe_type == "float":
|
|
741
|
-
return FLOAT32
|
|
742
|
-
if tbe_type == "int8" and ori_type == "bool":
|
|
743
|
-
# GE pass int8 here if data type is bool, but we must return bool back to GE, otherwise GE will
|
|
744
|
-
# raise an error "current op does not support bool"
|
|
745
|
-
return ori_type
|
|
746
|
-
return tbe_type
|
|
747
|
-
|
|
748
|
-
def _covert_tbe_shape(tbe_shape):
|
|
749
|
-
if not tbe_shape:
|
|
750
|
-
return [1]
|
|
751
|
-
return copy_shape(tbe_shape)
|
|
752
|
-
|
|
753
|
-
if isinstance(info_desc.get(INPUT_DESC), list):
|
|
754
|
-
for i, desc in enumerate(info_desc[INPUT_DESC]):
|
|
755
|
-
desc[0][ORI_DATA_TYPE] = desc[0][DATA_TYPE]
|
|
756
|
-
desc[0][DATA_TYPE] = _convert_tbe_type(args[i]["dtype"], desc[0][ORI_DATA_TYPE])
|
|
757
|
-
desc[0][ORI_FORMAT] = args[i].get(ORI_FORMAT, desc[0][FORMAT])
|
|
758
|
-
desc[0][FORMAT] = args[i][FORMAT]
|
|
759
|
-
desc[0][ORI_SHAPE] = _covert_tbe_shape(args[i].get(ORI_SHAPE, desc[0][SHAPE]))
|
|
760
|
-
desc[0][SHAPE] = list(args[i][SHAPE])
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
def update_global_output_desc(info_desc, tensor_desc):
|
|
764
|
-
"""Update the global output of the fused info file"""
|
|
765
|
-
for i, desc in enumerate(info_desc[OUTPUT_DESC]):
|
|
766
|
-
tensor_name = desc[TENSOR_NAME]
|
|
767
|
-
if tensor_name not in tensor_desc:
|
|
768
|
-
raise RuntimeError("tensor '{}' not exist in op_desc".format(tensor_name))
|
|
769
|
-
info_desc[OUTPUT_DESC][i] = tensor_desc[tensor_name]
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
def update_op_input_desc(op_desc, tensor_desc):
|
|
773
|
-
"""Update the input of operator"""
|
|
774
|
-
if not isinstance(op_desc.get(INPUT_DESC), list):
|
|
775
|
-
return
|
|
776
|
-
inputs_type_orig = []
|
|
777
|
-
inputs_type = []
|
|
778
|
-
const_inputs_idx = []
|
|
779
|
-
for i, desc in enumerate(op_desc[INPUT_DESC]):
|
|
780
|
-
for j, item in enumerate(desc):
|
|
781
|
-
if VALUE in item:
|
|
782
|
-
inputs_type_orig.append(None)
|
|
783
|
-
inputs_type.append(None)
|
|
784
|
-
const_inputs_idx.append(i)
|
|
785
|
-
item[ORI_DATA_TYPE] = item[DATA_TYPE]
|
|
786
|
-
item[ORI_FORMAT] = item[FORMAT]
|
|
787
|
-
item[ORI_SHAPE] = copy_shape(item[SHAPE])
|
|
788
|
-
else:
|
|
789
|
-
inputs_type_orig.append(item[DATA_TYPE])
|
|
790
|
-
tensor_name = item[TENSOR_NAME]
|
|
791
|
-
if tensor_name not in tensor_desc:
|
|
792
|
-
raise RuntimeError("tensor '{}' used without initialization".format(tensor_name))
|
|
793
|
-
# update op input
|
|
794
|
-
desc[j] = tensor_desc[tensor_name]
|
|
795
|
-
inputs_type.append(tensor_desc[tensor_name][DATA_TYPE])
|
|
796
|
-
# update op const input's data type
|
|
797
|
-
for _, idx in enumerate(const_inputs_idx):
|
|
798
|
-
const_value_type = op_desc[INPUT_DESC][idx][0][DATA_TYPE]
|
|
799
|
-
if const_value_type in inputs_type_orig:
|
|
800
|
-
op_desc[INPUT_DESC][idx][0][DATA_TYPE] = inputs_type[inputs_type_orig.index(const_value_type)]
|
|
801
|
-
# cache op const input
|
|
802
|
-
tensor_desc[op_desc[INPUT_DESC][idx][0][TENSOR_NAME]] = op_desc[INPUT_DESC][idx][0]
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
def cache_input_tensors(tensor_desc, input_desc):
|
|
806
|
-
"""Cache input tensor desc"""
|
|
807
|
-
if isinstance(input_desc, list):
|
|
808
|
-
for desc in input_desc:
|
|
809
|
-
for item in desc:
|
|
810
|
-
tensor_desc[item[TENSOR_NAME]] = item
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
def cache_output_tensors(tensor_desc, output_desc):
|
|
814
|
-
"""Cache output tensor desc"""
|
|
815
|
-
for item in output_desc:
|
|
816
|
-
tensor_desc[item[TENSOR_NAME]] = item
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
def save(filename, contents):
|
|
820
|
-
"""Save to file"""
|
|
821
|
-
with os.fdopen(os.open(os.path.realpath(filename), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o660), 'w') as f:
|
|
822
|
-
f.write(contents)
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
def update_akg_info(args, info_path, kernel_name=None):
|
|
826
|
-
"""Update akg info base on the current inputs provided by GE"""
|
|
827
|
-
with open(os.path.realpath(info_path), 'r') as f:
|
|
828
|
-
info_str = f.read()
|
|
829
|
-
desc = json.loads(info_str)
|
|
830
|
-
desc["op_ori"] = desc[OP]
|
|
831
|
-
desc[OP] = kernel_name if kernel_name else desc[OP]
|
|
832
|
-
tensor_desc = {} # {tensor_name: tensor_desc}
|
|
833
|
-
|
|
834
|
-
# Update input_desc
|
|
835
|
-
update_global_input_desc(desc, args)
|
|
836
|
-
# cache global input
|
|
837
|
-
cache_input_tensors(tensor_desc, desc.get(INPUT_DESC))
|
|
838
|
-
# Update info global config
|
|
839
|
-
InfoGlobalConfig.enable_cce_lib = desc.get("enable_cce_lib")
|
|
840
|
-
target_info = desc.get("target_info")
|
|
841
|
-
if target_info is not None:
|
|
842
|
-
InfoGlobalConfig.ascend_arch = target_info.get("arch")
|
|
843
|
-
|
|
844
|
-
# Update op_desc
|
|
845
|
-
for _, op_desc in enumerate(desc[OP_DESC]):
|
|
846
|
-
update_op_input_desc(op_desc, tensor_desc)
|
|
847
|
-
op_name = op_desc[NAME]
|
|
848
|
-
if op_name not in prims:
|
|
849
|
-
raise KeyError("Not supported op: {}".format(op_name))
|
|
850
|
-
prim = prims.get(op_name)(op_desc)
|
|
851
|
-
prim.update()
|
|
852
|
-
# cache op output
|
|
853
|
-
cache_output_tensors(tensor_desc, op_desc[OUTPUT_DESC])
|
|
854
|
-
|
|
855
|
-
# Update output_desc
|
|
856
|
-
update_global_output_desc(desc, tensor_desc)
|
|
857
|
-
|
|
858
|
-
# Update data format to DefaultFormat
|
|
859
|
-
convert_to_default_format(desc)
|
|
860
|
-
|
|
861
|
-
# GE backend must use old CCE
|
|
862
|
-
desc["backend"] = "GE"
|
|
863
|
-
|
|
864
|
-
return desc
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
def save_updated_akg_info(*args):
|
|
868
|
-
"""Save the updated akg info."""
|
|
869
|
-
info_path = args[-2]
|
|
870
|
-
kernel_name = args[-1]
|
|
871
|
-
if not isinstance(info_path, str):
|
|
872
|
-
# in this case, kernel_name is not passed by GE, skip compiling
|
|
873
|
-
return ""
|
|
874
|
-
updated_desc = update_akg_info(args, info_path, kernel_name)
|
|
875
|
-
real_info_path = os.path.join(os.path.realpath(os.path.dirname(info_path)), kernel_name + ".info")
|
|
876
|
-
# Save the updated info file
|
|
877
|
-
save(real_info_path, json.dumps(updated_desc))
|
|
878
|
-
return real_info_path
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
def create_dirs(*dirs):
|
|
882
|
-
"""Create directories."""
|
|
883
|
-
for d in dirs:
|
|
884
|
-
if not os.path.isdir(d):
|
|
885
|
-
try:
|
|
886
|
-
os.makedirs(d)
|
|
887
|
-
except OSError as err:
|
|
888
|
-
# File exists
|
|
889
|
-
if err.errno == 17:
|
|
890
|
-
pass
|
|
891
|
-
else:
|
|
892
|
-
raise err
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
def copy_file(src_path, dst_path):
|
|
896
|
-
"""Copy file to dst."""
|
|
897
|
-
try:
|
|
898
|
-
if os.path.isfile(dst_path):
|
|
899
|
-
os.remove(dst_path)
|
|
900
|
-
except OSError:
|
|
901
|
-
pass
|
|
902
|
-
|
|
903
|
-
try:
|
|
904
|
-
shutil.copy(src_path, dst_path)
|
|
905
|
-
except PermissionError:
|
|
906
|
-
# If dst_path already exits and only has READ permission
|
|
907
|
-
pass
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
def _compile_subprocess(kernel_meta_dirs, info_path, is_lite=True, compile_backend=None, attrs=None):
|
|
911
|
-
"""Use a new process to compile info."""
|
|
912
|
-
kernel_meta_parent_dir, kernel_meta_dir = kernel_meta_dirs
|
|
913
|
-
my_env = os.environ
|
|
914
|
-
my_env["MS_COMPILER_CACHE_PATH"] = kernel_meta_parent_dir
|
|
915
|
-
my_env["KERNEL_META_DIR"] = kernel_meta_dir
|
|
916
|
-
compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
|
|
917
|
-
if is_lite:
|
|
918
|
-
run_args = [sys.executable, compiler, info_path]
|
|
919
|
-
else:
|
|
920
|
-
run_args = [sys.executable, compiler, info_path, compile_backend, attrs, kernel_meta_parent_dir]
|
|
921
|
-
compile_result = subprocess.run(run_args, text=True, check=False, capture_output=True, env=my_env)
|
|
922
|
-
return compile_result
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
def search_supported_types_formats(info):
|
|
926
|
-
"""Get the supported data types and formats of the fused info file"""
|
|
927
|
-
|
|
928
|
-
class DfsSearcher:
|
|
929
|
-
"""Use DFS"""
|
|
930
|
-
|
|
931
|
-
def __init__(self, top_io_names, ops_desc):
|
|
932
|
-
self.supported_types = []
|
|
933
|
-
self.supported_formats = []
|
|
934
|
-
self.top_io_names = top_io_names
|
|
935
|
-
self.tensor_types = {}
|
|
936
|
-
self.tensor_formats = {}
|
|
937
|
-
self.ops_desc = ops_desc
|
|
938
|
-
self.cache = []
|
|
939
|
-
|
|
940
|
-
def set_current_format(self, cur_format, io_names):
|
|
941
|
-
"""set tensor format"""
|
|
942
|
-
for i, fmt in enumerate(cur_format):
|
|
943
|
-
if self.tensor_formats.get(io_names[i], fmt) != fmt:
|
|
944
|
-
return False
|
|
945
|
-
self.tensor_formats[io_names[i]] = fmt
|
|
946
|
-
return True
|
|
947
|
-
|
|
948
|
-
def set_current_type(self, cur_type, io_names):
|
|
949
|
-
"""set tensor data type"""
|
|
950
|
-
for i, data_type in enumerate(cur_type):
|
|
951
|
-
if self.tensor_types.get(io_names[i], data_type) != data_type:
|
|
952
|
-
return False
|
|
953
|
-
self.tensor_types[io_names[i]] = data_type
|
|
954
|
-
return True
|
|
955
|
-
|
|
956
|
-
def get_desc(self, opid):
|
|
957
|
-
"""get desc"""
|
|
958
|
-
if opid < len(self.cache):
|
|
959
|
-
return self.cache[opid]
|
|
960
|
-
desc = self.ops_desc[opid]
|
|
961
|
-
io_names = [item[TENSOR_NAME] for desc in desc[INPUT_DESC] for item in desc]
|
|
962
|
-
io_names.append(desc[OUTPUT_DESC][0][TENSOR_NAME])
|
|
963
|
-
op_name = desc[NAME]
|
|
964
|
-
if op_name not in prims:
|
|
965
|
-
raise KeyError("Not supported op: {}".format(op_name))
|
|
966
|
-
prim = prims.get(op_name)(desc)
|
|
967
|
-
io_formats = [f.split(",") for f in prim.supported_format()]
|
|
968
|
-
io_types = [t.split(",") for t in prim.supported_type()]
|
|
969
|
-
self.cache.append((io_formats, io_types, tuple(io_names)))
|
|
970
|
-
return self.cache[-1]
|
|
971
|
-
|
|
972
|
-
def search_types(self, opid):
|
|
973
|
-
"""search the supported types"""
|
|
974
|
-
if opid == len(self.ops_desc):
|
|
975
|
-
top_tensor_types = tuple(self.tensor_types.get(t) for t in self.top_io_names)
|
|
976
|
-
self.supported_types.append(top_tensor_types)
|
|
977
|
-
return
|
|
978
|
-
_, op_io_types, io_names = self.get_desc(opid)
|
|
979
|
-
for cur_type in op_io_types:
|
|
980
|
-
bak_tensor_types = copy.deepcopy(self.tensor_types)
|
|
981
|
-
if self.set_current_type(cur_type, io_names):
|
|
982
|
-
self.search_types(opid + 1)
|
|
983
|
-
self.tensor_types = bak_tensor_types
|
|
984
|
-
|
|
985
|
-
def search_formats(self, opid):
|
|
986
|
-
"""search the supported formats"""
|
|
987
|
-
if opid == len(self.ops_desc):
|
|
988
|
-
top_tensor_formats = tuple(self.tensor_formats.get(t) for t in self.top_io_names)
|
|
989
|
-
self.supported_formats.append(top_tensor_formats)
|
|
990
|
-
return
|
|
991
|
-
op_io_formats, _, io_names = self.get_desc(opid)
|
|
992
|
-
for cur_format in op_io_formats:
|
|
993
|
-
bak_tensor_formats = copy.deepcopy(self.tensor_formats)
|
|
994
|
-
if self.set_current_format(cur_format, io_names):
|
|
995
|
-
self.search_formats(opid + 1)
|
|
996
|
-
self.tensor_formats = bak_tensor_formats
|
|
997
|
-
|
|
998
|
-
def remove_dup(data):
|
|
999
|
-
res = []
|
|
1000
|
-
data_str = []
|
|
1001
|
-
for _, t in enumerate(data):
|
|
1002
|
-
t_str = ",".join(t)
|
|
1003
|
-
if t_str not in data_str:
|
|
1004
|
-
data_str.append(t_str)
|
|
1005
|
-
res.append(t)
|
|
1006
|
-
return res
|
|
1007
|
-
|
|
1008
|
-
top_io_names = [t[0][TENSOR_NAME] for t in info[INPUT_DESC]] + [t[TENSOR_NAME] for t in info[OUTPUT_DESC]]
|
|
1009
|
-
handle = DfsSearcher(top_io_names, info[OP_DESC])
|
|
1010
|
-
handle.search_types(0)
|
|
1011
|
-
handle.search_formats(0)
|
|
1012
|
-
return remove_dup(handle.supported_types), remove_dup(handle.supported_formats)
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
def op_select_format(*args, **kwags):
|
|
1016
|
-
"""Entrance for format/data type selection, will invoked by GE"""
|
|
1017
|
-
info_path = args[-1]
|
|
1018
|
-
desc = update_akg_info(args, info_path)
|
|
1019
|
-
supported_io_type, supported_io_format = search_supported_types_formats(desc)
|
|
1020
|
-
if not supported_io_type or not supported_io_format:
|
|
1021
|
-
raise RuntimeError("Select format failed for info: {}".format(info_path))
|
|
1022
|
-
input_num = len(desc[INPUT_DESC])
|
|
1023
|
-
output_num = len(desc[OUTPUT_DESC])
|
|
1024
|
-
param_list = []
|
|
1025
|
-
for i in range(input_num + output_num):
|
|
1026
|
-
dtype_list = [item[i] for item in supported_io_type] * len(supported_io_format)
|
|
1027
|
-
format_list = functools.reduce(lambda x, y: x + y,
|
|
1028
|
-
[[item[i]] * len(supported_io_type) for item in supported_io_format])
|
|
1029
|
-
classify = "input" + str(i) if i < input_num else "output" + str(i - input_num)
|
|
1030
|
-
name = "x" + str(i) if i < input_num else "y" + str(i - input_num)
|
|
1031
|
-
param = gen_param(classify=classify,
|
|
1032
|
-
name=name,
|
|
1033
|
-
datatype=",".join(dtype_list),
|
|
1034
|
-
format=",".join(format_list))
|
|
1035
|
-
param_list.append(param)
|
|
1036
|
-
param_dynamic_in_json = get_dynamic_param_in_json(param_list)
|
|
1037
|
-
return param_dynamic_in_json
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
def custom(*args, **kwags):
|
|
1041
|
-
"""Entrance for akg info compiling, will invoked by GE"""
|
|
1042
|
-
kernel_name = args[-1]
|
|
1043
|
-
real_info_path = save_updated_akg_info(*args)
|
|
1044
|
-
if not real_info_path:
|
|
1045
|
-
return
|
|
1046
|
-
kernel_meta_parent_dir = get_current_build_config("kernel_meta_parent_dir")
|
|
1047
|
-
kernel_meta_dir = "kernel_meta"
|
|
1048
|
-
compile_result = _compile_subprocess([kernel_meta_parent_dir, kernel_meta_dir], real_info_path, is_lite=True)
|
|
1049
|
-
json_path = os.path.join(kernel_meta_parent_dir, kernel_meta_dir, kernel_name + JSON_SUFFIX)
|
|
1050
|
-
if compile_result.returncode or not os.path.exists(json_path):
|
|
1051
|
-
raise RuntimeError("Compile {} failed! Detailed compile message: {}, {}"
|
|
1052
|
-
.format(kernel_name, compile_result.stdout.strip(), compile_result.stderr.strip()))
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
def custom_train(*args, **kwags):
|
|
1056
|
-
"""Entrance for akg info compiling, will invoked by GE"""
|
|
1057
|
-
|
|
1058
|
-
def _get_optimized_info_path():
|
|
1059
|
-
"""Get the info optimized by akg."""
|
|
1060
|
-
target_info = "target_info"
|
|
1061
|
-
file_path = os.path.join(composite_graph_dir, kernel_name + ".info")
|
|
1062
|
-
if not os.path.isfile(file_path):
|
|
1063
|
-
return real_info_path
|
|
1064
|
-
with open(os.path.realpath(real_info_path), 'r') as f:
|
|
1065
|
-
desc = json.loads(f.read())
|
|
1066
|
-
if target_info in desc:
|
|
1067
|
-
with open(os.path.realpath(file_path), 'r') as fo:
|
|
1068
|
-
info_desc = json.loads(fo.read())
|
|
1069
|
-
info_desc[target_info] = desc[target_info]
|
|
1070
|
-
save(file_path, json.dumps(info_desc))
|
|
1071
|
-
return file_path
|
|
1072
|
-
|
|
1073
|
-
info_path = args[-2]
|
|
1074
|
-
kernel_name = args[-1]
|
|
1075
|
-
real_info_path = save_updated_akg_info(*args)
|
|
1076
|
-
if not real_info_path:
|
|
1077
|
-
return
|
|
1078
|
-
info_dir = os.path.realpath(os.path.dirname(info_path))
|
|
1079
|
-
kernel_meta_parent_dir = get_current_build_config("kernel_meta_parent_dir")
|
|
1080
|
-
kernel_meta = "kernel_meta"
|
|
1081
|
-
kernel_meta_dir = os.path.join(kernel_meta_parent_dir, kernel_meta)
|
|
1082
|
-
akg_compile_dir = os.path.join(info_dir, "akg")
|
|
1083
|
-
tbe_compile_dir = os.path.join(info_dir, "tbe")
|
|
1084
|
-
composite_graph_dir = os.path.join(info_dir, "composite") # save akg optimized info
|
|
1085
|
-
akg_kernel_meta_dir = os.path.join(akg_compile_dir, kernel_meta) # save akg compile result
|
|
1086
|
-
tbe_kernel_meta_dir = os.path.join(tbe_compile_dir, kernel_meta) # save tbe compile result
|
|
1087
|
-
create_dirs(kernel_meta_dir, composite_graph_dir, akg_kernel_meta_dir, tbe_kernel_meta_dir)
|
|
1088
|
-
# Compile with AKG
|
|
1089
|
-
attr = {"dump_composite_graph": composite_graph_dir, "optimize_for_tbe": True}
|
|
1090
|
-
attrs = json.dumps(attr)
|
|
1091
|
-
akg_compile_result = _compile_subprocess([akg_compile_dir, kernel_meta], real_info_path,
|
|
1092
|
-
is_lite=False, compile_backend="AKG", attrs=attrs)
|
|
1093
|
-
json_path = os.path.join(akg_kernel_meta_dir, kernel_name + JSON_SUFFIX)
|
|
1094
|
-
o_path = os.path.join(akg_kernel_meta_dir, kernel_name + O_SUFFIX)
|
|
1095
|
-
if not os.path.exists(json_path):
|
|
1096
|
-
# Compile with TBE
|
|
1097
|
-
optimized_info_path = _get_optimized_info_path()
|
|
1098
|
-
tbe_compile_result = _compile_subprocess([tbe_compile_dir, kernel_meta], optimized_info_path,
|
|
1099
|
-
is_lite=False, compile_backend="TBE", attrs=attrs)
|
|
1100
|
-
json_path = os.path.join(tbe_kernel_meta_dir, kernel_name + JSON_SUFFIX)
|
|
1101
|
-
o_path = os.path.join(tbe_kernel_meta_dir, kernel_name + O_SUFFIX)
|
|
1102
|
-
if not os.path.exists(json_path):
|
|
1103
|
-
raise RuntimeError("Compile {} failed! Detailed akg compile message: {}, {}\n"
|
|
1104
|
-
"Detailed tbe compile message: {}, {}"
|
|
1105
|
-
.format(kernel_name,
|
|
1106
|
-
akg_compile_result.stdout.strip(), akg_compile_result.stderr.strip(),
|
|
1107
|
-
tbe_compile_result.stdout.strip(), tbe_compile_result.stderr.strip()))
|
|
1108
|
-
copy_file(json_path, os.path.join(kernel_meta_dir, kernel_name + JSON_SUFFIX))
|
|
1109
|
-
copy_file(o_path, os.path.join(kernel_meta_dir, kernel_name + O_SUFFIX))
|