onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,82 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+ """
6
+ The `onnxruntime-extensions` Python package offers an API that allows users to generate models for pre-processing and
7
+ post-processing tasks. In addition, it also provides an API to register custom operations implemented in Python.
8
+ This enables more flexibility and control over model execution, thus expanding the functionality of the ONNX Runtime.
9
+ """
10
+
11
+ __author__ = "Microsoft"
12
+
13
+ from ._version import __version__
14
+ from ._ocos import get_library_path
15
+ from ._ocos import Opdef, PyCustomOpDef
16
+ from ._ocos import hash_64
17
+ from ._ocos import enable_py_op
18
+ from ._ocos import default_opset_domain
19
+
20
+
21
+ _lib_only = False
22
+
23
+ try:
24
+ import onnx # noqa
25
+ import onnxruntime # noqa
26
+ except ImportError:
27
+ _lib_only = True
28
+ pass
29
+
30
+
31
+ _offline_api = [
32
+ "gen_processing_models",
33
+ "ort_inference",
34
+ "OrtPyFunction",
35
+ "PyOrtFunction",
36
+ "optimize_model",
37
+ "make_onnx_model",
38
+ "ONNXRuntimeError",
39
+ ]
40
+
41
+ __all__ = [
42
+ "get_library_path",
43
+ "Opdef",
44
+ "onnx_op",
45
+ "PyCustomOpDef",
46
+ "PyOp",
47
+ "enable_py_op",
48
+ "expand_onnx_inputs",
49
+ "hook_model_op",
50
+ "default_opset_domain",
51
+ "hash_64",
52
+ "__version__",
53
+ ]
54
+
55
+ # rename the implementation with a more formal name
56
+ onnx_op = Opdef.declare
57
+ PyOp = PyCustomOpDef
58
+
59
+
60
+ if _lib_only:
61
+
62
+ def _unimplemented(*args, **kwargs):
63
+ raise NotImplementedError("ONNX or ONNX Runtime is not installed")
64
+
65
+ gen_processing_models = _unimplemented
66
+ OrtPyFunction = _unimplemented
67
+ ort_inference = _unimplemented
68
+ PyOrtFunction = _unimplemented
69
+ optimize_model = _unimplemented
70
+ make_onnx_model = _unimplemented
71
+ ONNXRuntimeError = _unimplemented
72
+
73
+ else:
74
+ __all__ += _offline_api
75
+
76
+ from ._cuops import * # noqa
77
+ from ._ortapi2 import hook_model_op
78
+ from ._ortapi2 import expand_onnx_inputs
79
+ from ._ortapi2 import OrtPyFunction, ort_inference, optimize_model, make_onnx_model
80
+ from ._ortapi2 import OrtPyFunction as PyOrtFunction # backward compatibility
81
+ from ._ortapi2 import ONNXRuntimeError, ONNXRuntimeException # noqa
82
+ from .cvt import gen_processing_models
@@ -0,0 +1,564 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+
6
+ """
7
+ _cuops.py: Custom operators signatures for Python usage.
8
+ """
9
+
10
+ import onnx
11
+ import numpy
12
+ from onnx import onnx_pb as onnx_proto
13
+ from ._ocos import default_opset_domain, Opdef, PyCustomOpDef
14
+
15
+
16
+ class CustomOp:
17
+
18
+ @classmethod
19
+ def op_type(cls):
20
+ rcls = cls
21
+ while CustomOp != rcls.__base__:
22
+ rcls = rcls.__base__
23
+ return rcls.__name__
24
+
25
+ @classmethod
26
+ def get_inputs(cls):
27
+ return None
28
+
29
+ @classmethod
30
+ def get_outputs(cls):
31
+ return None
32
+
33
+ @classmethod
34
+ def input_default_values(cls):
35
+ return None
36
+
37
+ @classmethod
38
+ def serialize_attr(cls, attrs):
39
+ """
40
+ Only support serialize the basic python type like list or dict,
41
+ All other types needs to be serialized by the users
42
+ :param attrs: the dict attributes
43
+ :return: the dict of serialized data
44
+ """
45
+ return attrs
46
+
47
+ io_def = onnx.helper.make_tensor_value_info
48
+
49
+
50
+ class GPT2Tokenizer(CustomOp):
51
+
52
+ @classmethod
53
+ def get_inputs(cls):
54
+ return [
55
+ cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
56
+ ]
57
+
58
+ @classmethod
59
+ def get_outputs(cls):
60
+ return [
61
+ cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
62
+ cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None])
63
+ ]
64
+
65
+
66
+ class CLIPTokenizer(CustomOp):
67
+
68
+ @classmethod
69
+ def get_inputs(cls):
70
+ return [
71
+ cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
72
+ ]
73
+
74
+ @classmethod
75
+ def get_outputs(cls):
76
+ return [
77
+ cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
78
+ cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
79
+ cls.io_def('offset_mapping',
80
+ onnx.TensorProto.INT64, [None, None, 2])
81
+ ]
82
+
83
+
84
+ class RobertaTokenizer(CustomOp):
85
+
86
+ @classmethod
87
+ def get_inputs(cls):
88
+ return [
89
+ cls.io_def('input_text', onnx_proto.TensorProto.STRING, [None])
90
+ ]
91
+
92
+ @classmethod
93
+ def get_outputs(cls):
94
+ return [
95
+ cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
96
+ cls.io_def('attention_mask', onnx.TensorProto.INT64, [None, None]),
97
+ cls.io_def('offset_mapping',
98
+ onnx.TensorProto.INT64, [None, None, 2])
99
+ ]
100
+
101
+
102
+ class BpeDecoder(CustomOp):
103
+ @classmethod
104
+ def get_inputs(cls):
105
+ return [
106
+ cls.io_def("ids", onnx.TensorProto.INT64, None)
107
+ ]
108
+
109
+ @classmethod
110
+ def get_outputs(cls):
111
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, None)]
112
+
113
+
114
+ class SpmTokenizer(CustomOp):
115
+
116
+ @classmethod
117
+ def get_inputs(cls):
118
+ return [cls.io_def("input_text", onnx_proto.TensorProto.STRING, [None])]
119
+
120
+ @classmethod
121
+ def get_outputs(cls):
122
+ return [
123
+ cls.io_def("input_ids", onnx.TensorProto.INT64, [None, None]),
124
+ cls.io_def("attention_mask", onnx.TensorProto.INT64, [None, None]),
125
+ cls.io_def("offset_mapping",
126
+ onnx.TensorProto.INT64, [None, None, 2]),
127
+ ]
128
+
129
+
130
+ class VectorToString(CustomOp):
131
+
132
+ @classmethod
133
+ def get_inputs(cls):
134
+ return [cls.io_def("token_ids", onnx.TensorProto.INT64, [])]
135
+
136
+ @classmethod
137
+ def get_outputs(cls):
138
+ return [cls.io_def('text', onnx_proto.TensorProto.STRING, [None])]
139
+
140
+ @classmethod
141
+ def serialize_attr(cls, attrs):
142
+ attr_data = {}
143
+ for k_, v_ in attrs.items():
144
+ if k_ == 'map' and isinstance(v_, dict):
145
+ attr_data[k_] = '\n'.join(k + "\t" +
146
+ " ".join([str(i) for i in v])
147
+ for k, v in v_.items())
148
+ elif k_ == 'map' and isinstance(v_, str):
149
+ attr_data[k_] = v_
150
+ else:
151
+ attr_data[k_] = v_
152
+ return attr_data
153
+
154
+
155
+ class StringMapping(CustomOp):
156
+
157
+ @classmethod
158
+ def get_inputs(cls):
159
+ return [cls.io_def("input", onnx.TensorProto.STRING, [])]
160
+
161
+ @classmethod
162
+ def get_outputs(cls):
163
+ return [cls.io_def('output', onnx_proto.TensorProto.STRING, [])]
164
+
165
+ @classmethod
166
+ def serialize_attr(cls, attrs):
167
+ attr_data = {}
168
+ for k_, v_ in attrs.items():
169
+ if k_ == 'map' and isinstance(v_, dict):
170
+ attr_data[k_] = '\n'.join(k + "\t" + v for k, v in v_.items())
171
+ elif k_ == 'map' and isinstance(v_, str):
172
+ attr_data[k_] = v_
173
+ else:
174
+ attr_data[k_] = v_
175
+ return attr_data
176
+
177
+
178
+ class MaskedFill(CustomOp):
179
+
180
+ @classmethod
181
+ def get_inputs(cls):
182
+ return [
183
+ cls.io_def("value", onnx.TensorProto.STRING, [None]),
184
+ cls.io_def("mask", onnx.TensorProto.BOOL, [None])
185
+ ]
186
+
187
+ @classmethod
188
+ def get_outputs(cls):
189
+ return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
190
+
191
+
192
+ class StringToVector(CustomOp):
193
+
194
+ @classmethod
195
+ def get_inputs(cls):
196
+ return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
197
+
198
+ @classmethod
199
+ def get_outputs(cls):
200
+ return [cls.io_def('token_ids', onnx_proto.TensorProto.INT64, [])]
201
+
202
+ @classmethod
203
+ def serialize_attr(cls, attrs):
204
+ attr_data = {}
205
+ for k_, v_ in attrs.items():
206
+ if k_ == 'map' and isinstance(v_, dict):
207
+ attr_data[k_] = '\n'.join(k + "\t" +
208
+ " ".join([str(i) for i in v])
209
+ for k, v in v_.items())
210
+ elif k_ == 'map' and isinstance(v_, str):
211
+ attr_data[k_] = v_
212
+ elif k_ == 'unk' and isinstance(v_, list):
213
+ attr_data[k_] = ' '.join(str(i) for i in v_)
214
+ else:
215
+ attr_data[k_] = v_
216
+ return attr_data
217
+
218
+
219
+ class BlingFireSentenceBreaker(CustomOp):
220
+
221
+ @classmethod
222
+ def get_inputs(cls):
223
+ return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
224
+
225
+ @classmethod
226
+ def get_outputs(cls):
227
+ return [cls.io_def('sentence', onnx_proto.TensorProto.STRING, [])]
228
+
229
+ @classmethod
230
+ def serialize_attr(cls, attrs):
231
+ attrs_data = {}
232
+ for k_, v_ in attrs.items():
233
+ if k_ == 'model':
234
+ with open(v_, "rb") as model_file:
235
+ attrs_data[k_] = model_file.read()
236
+ else:
237
+ attrs_data[k_] = v_
238
+ return attrs_data
239
+
240
+
241
+ class SegmentExtraction(CustomOp):
242
+
243
+ @classmethod
244
+ def get_inputs(cls):
245
+ return [cls.io_def("input", onnx.TensorProto.INT64, [None, None])]
246
+
247
+ @classmethod
248
+ def get_outputs(cls):
249
+ return [
250
+ cls.io_def('position', onnx_proto.TensorProto.INT64, [None, 2]),
251
+ cls.io_def('value', onnx_proto.TensorProto.INT64, [None])
252
+ ]
253
+
254
+
255
+ class BertTokenizer(CustomOp):
256
+
257
+ @classmethod
258
+ def get_inputs(cls):
259
+ return [cls.io_def("text", onnx.TensorProto.STRING, [None])]
260
+
261
+ @classmethod
262
+ def get_outputs(cls):
263
+ return [
264
+ cls.io_def('input_ids', onnx_proto.TensorProto.INT64, [None]),
265
+ cls.io_def('token_type_ids', onnx_proto.TensorProto.INT64, [None]),
266
+ cls.io_def('attention_mask', onnx_proto.TensorProto.INT64, [None]),
267
+ cls.io_def('offset_mapping', onnx.TensorProto.INT64, [None, 2])
268
+ ]
269
+
270
+ @classmethod
271
+ def serialize_attr(cls, attrs):
272
+ attrs_data = {}
273
+ for k_, v_ in attrs.items():
274
+ if k_ == 'vocab':
275
+ attrs_data['vocab_file'] = v_
276
+ elif k_ == 'vocab_file':
277
+ with open(v_, "r", encoding='utf-8') as model_file:
278
+ lines = model_file.readlines()
279
+ attrs_data[k_] = '\n'.join(lines)
280
+ else:
281
+ attrs_data[k_] = v_
282
+ return attrs_data
283
+
284
+
285
+ class StringECMARegexReplace(CustomOp):
286
+
287
+ @classmethod
288
+ def get_inputs(cls):
289
+ return [
290
+ cls.io_def("input", onnx.TensorProto.STRING, [None]),
291
+ cls.io_def("pattern", onnx.TensorProto.STRING, [None]),
292
+ cls.io_def("rewrite", onnx.TensorProto.STRING, [None])
293
+ ]
294
+
295
+ @classmethod
296
+ def get_outputs(cls):
297
+ return [cls.io_def('output', onnx_proto.TensorProto.STRING, [None])]
298
+
299
+
300
+ class BertTokenizerDecoder(CustomOp):
301
+
302
+ @classmethod
303
+ def get_inputs(cls):
304
+ return [
305
+ cls.io_def("ids", onnx.TensorProto.INT64, [None]),
306
+ cls.io_def("position", onnx.TensorProto.INT64, [None, None])
307
+ ]
308
+
309
+ @classmethod
310
+ def get_outputs(cls):
311
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
312
+
313
+ @classmethod
314
+ def serialize_attr(cls, attrs):
315
+ attrs_data = {}
316
+ for k_, v_ in attrs.items():
317
+ if k_ == 'vocab_file':
318
+ with open(v_, "r", encoding='utf-8') as model_file:
319
+ lines = model_file.readlines()
320
+ attrs_data[k_] = '\n'.join(lines)
321
+ else:
322
+ attrs_data[k_] = v_
323
+ return attrs_data
324
+
325
+
326
+ class SentencepieceTokenizer(CustomOp):
327
+
328
+ @classmethod
329
+ def get_inputs(cls):
330
+ return [
331
+ cls.io_def('inputs', onnx_proto.TensorProto.STRING, [None]),
332
+ cls.io_def('nbest_size', onnx_proto.TensorProto.INT64, [None]),
333
+ cls.io_def('alpha', onnx_proto.TensorProto.FLOAT, [None]),
334
+ cls.io_def('add_bos', onnx_proto.TensorProto.BOOL, [None]),
335
+ cls.io_def('add_eos', onnx_proto.TensorProto.BOOL, [None]),
336
+ cls.io_def('reverse', onnx_proto.TensorProto.BOOL, [None]),
337
+ cls.io_def('fairseq', onnx_proto.TensorProto.BOOL, [None])
338
+ ]
339
+
340
+ # beyond Python 3.7, the order of the dict is guaranteed to be insertion order
341
+ @classmethod
342
+ def input_default_values(cls):
343
+ return {
344
+ 'nbest_size': [0],
345
+ 'alpha': [0],
346
+ 'add_bos': [False],
347
+ 'add_eos': [False],
348
+ 'reverse': [False],
349
+ 'fairseq': [False]
350
+ }
351
+
352
+ @classmethod
353
+ def get_outputs(cls):
354
+ return [
355
+ cls.io_def('tokens', onnx_proto.TensorProto.INT32, [None]),
356
+ cls.io_def('instance_indices',
357
+ onnx_proto.TensorProto.INT64, [None]),
358
+ cls.io_def('token_indices', onnx_proto.TensorProto.INT32, [None])
359
+ ]
360
+
361
+
362
+ class SentencepieceDecoder(CustomOp):
363
+
364
+ @classmethod
365
+ def get_inputs(cls):
366
+ return [
367
+ cls.io_def("ids", onnx.TensorProto.INT64, [None]),
368
+ cls.io_def('fairseq', onnx_proto.TensorProto.BOOL, [None])
369
+ ]
370
+
371
+ @classmethod
372
+ def input_default_values(cls):
373
+ return {
374
+ 'fairseq': [False]
375
+ }
376
+
377
+ @classmethod
378
+ def get_outputs(cls):
379
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
380
+
381
+
382
+ class TrieTokenizer(CustomOp):
383
+ @classmethod
384
+ def get_inputs(cls):
385
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
386
+
387
+ @classmethod
388
+ def get_outputs(cls):
389
+ return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
390
+
391
+
392
+ class TrieDetokenizer(CustomOp):
393
+ @classmethod
394
+ def get_inputs(cls):
395
+ return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
396
+
397
+ @classmethod
398
+ def get_outputs(cls):
399
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, [None])]
400
+
401
+
402
+ class Inverse(CustomOp):
403
+
404
+ @classmethod
405
+ def get_inputs(cls):
406
+ return [
407
+ cls.io_def('input', onnx_proto.TensorProto.FLOAT, [None, None])
408
+ ]
409
+
410
+ @classmethod
411
+ def get_outputs(cls):
412
+ return [
413
+ cls.io_def('output', onnx_proto.TensorProto.FLOAT, [None, None])
414
+ ]
415
+
416
+
417
+ class ImageReader(CustomOp):
418
+
419
+ @classmethod
420
+ def get_inputs(cls):
421
+ return [
422
+ cls.io_def('image_paths', onnx_proto.TensorProto.STRING, [None])
423
+ ]
424
+
425
+ @classmethod
426
+ def get_outputs(cls):
427
+ return [
428
+ cls.io_def('nchw_bytes', onnx_proto.TensorProto.UINT8,
429
+ [None, None, None, None])
430
+ ]
431
+
432
+
433
+ class GaussianBlur(CustomOp):
434
+
435
+ @classmethod
436
+ def get_inputs(cls):
437
+ return [
438
+ cls.io_def('nhwc', onnx_proto.TensorProto.FLOAT,
439
+ [None, None, None, None]),
440
+ cls.io_def('kernel_size', onnx_proto.TensorProto.INT64, [None]),
441
+ cls.io_def('sigma_xy', onnx_proto.TensorProto.DOUBLE, [None])
442
+ ]
443
+
444
+ @classmethod
445
+ def get_outputs(cls):
446
+ return [
447
+ cls.io_def('gb_nhwc', onnx_proto.TensorProto.FLOAT,
448
+ [None, None, None, None])
449
+ ]
450
+
451
+
452
+ class ImageDecoder(CustomOp):
453
+
454
+ @classmethod
455
+ def get_inputs(cls):
456
+ return [
457
+ cls.io_def('raw_input_image', onnx_proto.TensorProto.UINT8, [])
458
+ ]
459
+
460
+ @classmethod
461
+ def get_outputs(cls):
462
+ return [
463
+ cls.io_def('decoded_image',
464
+ onnx_proto.TensorProto.UINT8, [None, None, 3])
465
+ ]
466
+
467
+
468
+ class AudioDecoder(CustomOp):
469
+ @classmethod
470
+ def get_inputs(cls):
471
+ return [
472
+ cls.io_def('audio_stream', onnx_proto.TensorProto.UINT8, [1, None])
473
+ ]
474
+
475
+ @classmethod
476
+ def get_outputs(cls):
477
+ return [
478
+ cls.io_def('floatPCM', onnx_proto.TensorProto.FLOAT, [1, None])
479
+ ]
480
+
481
+
482
+ class StftNorm(CustomOp):
483
+ @classmethod
484
+ def get_inputs(cls):
485
+ return [
486
+ cls.io_def('pcm_wave', onnx_proto.TensorProto.FLOAT, [1, None]),
487
+ cls.io_def('n_fft', onnx_proto.TensorProto.INT64, []),
488
+ cls.io_def('hop_length', onnx_proto.TensorProto.INT64, []),
489
+ cls.io_def('window', onnx_proto.TensorProto.FLOAT, [None]),
490
+ cls.io_def('frame_size', onnx_proto.TensorProto.INT64, []),
491
+ ]
492
+
493
+ @classmethod
494
+ def get_outputs(cls):
495
+ return [
496
+ cls.io_def('stft_norm', onnx_proto.TensorProto.FLOAT,
497
+ [1, None, None])
498
+ ]
499
+
500
+
501
+ class HfJsonTokenizer(CustomOp):
502
+ @classmethod
503
+ def get_inputs(cls):
504
+ return [cls.io_def('str', onnx_proto.TensorProto.STRING, ['N'])]
505
+
506
+ @classmethod
507
+ def get_outputs(cls):
508
+ return [cls.io_def("ids", onnx.TensorProto.INT64, ['N', None])]
509
+
510
+
511
+ # TODO: have a C++ impl.
512
+ def _argsort_op(x, dim):
513
+ d = numpy.argsort(x, dim)
514
+ return d[:, ::-1]
515
+
516
+
517
+ Opdef.create(_argsort_op,
518
+ op_type='ArgSort',
519
+ inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_int64],
520
+ outputs=[PyCustomOpDef.dt_int64])
521
+
522
+
523
+ class CustomOpConverter:
524
+ pass
525
+
526
+
527
+ class SingleOpGraph:
528
+
529
+ @classmethod
530
+ def get_next_id(cls):
531
+ if not hasattr(cls, '_id_counter'):
532
+ cls._id_counter = 0
533
+ cls._id_counter += 1
534
+ return cls._id_counter
535
+
536
+ @classmethod
537
+ def build_graph(cls, op_class, *args, **kwargs):
538
+ if isinstance(op_class, str):
539
+ op_class = cls.get_op_class(op_class)
540
+
541
+ cvt = kwargs.pop('cvt', None)
542
+ if cvt is None and len(args) > 0 and isinstance(args[0], CustomOpConverter):
543
+ cvt = args[0]
544
+ args = args[1:]
545
+
546
+ new_kwargs = kwargs if cvt is None else cvt(**kwargs)
547
+
548
+ op_type = op_class.op_type()
549
+ inputs = op_class.get_inputs()
550
+ outputs = op_class.get_outputs()
551
+ attrs = op_class.serialize_attr(new_kwargs)
552
+ cuop = onnx.helper.make_node(op_type, [i_.name for i_ in inputs],
553
+ [o_.name for o_ in outputs],
554
+ "{}_{}".format(op_type,
555
+ cls.get_next_id()),
556
+ **attrs,
557
+ domain=default_opset_domain())
558
+ graph = onnx.helper.make_graph([cuop], "og_{}_{}".format(
559
+ op_type, cls.get_next_id()), inputs, outputs)
560
+ return graph
561
+
562
+ @staticmethod
563
+ def get_op_class(op_type):
564
+ return globals()[op_type]
@@ -0,0 +1,45 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+ from typing import Callable
6
+
7
+
8
+ class PyCustomOpDef:
9
+ undefined: int = ...
10
+ dt_float: int = ...
11
+ dt_uint8: int = ...
12
+ dt_int8: int = ...
13
+ dt_uint16: int = ...
14
+ dt_int16: int = ...
15
+ dt_int32: int = ...
16
+ dt_int64: int = ...
17
+ dt_string: int = ...
18
+ dt_bool: int = ...
19
+ dt_float16: int = ...
20
+ dt_double: int = ...
21
+ dt_uint32: int = ...
22
+ dt_uint64: int = ...
23
+ dt_complex64: int = ...
24
+ dt_complex128: int = ...
25
+ dt_bfloat16: int = ...
26
+
27
+ def install_hooker(self, invocation_handler: Callable) -> None:
28
+ ...
29
+ ...
30
+
31
+
32
+ def enable_py_op(enabled: bool) -> bool:
33
+ ...
34
+
35
+
36
+ def add_custom_op(opdef: PyCustomOpDef) -> None:
37
+ ...
38
+
39
+
40
+ def hash_64(s: str, num_buckets: int, fast: int) -> int:
41
+ ...
42
+
43
+
44
+ def default_opset_domain() -> str:
45
+ ...