onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,331 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+
6
+ """
7
+ _hf_cvt.py: HuggingFace Tokenizer/Processor Converter
8
+ """
9
+ import os
10
+ import json
11
+ import onnx
12
+ from numpy import array as nparray
13
+ from functools import partial
14
+ from collections import namedtuple, OrderedDict
15
+
16
+ from .util import read_file
17
+ from ._cuops import CustomOpConverter, SingleOpGraph
18
+
19
+
20
+ class HFTokenizerConverter(CustomOpConverter):
21
+ def __init__(self, tokenizer):
22
+ self.tokenizer = tokenizer
23
+
24
+ @staticmethod
25
+ def convert_bpe_vocab(hf_tokenizer):
26
+ attrs = {'vocab': json.dumps(
27
+ hf_tokenizer.encoder, separators=(',', ':'))}
28
+ if hf_tokenizer.added_tokens_encoder:
29
+ token_map = [f"{_k}={_v}" for _k,
30
+ _v in hf_tokenizer.added_tokens_encoder.items()]
31
+ attrs.update({"added_token": "\n".join(token_map)})
32
+
33
+ sorted_merges = {v_: k_ for k_, v_ in hf_tokenizer.bpe_ranks.items()}
34
+ attrs['merges'] = '\n'.join("{} {}".format(
35
+ *sorted_merges[n_]) for n_ in range(len(sorted_merges)))
36
+ return attrs
37
+
38
+ @staticmethod
39
+ def convert_json_vocab(hf_tokenizer):
40
+ filenames = getattr(hf_tokenizer, "vocab_files_names", None)
41
+ if filenames is None:
42
+ raise ValueError(
43
+ f"{hf_tokenizer.__name__}: vocab_files_names is not found")
44
+
45
+ tokenizer_file = filenames["tokenizer_file"]
46
+ vocab_file = getattr(hf_tokenizer, "vocab_file", None)
47
+ if (vocab_file is None) or (not os.path.exists(vocab_file)):
48
+ model_dir = hf_tokenizer.name_or_path
49
+ else:
50
+ model_dir = os.path.dirname(vocab_file)
51
+ f = open(os.path.join(model_dir, tokenizer_file), "r", encoding="utf-8")
52
+ tokenizer_json = json.load(f)
53
+ f.close()
54
+ # get vocab object from json file
55
+ vocab = tokenizer_json.get("model", {}).get("vocab", {})
56
+ sorted_merges = tokenizer_json.get("model", {}).get("merges", [])
57
+ sorted_merges = [v_.replace("\n", "<0x0A>") for v_ in sorted_merges]
58
+ attrs = {"vocab": json.dumps(vocab, separators=(",", ":"))}
59
+ attrs["merges"] = "\n".join(sorted_merges)
60
+ if hf_tokenizer.added_tokens_encoder:
61
+ token_map = [f"{_k}={_v}" for _k,
62
+ _v in hf_tokenizer.added_tokens_encoder.items()]
63
+ attrs.update({"added_token": "\n".join(token_map)})
64
+
65
+ return attrs
66
+
67
+ @staticmethod
68
+ def get_model_name(hf_tokenizer):
69
+ name = hf_tokenizer.__class__.__name__
70
+ if name.endswith("Fast"):
71
+ name = name[: -len("Fast")]
72
+ if name.endswith("Tokenizer"):
73
+ name = name[: -len("Tokenizer")]
74
+ return name
75
+
76
+ def bpe_tokenizer(self, **kwargs):
77
+ hf_bpe_tokenizer = self.tokenizer
78
+ if getattr(hf_bpe_tokenizer, "is_fast", True):
79
+ attrs = self.convert_json_vocab(hf_bpe_tokenizer)
80
+ else:
81
+ attrs = self.convert_bpe_vocab(hf_bpe_tokenizer)
82
+
83
+ attrs.update({"model_name": self.get_model_name(hf_bpe_tokenizer)})
84
+ attrs.update(**kwargs)
85
+ return attrs
86
+
87
+ def bert_tokenizer(self, **kwargs):
88
+ hf_bert_tokenizer = self.tokenizer
89
+ # has to be sorted since the id of token was generated automatically.
90
+ ordered_vocab = OrderedDict(
91
+ sorted(hf_bert_tokenizer.vocab.items(), key=lambda item: int(item[1])))
92
+ vocab = '\n'.join(ordered_vocab.keys())
93
+ attrs = dict(vocab=vocab)
94
+ init_kwargs = hf_bert_tokenizer.init_kwargs
95
+ attrs['do_lower_case'] = 1 if 'do_lower_case' in init_kwargs and init_kwargs.get(
96
+ 'do_lower_case') else 0
97
+ attrs['strip_accents'] = 1 if 'strip_accents' in init_kwargs and init_kwargs.get(
98
+ 'strip_accents') else 0
99
+ attrs.update(**kwargs)
100
+ return attrs
101
+
102
+ def bert_decoder(self, **kwargs):
103
+ hf_bert_tokenizer = self.tokenizer
104
+ attrs = {'vocab': json.dumps(
105
+ hf_bert_tokenizer.ids_to_tokens, separators=(',', ':'))}
106
+ attrs.update(**kwargs)
107
+ return attrs
108
+
109
+ def bpe_decoder(self, **kwargs):
110
+ decoder = self.tokenizer.decoder
111
+ # if decoder is not iterable, build it from the vocab.
112
+ if not hasattr(decoder, "__iter__"):
113
+ decoder = {id: token for token, id in self.tokenizer.vocab.items()}
114
+ id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)])
115
+ byte_decoder = getattr(self.tokenizer, "byte_decoder", None)
116
+ if byte_decoder is None:
117
+ # let's take it as a SPM tokenizer
118
+ byte_decoder = {chr(0x2581): ord(' ')}
119
+ str_byte_decoder = "\n".join(
120
+ ["{}\t{}".format(ord(_c), str(byte_decoder[_c])) for _c in byte_decoder])
121
+ all_special_ids = self.tokenizer.all_special_ids
122
+ added_tokens = self.tokenizer.added_tokens_decoder
123
+ str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids])
124
+ str_added_tokens = "\n".join(
125
+ ["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens])
126
+ kwargs.update({
127
+ "id_vocab": id_vocab,
128
+ "byte_decoder": str_byte_decoder,
129
+ "added_tokens": str_added_tokens,
130
+ "all_special_ids": str_all_special_ids,
131
+ "skip_special_tokens": kwargs.get("skip_special_tokens", False)
132
+ })
133
+ return kwargs
134
+
135
+ def clip_tokenizer(self, **kwargs):
136
+ hf_clip_tokenizer = self.tokenizer
137
+
138
+ if type(self.tokenizer).__name__.endswith('Fast'):
139
+ raise ValueError(
140
+ 'Please use the slow version of the tokenizer (ex: CLIPTokenizer).')
141
+
142
+ attrs = self.convert_bpe_vocab(hf_clip_tokenizer)
143
+ attrs.update(**kwargs)
144
+ return attrs
145
+
146
+ def roberta_tokenizer(self, **kwargs):
147
+ hf_roberta_tokenizer = self.tokenizer
148
+
149
+ if type(self.tokenizer).__name__.endswith('Fast'):
150
+ raise ValueError(
151
+ 'Please use the slow version of the tokenizer (ex: RobertaTokenizer).')
152
+
153
+ attrs = self.convert_bpe_vocab(hf_roberta_tokenizer)
154
+ attrs.update(**kwargs)
155
+ return attrs
156
+
157
+ def spm_tokenizer(self, **kwargs):
158
+ attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
159
+ attrs.update(**kwargs)
160
+ return attrs
161
+
162
+ def spm_decoder(self, **kwargs):
163
+ attrs = {'model': read_file(self.tokenizer.vocab_file, 'rb')}
164
+ attrs.update(**kwargs)
165
+ return attrs
166
+
167
+
168
+ TokenOpParam = namedtuple("TokenOpParam",
169
+ ["pre_op", "pre_attribute_cvt",
170
+ "post_op", "post_attribute_cvt",
171
+ "default_encoder_inputs",
172
+ "default_decoder_inputs"],
173
+ defaults=(None, None, None, None, None))
174
+
175
+ # Some tokenizers can be added by this table
176
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1252
177
+ # @formatter:off
178
+ _PROCESSOR_DICT = {
179
+ "BertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
180
+ 'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
181
+ "DistilBertTokenizer": TokenOpParam('BertTokenizer', HFTokenizerConverter.bert_tokenizer,
182
+ 'BertDecoder', HFTokenizerConverter.bpe_decoder, None, None),
183
+ "GPT2Tokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
184
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
185
+ "CodeGenTokenizer": TokenOpParam('GPT2Tokenizer', HFTokenizerConverter.bpe_tokenizer,
186
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
187
+ "CLIPTokenizer": TokenOpParam('CLIPTokenizer', HFTokenizerConverter.clip_tokenizer,
188
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
189
+ "RobertaTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
190
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
191
+ "BartTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
192
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
193
+ "LayoutLMv3Tokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
194
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
195
+ "LongformerTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
196
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
197
+ "LEDTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
198
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
199
+ "MvpTokenizer": TokenOpParam('RobertaTokenizer', HFTokenizerConverter.roberta_tokenizer,
200
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
201
+ "T5Tokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
202
+ 'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
203
+ default_encoder_inputs={'add_eos': [True]}, default_decoder_inputs=None),
204
+ "LlamaTokenizer": TokenOpParam('SpmTokenizer', HFTokenizerConverter.bpe_tokenizer,
205
+ 'BpeDecoder', HFTokenizerConverter.bpe_decoder, None, None),
206
+ "XLMRobertaTokenizer": TokenOpParam('SentencepieceTokenizer', HFTokenizerConverter.spm_tokenizer,
207
+ 'SentencepieceDecoder', HFTokenizerConverter.spm_decoder,
208
+ default_encoder_inputs={'add_bos': [True], 'add_eos': [True], 'fairseq': [True]},
209
+ default_decoder_inputs={'fairseq': [True]}),
210
+ }
211
+ # @formatter:on
212
+
213
+
214
+ class HFTokenizerOnnxGraph:
215
+
216
+ @staticmethod
217
+ def extract_cls_name(processor):
218
+ cls_name = processor if isinstance(
219
+ processor, str) else type(processor).__name__
220
+ if cls_name.endswith("TokenizerFast"):
221
+ cls_name = cls_name[:-len("Fast")]
222
+ return cls_name
223
+
224
+ @classmethod
225
+ def is_supported(cls, processor):
226
+ cls_name = cls.extract_cls_name(processor)
227
+ return cls_name in _PROCESSOR_DICT
228
+
229
+ def __init__(self, processor, **kwargs):
230
+ cls_name = self.extract_cls_name(processor)
231
+ self.cvt_quadruple = _PROCESSOR_DICT[cls_name]
232
+ self.cvt_obj = HFTokenizerConverter(processor)
233
+
234
+ def pre_processing(self, **kwargs):
235
+ with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)
236
+ cast_token_id = kwargs.pop("CAST_TOKEN_ID", False)
237
+
238
+ _cvt_op = self.cvt_quadruple.pre_op
239
+ _cvt_func = self.cvt_quadruple.pre_attribute_cvt
240
+ cvt = partial(_cvt_func, self.cvt_obj)
241
+ g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
242
+ default_inputs = []
243
+ if with_default_inputs:
244
+ op_class = SingleOpGraph.get_op_class(_cvt_op)
245
+ default_inputs = op_class.input_default_values()
246
+ if default_inputs is None:
247
+ return g
248
+
249
+ # add default_inputs into initializers to simplify the model input
250
+ n_inputs = len(default_inputs)
251
+ if self.cvt_quadruple.default_encoder_inputs is not None:
252
+ default_inputs.update(self.cvt_quadruple.default_encoder_inputs)
253
+ if len(default_inputs) != n_inputs:
254
+ raise ValueError(
255
+ "Op: {} does not have the inputs from its TokenOpParam.".format(_cvt_op))
256
+
257
+ new_initializers = []
258
+
259
+ for k, v in default_inputs.items():
260
+ input_value_info = next((i for i in g.input if i.name == k), None)
261
+ if input_value_info is None:
262
+ raise ValueError(
263
+ "The input {} is not found in the graph".format(k))
264
+
265
+ np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
266
+ input_value_info.type.tensor_type.elem_type)
267
+ value = nparray(v, np_dtype)
268
+ new_initializers.append(onnx.numpy_helper.from_array(value, k))
269
+ g.initializer.extend(new_initializers)
270
+ new_inputs = [i for i in g.input if i.name not in default_inputs]
271
+ g.ClearField("input")
272
+ g.input.extend(new_inputs)
273
+
274
+ if cast_token_id:
275
+ # assume the first output is always the token ID.
276
+ if g.output[0].type.tensor_type.elem_type != onnx.onnx_pb.TensorProto.INT64:
277
+ new_output_name = g.output[0].name + '_cast'
278
+ shape = g.output[0].type.tensor_type.shape
279
+ cast_node = onnx.helper.make_node('Cast', [g.output[0].name], [new_output_name],
280
+ to=onnx.onnx_pb.TensorProto.INT64)
281
+ new_output = [onnx.helper.make_tensor_value_info(
282
+ new_output_name, onnx.onnx_pb.TensorProto.INT64, None)] + list(g.output)[1:]
283
+ if shape is not None:
284
+ new_output[0].type.tensor_type.shape.CopyFrom(shape)
285
+ g.node.append(cast_node)
286
+ g.ClearField('output')
287
+ g.output.extend(new_output)
288
+
289
+ return g
290
+
291
+ def post_processing(self, **kwargs):
292
+ with_default_inputs = kwargs.pop("WITH_DEFAULT_INPUTS", True)
293
+
294
+ _cvt_op = self.cvt_quadruple.post_op
295
+ _cvt_func = self.cvt_quadruple.post_attribute_cvt
296
+ cvt = partial(_cvt_func, self.cvt_obj)
297
+ g = SingleOpGraph.build_graph(_cvt_op, cvt=cvt, **kwargs)
298
+
299
+ default_inputs = {}
300
+ if with_default_inputs:
301
+ op_class = SingleOpGraph.get_op_class(_cvt_op)
302
+ default_inputs = op_class.input_default_values()
303
+ if default_inputs is None:
304
+ encoder_inputs = self.cvt_quadruple.default_encoder_inputs
305
+ if encoder_inputs is not None and encoder_inputs["fairseq"]:
306
+ default_inputs = {} # need to set to empty dict to call .update later
307
+ else:
308
+ return g
309
+
310
+ # add default_inputs into initializers to simplify the model input
311
+ if self.cvt_quadruple.default_decoder_inputs is not None:
312
+ default_inputs.update(self.cvt_quadruple.default_decoder_inputs)
313
+
314
+ new_initializers = []
315
+
316
+ for k, v in default_inputs.items():
317
+ input_value_info = next((i for i in g.input if i.name == k), None)
318
+ if input_value_info is None:
319
+ raise ValueError(
320
+ "The input {} is not found in the graph".format(k))
321
+
322
+ np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
323
+ input_value_info.type.tensor_type.elem_type)
324
+ value = nparray(v, np_dtype)
325
+ new_initializers.append(onnx.numpy_helper.from_array(value, k))
326
+ g.initializer.extend(new_initializers)
327
+ new_inputs = [i for i in g.input if i.name not in default_inputs]
328
+ g.ClearField("input")
329
+ g.input.extend(new_inputs)
330
+
331
+ return g
@@ -0,0 +1,133 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+ """
6
+ _ocos.py: PythonOp implementation
7
+ """
8
+ import os
9
+ import sys
10
+ import glob
11
+
12
+
13
+ def _search_cuda_dir():
14
+ paths = os.getenv("PATH", "").split(os.pathsep)
15
+ for path in paths:
16
+ for filename in glob.glob(os.path.join(path, "cudart64*.dll")):
17
+ return os.path.dirname(filename)
18
+
19
+ return None
20
+
21
+
22
+ if sys.platform == "win32":
23
+ from . import _version # noqa: E402
24
+
25
+ if hasattr(_version, "cuda"):
26
+ cuda_path = _search_cuda_dir()
27
+ if cuda_path is None:
28
+ raise RuntimeError("Cannot locate CUDA directory in the environment variable for GPU package")
29
+
30
+ os.add_dll_directory(cuda_path)
31
+
32
+
33
+ from ._extensions_pydll import ( # noqa
34
+ PyCustomOpDef,
35
+ enable_py_op,
36
+ add_custom_op,
37
+ hash_64,
38
+ default_opset_domain,
39
+ )
40
+
41
+
42
+ def get_library_path():
43
+ """
44
+ The custom operator library binary path
45
+ :return: A string of this library path.
46
+ """
47
+ mod = sys.modules["onnxruntime_extensions._extensions_pydll"]
48
+ return mod.__file__
49
+
50
+
51
+ class Opdef:
52
+ _odlist = {}
53
+
54
+ def __init__(self, op_type, func):
55
+ self.op_type = op_type
56
+ self.body = func
57
+ self._id = id(self)
58
+
59
+ @staticmethod
60
+ def declare(*args, **kwargs):
61
+ if len(args) > 0 and hasattr(args[0], "__call__"):
62
+ raise RuntimeError("Unexpected arguments {}.".format(args))
63
+ # return Opdef._create(args[0])
64
+ return lambda f: Opdef.create(f, *args, **kwargs)
65
+
66
+ @staticmethod
67
+ def create(func, *args, **kwargs):
68
+ name = kwargs.get("op_type", None)
69
+ op_type = name or func.__name__
70
+ opdef = Opdef(op_type, func)
71
+ od_id = id(opdef)
72
+
73
+ # Tells python this object cannot be destroyed
74
+ # because it is also stored in C++ container.
75
+ Opdef._odlist[od_id] = opdef
76
+ opdef._nativedef = PyCustomOpDef()
77
+ opdef._nativedef.op_type = op_type
78
+ opdef._nativedef.obj_id = od_id
79
+
80
+ inputs = kwargs.get("inputs", None)
81
+ if inputs is None:
82
+ inputs = [PyCustomOpDef.dt_float]
83
+ opdef._nativedef.input_types = inputs
84
+ outputs = kwargs.get("outputs", None)
85
+ if outputs is None:
86
+ outputs = [PyCustomOpDef.dt_float]
87
+ opdef._nativedef.output_types = outputs
88
+ attrs = kwargs.get("attrs", None)
89
+ if attrs is None:
90
+ attrs = {}
91
+ elif isinstance(attrs, (list, tuple)):
92
+ attrs = {k: PyCustomOpDef.dt_string for k in attrs}
93
+ opdef._nativedef.attrs = attrs
94
+ add_custom_op(opdef._nativedef)
95
+ return opdef
96
+
97
+ def __call__(self, *args, **kwargs):
98
+ return self.body(*args, **kwargs)
99
+
100
+ def cast_attributes(self, attributes):
101
+ res = {}
102
+ for k, v in attributes.items():
103
+ if self._nativedef.attrs[k] == PyCustomOpDef.dt_int64:
104
+ res[k] = int(v)
105
+ elif self._nativedef.attrs[k] == PyCustomOpDef.dt_float:
106
+ res[k] = float(v)
107
+ elif self._nativedef.attrs[k] == PyCustomOpDef.dt_string:
108
+ res[k] = v
109
+ else:
110
+ raise RuntimeError("Unsupported attribute type {}.".format(self._nativedef.attrs[k]))
111
+ return res
112
+
113
+
114
+ def _on_pyop_invocation(k_id, feed, attributes):
115
+ if k_id not in Opdef._odlist:
116
+ raise RuntimeError(
117
+ "Unable to find function id={}. " "Did you decorate the operator with @onnx_op?.".format(k_id)
118
+ )
119
+ op_ = Opdef._odlist[k_id]
120
+ rv = op_.body(*feed, **op_.cast_attributes(attributes))
121
+ if isinstance(rv, tuple):
122
+ # Multiple outputs.
123
+ res = []
124
+ for r in rv:
125
+ res.append(r.shape)
126
+ res.append(r.flatten().tolist())
127
+ res = tuple(res)
128
+ else:
129
+ res = (rv.shape, rv.flatten().tolist())
130
+ return (k_id,) + res
131
+
132
+
133
+ PyCustomOpDef.install_hooker(_on_pyop_invocation)