onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime_extensions/__init__.py +82 -0
- onnxruntime_extensions/_cuops.py +564 -0
- onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
- onnxruntime_extensions/_extensions_pydll.pyi +45 -0
- onnxruntime_extensions/_hf_cvt.py +331 -0
- onnxruntime_extensions/_ocos.py +133 -0
- onnxruntime_extensions/_ortapi2.py +274 -0
- onnxruntime_extensions/_torch_cvt.py +231 -0
- onnxruntime_extensions/_version.py +2 -0
- onnxruntime_extensions/cmd.py +66 -0
- onnxruntime_extensions/cvt.py +306 -0
- onnxruntime_extensions/onnxprocess/__init__.py +12 -0
- onnxruntime_extensions/onnxprocess/_builder.py +53 -0
- onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
- onnxruntime_extensions/onnxprocess/_session.py +355 -0
- onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
- onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
- onnxruntime_extensions/pnp/__init__.py +13 -0
- onnxruntime_extensions/pnp/_base.py +124 -0
- onnxruntime_extensions/pnp/_imagenet.py +65 -0
- onnxruntime_extensions/pnp/_nlp.py +148 -0
- onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
- onnxruntime_extensions/pnp/_torchext.py +310 -0
- onnxruntime_extensions/pnp/_unifier.py +45 -0
- onnxruntime_extensions/pnp/_utils.py +302 -0
- onnxruntime_extensions/pp_api.py +83 -0
- onnxruntime_extensions/tools/__init__.py +0 -0
- onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
- onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
- onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
- onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
- onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
- onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
- onnxruntime_extensions/util.py +186 -0
- onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
- onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
- onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
- onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
- onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
3
|
+
# license information.
|
|
4
|
+
###############################################################################
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
cvt.py: Processing Graph Converter and Generator
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from typing import Union
|
|
11
|
+
|
|
12
|
+
from ._hf_cvt import HFTokenizerConverter, HFTokenizerOnnxGraph # noqa
|
|
13
|
+
from ._ortapi2 import make_onnx_model, SingleOpGraph
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import numpy as np
|
|
17
|
+
import tempfile
|
|
18
|
+
import shutil
|
|
19
|
+
|
|
20
|
+
# edit environment variables to avoid protobuf version mismatch
|
|
21
|
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
|
22
|
+
|
|
23
|
+
from transformers.convert_slow_tokenizer import SpmConverter # noqa: E402
|
|
24
|
+
from transformers import AutoTokenizer # noqa: E402
|
|
25
|
+
from tokenizers import decoders, normalizers, pre_tokenizers, Regex # noqa: E402
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
OrtxTokenizer = None
|
|
29
|
+
try:
|
|
30
|
+
from onnxruntime_extensions.pp_api import Tokenizer as OrtxTokenizer
|
|
31
|
+
except ImportError:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
_is_torch_available = False
|
|
35
|
+
try:
|
|
36
|
+
import torch # noqa
|
|
37
|
+
_is_torch_available = True
|
|
38
|
+
from ._torch_cvt import WhisperDataProcGraph
|
|
39
|
+
except ImportError:
|
|
40
|
+
WhisperDataProcGraph = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_PRE_POST_PAIR = {'TrieTokenizer': "TrieDetokenizer"}
|
|
44
|
+
|
|
45
|
+
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
|
46
|
+
if add_prefix_space:
|
|
47
|
+
prepend_scheme = "always"
|
|
48
|
+
if not getattr(original_tokenizer, "legacy", True):
|
|
49
|
+
prepend_scheme = "first"
|
|
50
|
+
else:
|
|
51
|
+
prepend_scheme = "never"
|
|
52
|
+
return prepend_scheme
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Baichuan2Converter(SpmConverter):
|
|
56
|
+
handle_byte_fallback = True
|
|
57
|
+
|
|
58
|
+
def __init__(self, original_tokenizer):
|
|
59
|
+
super().__init__(original_tokenizer)
|
|
60
|
+
original_tokenizer.add_prefix_space = False
|
|
61
|
+
|
|
62
|
+
def vocab(self, proto):
|
|
63
|
+
vocab = [
|
|
64
|
+
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
|
|
65
|
+
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
|
|
66
|
+
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
|
|
67
|
+
]
|
|
68
|
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
|
69
|
+
return vocab
|
|
70
|
+
|
|
71
|
+
def unk_id(self, proto):
|
|
72
|
+
unk_id = 0
|
|
73
|
+
return unk_id
|
|
74
|
+
|
|
75
|
+
def decoder(self, replacement, add_prefix_space):
|
|
76
|
+
sequence = [
|
|
77
|
+
decoders.Replace("▁", " "),
|
|
78
|
+
decoders.ByteFallback(),
|
|
79
|
+
decoders.Fuse(),
|
|
80
|
+
]
|
|
81
|
+
if add_prefix_space:
|
|
82
|
+
sequence += [decoders.Strip(content=" ", left=1)]
|
|
83
|
+
return decoders.Sequence(sequence)
|
|
84
|
+
|
|
85
|
+
def normalizer(self, proto):
|
|
86
|
+
if getattr(self.original_tokenizer, "legacy", True):
|
|
87
|
+
sequence = []
|
|
88
|
+
if getattr(self.original_tokenizer, "add_prefix_space", True):
|
|
89
|
+
sequence += [normalizers.Prepend(prepend="▁")]
|
|
90
|
+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
|
|
91
|
+
return normalizers.Sequence(sequence)
|
|
92
|
+
return None # non-legacy, no normalizer
|
|
93
|
+
|
|
94
|
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
|
95
|
+
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
|
|
96
|
+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
|
|
97
|
+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
|
|
98
|
+
else:
|
|
99
|
+
return super().pre_tokenizer(replacement, add_prefix_space)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class ChatGlmConverter(SpmConverter):
|
|
103
|
+
def normalizer(self, proto):
|
|
104
|
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
|
105
|
+
_normalizers = [
|
|
106
|
+
normalizers.Strip(left=False, right=True), # stripping is important
|
|
107
|
+
normalizers.Replace(Regex(" {2,}"), "▁"),
|
|
108
|
+
]
|
|
109
|
+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
|
|
110
|
+
|
|
111
|
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
|
112
|
+
prepend_scheme = "always"
|
|
113
|
+
if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
|
|
114
|
+
prepend_scheme = "first"
|
|
115
|
+
return pre_tokenizers.Metaspace(
|
|
116
|
+
replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
JSON_TOKEN_CONVERTERS = {
|
|
121
|
+
"BaichuanTokenizer": Baichuan2Converter,
|
|
122
|
+
"ChatGLMTokenizer": ChatGlmConverter,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
# Save tokenizer JSON files using HuggingFace AutoTokenizer
|
|
126
|
+
def convert_tokenizer(model_path, output_dir):
|
|
127
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
128
|
+
if output_dir is None:
|
|
129
|
+
if os.path.isdir(model_path):
|
|
130
|
+
output_dir = model_path
|
|
131
|
+
else:
|
|
132
|
+
# create a temporary directory
|
|
133
|
+
output_dir = tempfile.mkdtemp()
|
|
134
|
+
tokenizer.save_pretrained(output_dir)
|
|
135
|
+
json_path = os.path.join(output_dir, "tokenizer.json")
|
|
136
|
+
|
|
137
|
+
if type(tokenizer).__name__ in JSON_TOKEN_CONVERTERS:
|
|
138
|
+
GenericSpmConverter = JSON_TOKEN_CONVERTERS[type(tokenizer).__name__]
|
|
139
|
+
|
|
140
|
+
converted = GenericSpmConverter(tokenizer).converted()
|
|
141
|
+
converted.save(json_path)
|
|
142
|
+
print(f"**Tokenizer saved to {json_path}")
|
|
143
|
+
return output_dir
|
|
144
|
+
|
|
145
|
+
# Validate tokenizer files downloaded from memory
|
|
146
|
+
def validate_tokenizer(model_path, output_dir):
|
|
147
|
+
test_sentence = "I like walking my cute dog\n and\x17 then, 生活的真谛是 \t\t\t\t \n\n61"
|
|
148
|
+
if OrtxTokenizer is None:
|
|
149
|
+
print("onnxruntime_extensions package was built with C API enabled, skipping tokenization test")
|
|
150
|
+
ortx_tokenizer = OrtxTokenizer(output_dir)
|
|
151
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
|
|
152
|
+
expected_ids = tokenizer(test_sentence, return_tensors="np")["input_ids"]
|
|
153
|
+
ortx_ids = np.asarray(ortx_tokenizer.tokenize(test_sentence))
|
|
154
|
+
assert np.array_equal(expected_ids[0], ortx_ids), f"Tokenization mismatch: {expected_ids[0]} != {ortx_ids}"
|
|
155
|
+
print("Tokenization test passed")
|
|
156
|
+
|
|
157
|
+
# Download tokenizer JSON files from memory
|
|
158
|
+
def download_tokenizer(tokenizer_dir, output_dir):
|
|
159
|
+
try:
|
|
160
|
+
from transformers.utils import cached_file
|
|
161
|
+
|
|
162
|
+
resolved_full_file = cached_file(tokenizer_dir, "tokenizer.json")
|
|
163
|
+
resolved_config_file = cached_file(tokenizer_dir, "tokenizer_config.json")
|
|
164
|
+
except ImportError:
|
|
165
|
+
raise ValueError(f"Directory '{tokenizer_dir}' not found and transformers is not available")
|
|
166
|
+
if not os.path.exists(resolved_full_file):
|
|
167
|
+
raise FileNotFoundError(f"Downloaded HF file '{resolved_full_file}' cannot be found")
|
|
168
|
+
if os.path.dirname(resolved_full_file) != os.path.dirname(resolved_config_file):
|
|
169
|
+
raise FileNotFoundError(
|
|
170
|
+
f"Downloaded HF files '{resolved_full_file}' " f"and '{resolved_config_file}' are not in the same directory"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if output_dir is None or len(output_dir) == 0:
|
|
174
|
+
output_dir = os.path.dirname(resolved_full_file)
|
|
175
|
+
print(f"Using {output_dir} as output directory")
|
|
176
|
+
return output_dir
|
|
177
|
+
else:
|
|
178
|
+
# copy the files to the output directory
|
|
179
|
+
shutil.copy(resolved_full_file, output_dir)
|
|
180
|
+
shutil.copy(resolved_config_file, output_dir)
|
|
181
|
+
return output_dir
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def gen_processing_models(processor: Union[str, object],
|
|
185
|
+
pre_kwargs: dict = None,
|
|
186
|
+
post_kwargs: dict = None,
|
|
187
|
+
opset: int = None,
|
|
188
|
+
schema_v2: bool = False,
|
|
189
|
+
**kwargs):
|
|
190
|
+
"""
|
|
191
|
+
Generate the pre- and post-processing ONNX model, basing on the name or HF class.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
processor:
|
|
196
|
+
the HF processor/tokenizer instance, or the name (str) of a Data Processor
|
|
197
|
+
the instance is preferred, otherwise when name was given, the corresponding configuration for the processor
|
|
198
|
+
has to be provided in the kwargs
|
|
199
|
+
pre_kwargs: dict
|
|
200
|
+
Keyword arguments for generating the pre-processing model
|
|
201
|
+
WITH_DEFAULT_INPUTS: bool, add default inputs to the graph, default is True
|
|
202
|
+
CAST_TOKEN_ID: bool, add a cast op to output token IDs to be int64 if needed, default is False
|
|
203
|
+
post_kwargs: dict
|
|
204
|
+
Keyword arguments for generating the post-processing model
|
|
205
|
+
opset: int
|
|
206
|
+
the target opset version of the model
|
|
207
|
+
schema_v2: bool
|
|
208
|
+
the flag for using embedded tokenizer files; this option leverages the blob-loading functionality
|
|
209
|
+
which loads HF tokenizers from memory rather than using the tokenizer files in HF JSON format.
|
|
210
|
+
kwargs:
|
|
211
|
+
The additional arguments for generating models
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
ONNX-Models
|
|
216
|
+
The pre- and post-processing ONNX models
|
|
217
|
+
"""
|
|
218
|
+
if pre_kwargs is None and post_kwargs is None:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"Either pre_kwargs or post_kwargs should be provided. None means no processing graph output.")
|
|
221
|
+
|
|
222
|
+
# If true, we get the tokenizer JSON files by either downloading from cache or using HuggingFace AutoTokenizer
|
|
223
|
+
# to convert them, and then create an ONNX model with the JSON files as strings in the model attributes (attrs).
|
|
224
|
+
if schema_v2:
|
|
225
|
+
model_name = processor if isinstance(processor, str) else type(processor).__name__
|
|
226
|
+
|
|
227
|
+
converted_tokenizer = {"Baichuan2", "chatglm"}
|
|
228
|
+
need_convert = False
|
|
229
|
+
for token in converted_tokenizer:
|
|
230
|
+
if model_name.find(token) != -1:
|
|
231
|
+
need_convert = True
|
|
232
|
+
break
|
|
233
|
+
|
|
234
|
+
if need_convert:
|
|
235
|
+
model_dir = convert_tokenizer(model_name)
|
|
236
|
+
validate_tokenizer(model_name, None)
|
|
237
|
+
else:
|
|
238
|
+
model_dir = download_tokenizer(model_name, None)
|
|
239
|
+
|
|
240
|
+
# Load the content of tokenizer.json into a string
|
|
241
|
+
with open(f"{model_dir}/tokenizer.json", "r", encoding="utf-8") as f:
|
|
242
|
+
tokenizer_vocab = f.read()
|
|
243
|
+
|
|
244
|
+
# Load the content of tokenizer_config.json into a string
|
|
245
|
+
with open(f"{model_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
|
|
246
|
+
tokenizer_config = f.read()
|
|
247
|
+
|
|
248
|
+
# Create an ONNX model with these JSON file strings in attrs
|
|
249
|
+
g_pre, g_post = (None, None)
|
|
250
|
+
if pre_kwargs is not None:
|
|
251
|
+
# Add tokenizer_vocab and tokenizer_config to the kwargs
|
|
252
|
+
# so they are added to attrs in build_graph
|
|
253
|
+
pre_kwargs['tokenizer_vocab'] = tokenizer_vocab
|
|
254
|
+
pre_kwargs['tokenizer_config'] = tokenizer_config
|
|
255
|
+
g_pre = SingleOpGraph.build_graph("HfJsonTokenizer", **pre_kwargs)
|
|
256
|
+
if post_kwargs is not None:
|
|
257
|
+
if pre_kwargs is None:
|
|
258
|
+
cls_name = processor
|
|
259
|
+
else:
|
|
260
|
+
if processor not in _PRE_POST_PAIR:
|
|
261
|
+
raise RuntimeError(
|
|
262
|
+
f"Cannot locate the post processing operator name from {processor}")
|
|
263
|
+
cls_name = _PRE_POST_PAIR[processor]
|
|
264
|
+
# Add tokenizer_vocab and tokenizer_config to the kwargs
|
|
265
|
+
# so they are added to attrs in build_graph
|
|
266
|
+
post_kwargs['tokenizer_vocab'] = tokenizer_vocab
|
|
267
|
+
post_kwargs['tokenizer_config'] = tokenizer_config
|
|
268
|
+
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
|
|
269
|
+
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
|
|
270
|
+
else:
|
|
271
|
+
if isinstance(processor, str):
|
|
272
|
+
g_pre, g_post = (None, None)
|
|
273
|
+
if pre_kwargs:
|
|
274
|
+
g_pre = SingleOpGraph.build_graph(processor, **pre_kwargs)
|
|
275
|
+
if post_kwargs:
|
|
276
|
+
if pre_kwargs is None:
|
|
277
|
+
cls_name = processor
|
|
278
|
+
else:
|
|
279
|
+
if processor not in _PRE_POST_PAIR:
|
|
280
|
+
raise RuntimeError(
|
|
281
|
+
f"Cannot locate the post processing operator name from {processor}")
|
|
282
|
+
cls_name = _PRE_POST_PAIR[processor]
|
|
283
|
+
g_post = SingleOpGraph.build_graph(cls_name, **post_kwargs)
|
|
284
|
+
return make_onnx_model(g_pre) if g_pre else None, make_onnx_model(g_post) if g_post else None
|
|
285
|
+
|
|
286
|
+
cls_name = type(processor).__name__
|
|
287
|
+
if cls_name == "WhisperProcessor":
|
|
288
|
+
if WhisperDataProcGraph is None:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
"The Whisper processor needs torch.onnx support, please install pytorch 2.0 and above")
|
|
291
|
+
_converter = WhisperDataProcGraph(processor, opset=opset, **kwargs)
|
|
292
|
+
pre_m = _converter.pre_processing(
|
|
293
|
+
**pre_kwargs) if pre_kwargs is not None else None
|
|
294
|
+
post_m = _converter.post_processing(
|
|
295
|
+
**post_kwargs) if post_kwargs is not None else None
|
|
296
|
+
return pre_m, post_m
|
|
297
|
+
elif HFTokenizerOnnxGraph.is_supported(processor):
|
|
298
|
+
_converter = HFTokenizerOnnxGraph(processor)
|
|
299
|
+
pre_g = _converter.pre_processing(
|
|
300
|
+
**pre_kwargs) if pre_kwargs is not None else None
|
|
301
|
+
post_g = _converter.post_processing(
|
|
302
|
+
**post_kwargs) if post_kwargs is not None else None
|
|
303
|
+
return make_onnx_model(pre_g) if pre_g else None, \
|
|
304
|
+
make_onnx_model(post_g) if post_g else None
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(f"Unsupported processor/tokenizer: {cls_name}")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
override the torch importing, to dump all torch operators during the processing code.
|
|
3
|
+
!!!This package depends on onnxruntime_extensions root package, but not vice versa.!!!
|
|
4
|
+
, since this package fully relies on pytorch, while the onnxruntime_extensions doesn't
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from . _tensor import op_from_customop as pyfunc_from_custom_op
|
|
8
|
+
from . _tensor import op_from_model as pyfunc_from_model
|
|
9
|
+
from ._builder import build_customop_model
|
|
10
|
+
from ._session import ONNXTraceSession
|
|
11
|
+
|
|
12
|
+
trace_for_onnx = ONNXTraceSession.trace_for_onnx
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import pathlib
|
|
3
|
+
from ._onnx_ops import make_model_ex
|
|
4
|
+
from .._cuops import SingleOpGraph, GPT2Tokenizer, VectorToString
|
|
5
|
+
from .._ortapi2 import default_opset_domain
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def is_path(name_or_buffer):
|
|
9
|
+
return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _GPT2Tokenizer(GPT2Tokenizer):
|
|
13
|
+
@classmethod
|
|
14
|
+
def serialize_attr(cls, kwargs):
|
|
15
|
+
assert 'model' in kwargs, "Need model parameter to build the tokenizer"
|
|
16
|
+
hf_gpt2_tokenizer = kwargs['model']
|
|
17
|
+
attrs = {'vocab': json.dumps(hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
|
|
18
|
+
sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
|
|
19
|
+
attrs['merges'] = '\n'.join("{} {}".format(*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
|
20
|
+
return attrs
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _VectorToString(VectorToString):
|
|
24
|
+
@classmethod
|
|
25
|
+
def serialize_attr(cls, kwargs):
|
|
26
|
+
assert 'decoder' in kwargs, "Need decoder parameter to build the tokenizer"
|
|
27
|
+
decoder = kwargs['decoder']
|
|
28
|
+
remapped = {v: [k] for k, v in decoder.items()}
|
|
29
|
+
attrs = dict(map=remapped, unk='<unknown>')
|
|
30
|
+
return super().serialize_attr(attrs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
customop_mbuilder = {
|
|
34
|
+
c_.op_type(): c_ for c_ in (
|
|
35
|
+
_GPT2Tokenizer,
|
|
36
|
+
_VectorToString
|
|
37
|
+
)
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def build_customop_model(op_type, f, opset_version=11, **attrs):
|
|
42
|
+
op_class = SingleOpGraph.get_op_class(op_type)
|
|
43
|
+
if op_type in customop_mbuilder:
|
|
44
|
+
op_class = customop_mbuilder[op_type]
|
|
45
|
+
|
|
46
|
+
graph = SingleOpGraph.build_my_graph(op_class, **attrs)
|
|
47
|
+
m = make_model_ex(graph, [(default_opset_domain(), 1)], opset_version)
|
|
48
|
+
if is_path(f):
|
|
49
|
+
with open(f, 'wb') as f_:
|
|
50
|
+
f_.write(m.SerializeToString())
|
|
51
|
+
else:
|
|
52
|
+
f.write(m.SerializeToString())
|
|
53
|
+
f.flush()
|