onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,628 @@
1
+ import torch
2
+ import builtins
3
+ import functools
4
+ import numpy as np
5
+ from onnx import onnx_pb as onnx_proto
6
+ from typing import List, Tuple, Optional, Union, Any, ContextManager, overload, Iterator, NamedTuple
7
+ from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout # noqa
8
+ from torch import strided, memory_format, contiguous_format, StringType # noqa
9
+
10
+ from ._onnx_ops import ox as _ox
11
+ from .._ortapi2 import OrtPyFunction
12
+
13
+
14
+ class _EagerTensor:
15
+ def __init__(self, _t, name=None, sess=None, raw_data: Any = None):
16
+ self._t = _t if isinstance(_t, torch.Tensor) else torch.tensor(_t)
17
+ if isinstance(name, (tuple, list)):
18
+ assert len(name) == 1, "Multiple names for one tensor!"
19
+ name = name[0]
20
+ self.name = '' if name is None else name
21
+ self.raw_data = raw_data
22
+ self.symbolic_shape = []
23
+
24
+ def __repr__(self):
25
+ if self.raw_data is not None:
26
+ return "name: {}, \"{}\"".format(self.name, str(self.raw_data))
27
+ else:
28
+ return "name: {}, {}, dtype={}".format(self.name, repr(self._t), str(self._t.dtype))
29
+
30
+ _all_ops = {}
31
+
32
+ @property
33
+ def value(self) -> Union[torch.Tensor, Any]:
34
+ return self.raw_data if self.raw_data else self._t
35
+
36
+ @property
37
+ def t(self):
38
+ return self._t
39
+
40
+ @property
41
+ def dtype(self):
42
+ return self._t.dtype
43
+
44
+ @property
45
+ def onnx_type(self):
46
+ return self.to_onnx_type(self._t.dtype)
47
+
48
+ @classmethod
49
+ def is_numeric(cls, np_arr):
50
+ return np_arr.dtype.kind in set('buifc')
51
+
52
+ @classmethod
53
+ def set_active_session(cls, sess):
54
+ """
55
+ set the active operator tracing log session. if sess is None, the active session will be removed
56
+ :param sess:
57
+ :return:
58
+ """
59
+ if not hasattr(cls, '_active_session'):
60
+ cls._active_session = sess
61
+ if sess is None:
62
+ raise RuntimeError("unset the active session twice!")
63
+ else:
64
+ if sess is not None:
65
+ raise RuntimeError("The active session already assigned!")
66
+ delattr(cls, '_active_session')
67
+
68
+ @classmethod
69
+ def get_trace_session(cls):
70
+ if not hasattr(cls, '_active_session'):
71
+ raise RuntimeError("the tracing not started yet!")
72
+ return cls._active_session # noqa
73
+
74
+ @classmethod
75
+ def get_container(cls):
76
+ return cls.get_trace_session().container
77
+
78
+ @classmethod
79
+ def from_onnx(cls, raw_val, ort_sess, name):
80
+ raw_data = None
81
+ if cls.is_numeric(raw_val):
82
+ val = torch.from_numpy(raw_val)
83
+ else:
84
+ # only keep the shape and the value was stored by it-self.
85
+ val = torch.empty(*raw_val.shape, dtype=torch.uint8)
86
+ raw_data = raw_val
87
+ t = cls(val, name, ort_sess, raw_data)
88
+ return t
89
+
90
+ @classmethod
91
+ def from_torch(cls, _t, name):
92
+ t_name = name if name is not None else "id_{}".format(id(_t))
93
+ ts = cls(_t, t_name)
94
+ return ts
95
+
96
+ @classmethod
97
+ # torch.tensor prototype
98
+ def mytensor(cls, data: Any, dtype: Optional[_dtype] = None, device: Union[_device, str, None] = None, requires_grad: _bool = False): # noqa
99
+ y = torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
100
+ val = _ox.make_tensor(cls.to_onnx_type(y.dtype), list(y.size()),
101
+ [data] if isinstance(data, (int, float, str, bool)) else data)
102
+ s = _ox.constant([], [_ox.get_unique_tensor_name('const')], cls.get_container(), None, value=val)
103
+ return cls.from_torch(y, s)
104
+
105
+ def numpy(self):
106
+ return self._t.numpy() if self.raw_data is None else self.raw_data
107
+
108
+ def item(self):
109
+ return self.numpy().item()
110
+
111
+ def get_shape(self):
112
+ return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape
113
+
114
+ def _to_binary_tensor_args(self, other):
115
+ # convert self, other to [self, other], but if either is a number, convert that to a constant
116
+ x, y = self, other
117
+ if isinstance(y, (int, float, bool, np.ndarray)):
118
+ y = self.mytensor(y)
119
+ elif isinstance(x, (int, float, bool, np.ndarray)):
120
+ x = self.mytensor(x)
121
+ return x, y
122
+
123
+ _dup_id = 0
124
+
125
+ def __copy__(self):
126
+ new_t = _EagerTensor.from_torch(self.t, self.name + '_{}'.format(_EagerTensor._dup_id))
127
+ self._dup_id += 1
128
+ new_t.raw_data = self.raw_data
129
+ return new_t
130
+
131
+ def __add__(self, other):
132
+ x0, x1 = self._to_binary_tensor_args(other)
133
+ y = torch.add(x0._t, x1._t)
134
+ s = _ox.add(*_EagerTensor.ox_args([x0, x1]))
135
+ return self.from_torch(y, s)
136
+
137
+ def __sub__(self, other):
138
+ x0, x1 = self._to_binary_tensor_args(other)
139
+ y = torch.sub(x0._t, x1._t)
140
+ s = _ox.sub(*_EagerTensor.ox_args([x0, x1]))
141
+ return self.from_torch(y, s)
142
+
143
+ def __mul__(self, other):
144
+ x0, x1 = self._to_binary_tensor_args(other)
145
+ y = torch.mul(x0._t, x1._t)
146
+ s = _ox.mul(*_EagerTensor.ox_args([x0, x1]))
147
+ return self.from_torch(y, s)
148
+
149
+ def __div__(self, other):
150
+ x0, x1 = self._to_binary_tensor_args(other)
151
+ y = torch.div(x0._t, x1._t)
152
+ s = _ox.div(*_EagerTensor.ox_args([x0, x1]))
153
+ return self.from_torch(y, s)
154
+
155
+ def __pow__(self, other):
156
+ x0, x1 = self._to_binary_tensor_args(other)
157
+ y = torch.pow(x0._t, x1._t)
158
+ s = _ox.pow(*_EagerTensor.ox_args([x0, x1]))
159
+ return self.from_torch(y, s)
160
+
161
+ def __matmul__(self, other):
162
+ x0, x1 = self._to_binary_tensor_args(other)
163
+ y = torch.matmul(x0._t, x1._t)
164
+ s = _ox.matmul(*_EagerTensor.ox_args([x0, x1]))
165
+ return self.from_torch(y, s)
166
+
167
+ def __lt__(self, other):
168
+ x0, x1 = self._to_binary_tensor_args(other)
169
+ y = torch.less(x0._t, x1._t)
170
+ s = _ox.less(*_EagerTensor.ox_args([x0, x1]))
171
+ return self.from_torch(y, s)
172
+
173
+ def __le__(self, other):
174
+ x0, x1 = self._to_binary_tensor_args(other)
175
+ y = torch.less_equal(x0._t, x1._t)
176
+ s = _ox.less_equal(*_EagerTensor.ox_args([x0, x1]))
177
+ return self.from_torch(y, s)
178
+
179
+ def __eq__(self, other):
180
+ x0, x1 = self._to_binary_tensor_args(other)
181
+ y = torch.equal(x0._t, x1._t)
182
+ s = _ox.equal(*_EagerTensor.ox_args([x0, x1]))
183
+ return self.from_torch(y, s)
184
+
185
+ def __ne__(self, other):
186
+ x0, x1 = self._to_binary_tensor_args(other)
187
+ y = torch.not_equal(x0._t, x1._t)
188
+ s = _ox.not_equal(*_EagerTensor.ox_args([x0, x1]))
189
+ return self.from_torch(y, s)
190
+
191
+ def __gt__(self, other):
192
+ x0, x1 = self._to_binary_tensor_args(other)
193
+ y = torch.greater(x0._t, x1._t)
194
+ s = _ox.greater(*_EagerTensor.ox_args([x0, x1]))
195
+ return self.from_torch(y, s)
196
+
197
+ def __ge__(self, other):
198
+ x0, x1 = self._to_binary_tensor_args(other)
199
+ y = torch.greater_equal(x0._t, x1._t)
200
+ s = _ox.greater_equal(*_EagerTensor.ox_args([x0, x1]))
201
+ return self.from_torch(y, s)
202
+
203
+ def __invert__(self):
204
+ if self.t.dtype is torch.bool:
205
+ y = torch.logical_not(self.t)
206
+ s = _ox.not_op(*self.my_args())
207
+ return self.from_torch(y, s)
208
+ else:
209
+ raise NotImplementedError("no numeric tensor inverse supported yet.")
210
+
211
+ def __neg__(self):
212
+ y = torch.neg([self.t])
213
+ s = _ox.neg(*self.my_args())
214
+ return self.from_torch(y, s)
215
+
216
+ def __not__(self):
217
+ y = torch.logical_not(self.t)
218
+ s = _ox.not_op(*self.my_args())
219
+ return self.from_torch(y, s)
220
+
221
+ def __or__(self, other):
222
+ x0, x1 = self._to_binary_tensor_args(other)
223
+ y = torch.logical_or(x0._t, x1._t)
224
+ s = _ox.or_op(*_EagerTensor.ox_args([x0, x1]))
225
+ return self.from_torch(y, s)
226
+
227
+ def __getitem__(self, indices):
228
+ y = self.value.__getitem__(indices)
229
+
230
+ # normalize indices to tuples of slices
231
+ # Formats encountered:
232
+ # - a single int
233
+ # - a tuple of (int or slice)
234
+ if not isinstance(indices, (tuple, list)): # single item: make it a tuple
235
+ indices = (indices,)
236
+ squeeze = [axis for axis, index in enumerate(indices) if
237
+ isinstance(index, int)] # which axes had a single index?
238
+ indices = tuple(
239
+ index if isinstance(index, slice) else slice(index, index + 1 if index != -1 else None, 1) for index in
240
+ indices) # make all tuple items of type Slice
241
+ bs, es, ss, ds = [], [], [], []
242
+ INT_MAX = 2 ** 63 - 1
243
+ for axis, index in enumerate(indices):
244
+ if not isinstance(index, slice):
245
+ raise ValueError("Index expected")
246
+ if index.start is None and index.stop is None: # [:] can be skipped
247
+ continue
248
+ b, e, s = index.start, index.stop, index.step
249
+ bs.append(b if b is not None else 0)
250
+ es.append(e if e is not None else INT_MAX)
251
+ ss.append(s if s is not None else 1)
252
+ ds.append(axis)
253
+ s = _ox.slice(*self.my_args(), starts=bs, ends=es, axes=ds, steps=ss)
254
+ if squeeze: # single index means we must drop the axis
255
+ s = _ox.squeeze(*self.ox_name_args(s), axes=squeeze)
256
+
257
+ return self.from_torch(y, s)
258
+
259
+ def __getattribute__(self, attr):
260
+ """
261
+ A little hack that allows to call unary operators in a chaining fashion,
262
+ e.g. x.shape() instead of ox.shape(x).
263
+ """
264
+ if attr in _EagerTensor._all_ops:
265
+ f = _EagerTensor._all_ops[attr]
266
+ return functools.partial(f, self)
267
+ else:
268
+ return object.__getattribute__(self, attr)
269
+
270
+ @classmethod
271
+ def ox_name_args(cls, input_names, output_names=None):
272
+ """
273
+ generate the arguments for ONNX model builder.
274
+ :param input_names: input name list
275
+ :param output_names: output name list, can be None, or [None]*output_n
276
+ :return: input_names, output_names, container, operator_name
277
+ """
278
+ container = cls.get_trace_session().container
279
+ if output_names is None:
280
+ output_names = [None] # by default, there is only one output
281
+
282
+ output_names = [_ox.get_unique_tensor_name(str(n_))
283
+ if output_names[n_] is None else
284
+ output_names[n_] for n_ in range(len(output_names))]
285
+ operator_name = None
286
+ return input_names, output_names, container, operator_name
287
+
288
+ @classmethod
289
+ def ort_verify(cls, ts_from, ts_to):
290
+ result, model = cls.get_trace_session().runops(ts_from, ts_to)
291
+ for idx in range(len(ts_to)):
292
+ if not np.allclose(ts_to[idx].numpy(), result[idx]):
293
+ # ONNX cannot be import globally, which is conflict with torch.onnx
294
+ import onnx # noqa
295
+ onnx.save_model(model, 'mt_debmodel.onnx')
296
+ raise RuntimeError("ONNXRuntime Result is not same pytorch!")
297
+
298
+ def create_and_verify(self, value, name, additional_inputs=None):
299
+ ts_y = self.from_torch(value, name)
300
+ inputs = [self] + ([] if additional_inputs is None else additional_inputs)
301
+ self.ort_verify(inputs, [ts_y])
302
+ return ts_y
303
+
304
+ @classmethod
305
+ def ox_args(cls, tensors, output_names=None):
306
+ input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors]
307
+ return cls.ox_name_args(input_names, output_names)
308
+
309
+ def my_args(self):
310
+ return self.ox_args([self])
311
+
312
+ @staticmethod
313
+ def normalize_seq(list_or_tuple):
314
+ return [x.value.item() if isinstance(x, _EagerTensor) else x for x in list_or_tuple]
315
+
316
+ @staticmethod
317
+ def to_onnx_type(torch_type):
318
+ ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
319
+ torch.float32: onnx_proto.TensorProto.FLOAT,
320
+ torch.long: onnx_proto.TensorProto.INT64,
321
+ torch.int32: onnx_proto.TensorProto.INT32}
322
+ # ...
323
+ return ty_dict.get(torch_type, onnx_proto.TensorProto.STRING)
324
+
325
+ def long(self):
326
+ y = self._t.long()
327
+ s = _ox.cast(*self.my_args(), to=onnx_proto.TensorProto.INT64)
328
+ return self.create_and_verify(y, s[0])
329
+
330
+ def cumsum(self, dim: _int, *, dtype: Optional[_dtype] = None): # noqa
331
+ y = self._t.cumsum(dim, dtype=dtype)
332
+ s = _ox.cumsum(*self.my_args(), axis=dim)
333
+ return self.create_and_verify(y, s[0])
334
+
335
+ def size(self):
336
+ y = self._t.size()
337
+ s = _ox.shape(*self.my_args())
338
+ return self.create_and_verify(y, s[0])
339
+
340
+ def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False):
341
+ y = self._t.type(dtype, non_blocking)
342
+ s = _ox.cast(*self.my_args(), to=self.to_onnx_type(dtype))
343
+ return self.create_and_verify(y, s)
344
+
345
+ def to(self, device):
346
+ y = self._t.to(device)
347
+ s = _ox.identity(*self.my_args())
348
+ return self.create_and_verify(y, s[0])
349
+
350
+ def cpu(self):
351
+ y = self._t.cpu()
352
+ s = _ox.identity(*self.my_args())
353
+ return self.create_and_verify(y, s[0])
354
+
355
+ def detach(self):
356
+ y = self._t.detach()
357
+ s = _ox.identity(*self.my_args())
358
+ return self.create_and_verify(y, s[0])
359
+
360
+ def clone(self):
361
+ y = self._t.clone()
362
+ s = _ox.identity(*self.my_args())
363
+ return self.create_and_verify(y, s[0])
364
+
365
+ def masked_fill(self, mask, value):
366
+ y = self._t.masked_fill(mask.value, value)
367
+ if not isinstance(value, _EagerTensor):
368
+ value = _EagerTensor.mytensor(value)
369
+ s = _ox.where(*_EagerTensor.ox_args([mask, value, self]))
370
+ return self.create_and_verify(y, s[0], additional_inputs=[mask, value])
371
+
372
+ def unsqueeze(self, dim: _int):
373
+ y = self._t.unsqueeze(dim)
374
+ s = _ox.unsqueeze(*self.my_args(), [dim])
375
+ return self.create_and_verify(y, s[0])
376
+
377
+ def squeeze(self, dim: _int):
378
+ y = self._t.squeeze(dim)
379
+ s = _ox.squeeze(*self.my_args(), [dim])
380
+ return self.create_and_verify(y, s[0])
381
+
382
+
383
+ def _create_ox_sequence(*size):
384
+ container = _EagerTensor.get_container()
385
+ con_x = []
386
+ if builtins.any(isinstance(n_, _EagerTensor) for n_ in size):
387
+ for x in size:
388
+ if isinstance(x, _EagerTensor):
389
+ x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0]
390
+ else:
391
+ x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x])
392
+ x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0]
393
+ con_x.append(x_h)
394
+ return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
395
+ else:
396
+ ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size)
397
+ return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
398
+
399
+
400
+ def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None):
401
+ if onnx_type is None:
402
+ onnx_type = onnx_proto.TensorProto.FLOAT
403
+ names = _create_ox_sequence(*size)
404
+ ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
405
+
406
+ container = _EagerTensor.get_container()
407
+ s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
408
+ return s[0]
409
+
410
+
411
+ def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
412
+ dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None,
413
+ requires_grad: _bool = False) -> _EagerTensor: # noqa
414
+
415
+ if len(size) == 1 and isinstance(size[0], list):
416
+ size = size[0]
417
+ n_size = _EagerTensor.normalize_seq(size)
418
+ y = torch.empty(*n_size, memory_format=memory_format, out=out,
419
+ dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
420
+ s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
421
+ return _EagerTensor.from_torch(y, s)
422
+
423
+
424
+ def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
425
+ device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
426
+
427
+ if len(size) == 1 and isinstance(size[0], list):
428
+ size = size[0]
429
+ n_size = _EagerTensor.normalize_seq(size)
430
+ y = torch.zeros(*n_size, out=out, dtype=dtype,
431
+ layout=layout, device=device, requires_grad=requires_grad)
432
+ s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
433
+ return _EagerTensor.from_torch(y, s)
434
+
435
+
436
+ def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
437
+ device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
438
+
439
+ if len(size) == 1 and isinstance(size[0], list):
440
+ size = size[0]
441
+ n_size = _EagerTensor.normalize_seq(size)
442
+ y = torch.ones(*n_size, out=out, dtype=dtype,
443
+ layout=layout, device=device, requires_grad=requires_grad)
444
+ s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
445
+ return _EagerTensor.from_torch(y, s)
446
+
447
+
448
+ def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa
449
+
450
+ if len(repeats) == 1 and isinstance(repeats[0], list):
451
+ repeats = repeats[0]
452
+ n_size = _EagerTensor.normalize_seq(repeats)
453
+ y = input_ts.t.repeat(*n_size)
454
+ seq = _create_ox_sequence(*repeats)
455
+ s = _ox.tile(*input_ts.my_args(), repeats=seq[0])
456
+ return _EagerTensor.from_torch(y, s[0])
457
+
458
+
459
+ def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa
460
+ y = torch.argmax(input_ts.value, dim, keepdim)
461
+ s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim)
462
+ return _EagerTensor.from_torch(y, s)
463
+
464
+
465
+ def softmax(input_ts: _EagerTensor, dim: _int, dtype: Optional[_dtype]=None) -> _EagerTensor:
466
+ y = torch.softmax(input_ts.value, dim, dtype)
467
+ s = _ox.softmax(*input_ts.my_args(), axis=dim)
468
+ return _EagerTensor.from_torch(y, s)
469
+
470
+
471
+ def cat(tensors: Union[Tuple[_EagerTensor, ...], List[_EagerTensor]],
472
+ dim, *, out: Optional[_EagerTensor] = None) -> _EagerTensor: # noqa
473
+ res = torch.cat([t_.value for t_ in tensors], dim, out=out)
474
+ oname = _ox.concat(*_EagerTensor.ox_args(tensors), dim)
475
+ y = _EagerTensor.from_torch(res, oname[0])
476
+ _EagerTensor.ort_verify(tensors, [y])
477
+ return y
478
+
479
+
480
+ def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
481
+ container = _EagerTensor.get_container()
482
+ y = torch.all(input_ts.value)
483
+ s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
484
+ s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1])
485
+ s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
486
+ container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
487
+ s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
488
+ return input_ts.create_and_verify(y, s[0])
489
+
490
+
491
+ def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
492
+ container = _EagerTensor.get_container()
493
+ y = torch.any(input_ts.value)
494
+ s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
495
+ s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1])
496
+ s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
497
+ container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
498
+ s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
499
+ return input_ts.create_and_verify(y, s[0])
500
+
501
+
502
+ def reshape(input_ts: _EagerTensor, shape: _size):
503
+ y = input_ts.t.reshape(shape)
504
+ s = _ox.reshape(*input_ts.my_args(), desired_shape=shape)
505
+ return input_ts.create_and_verify(y, s[0])
506
+
507
+
508
+ def transpose(input_ts: _EagerTensor, dim0: _int, dim1: _int):
509
+ y = input_ts.t.transpose(dim0, dim1)
510
+ axes = list(range(y.dim()))
511
+ axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
512
+ s = _ox.transpose(*input_ts.my_args(), perm=axes)
513
+ return input_ts.create_and_verify(y, s[0])
514
+
515
+
516
+ class _LoopIterator:
517
+ def __init__(self, ctx):
518
+ self.context = ctx
519
+
520
+ def __iter__(self):
521
+ return self
522
+
523
+ def __next__(self):
524
+ if self.context.is_stopped():
525
+ _EagerTensor.get_trace_session().pop_container()
526
+ raise StopIteration
527
+ return self.context.current()
528
+
529
+
530
+ class _ControlFlowContext:
531
+ def __init__(self):
532
+ self.condition_i = None
533
+ self.condition = None
534
+ self.loop_count = None
535
+ self.iteration_num = None
536
+ self.states_i = []
537
+ self.loop_states = []
538
+ self.scan_outputs = []
539
+ self.sub_graph = None
540
+
541
+ def flow_output(self, cond, *outputs):
542
+ assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects"
543
+ if self.sub_graph is None:
544
+ trc = _EagerTensor.get_trace_session()
545
+ self.sub_graph = trc.build_graph(trc.container,
546
+ [self.iteration_num, self.condition] + self.loop_states,
547
+ [cond] + list(outputs))
548
+
549
+ self.condition = cond
550
+ c_state = len(self.loop_states)
551
+ self.loop_states = list(outputs[:c_state])
552
+ if len(self.scan_outputs) == 0:
553
+ sc = [_EagerTensor(torch.unsqueeze(sci_.value, 0), 'sc_' + sci_.name) for sci_ in outputs[c_state:]]
554
+ self.scan_outputs = sc
555
+ else:
556
+ next_extra_vars = []
557
+ for idx_, ext_ in enumerate(outputs[c_state:]):
558
+ et = self.scan_outputs[idx_]
559
+ next_extra_vars.append(_EagerTensor(
560
+ torch.cat([et.value, torch.unsqueeze(outputs[c_state + idx_].value, 0)]), name=et.name))
561
+ self.scan_outputs = next_extra_vars
562
+ self.iteration_num.value.add_(1)
563
+
564
+ def current(self):
565
+ return [self.iteration_num] + list(self.loop_states)
566
+
567
+ def finalize(self):
568
+ # generate the outputs from the enclosing scope variables
569
+ full_outputs = [_EagerTensor(o_.value, 'lp_' + o_.name) for o_ in self.loop_states + self.scan_outputs]
570
+ _ox.loop(*_EagerTensor.ox_args(
571
+ [self.loop_count, self.condition_i] + list(self.states_i),
572
+ [ts_.name for ts_ in full_outputs]), body=self.sub_graph)
573
+ return tuple(full_outputs)
574
+
575
+ def is_stopped(self):
576
+ return self.condition.item() is False or self.iteration_num.item() >= self.loop_count.item()
577
+
578
+ def loop(self, loop_c, condition, *states):
579
+ self.condition = condition
580
+ self.condition_i = condition
581
+ self.states_i = states
582
+ _EagerTensor.get_trace_session().stack_container()
583
+ self.iteration_num = _EagerTensor.mytensor(0)
584
+ # clone the variables for the sub graph.
585
+ self.loop_states = [_EagerTensor(st_.value, st_.name) for st_ in states]
586
+ self.loop_count = loop_c
587
+ loop_b = _LoopIterator(self)
588
+ return iter(loop_b)
589
+
590
+
591
+ def control_flow():
592
+ return _ControlFlowContext()
593
+
594
+
595
+ class _TracingEagerOp(OrtPyFunction):
596
+ def __call__(self, *args, **kwargs):
597
+ np_args = [ts_.numpy() if isinstance(ts_, _EagerTensor) else ts_ for ts_ in args]
598
+ outseq = super().__call__(*np_args, **kwargs)
599
+ outseq = outseq if isinstance(outseq, (list, tuple)) else [outseq]
600
+
601
+ outputs = [_EagerTensor.from_onnx(outseq[n_], self.ort_session, out_.name)
602
+ for n_, out_ in enumerate(self.ort_session.get_outputs())]
603
+
604
+ y_names = [y.name for y in outputs]
605
+ _ox.model_call(*_EagerTensor.ox_args(args, output_names=y_names), oxml=self.onnx_model)
606
+ return tuple(outputs) if len(outputs) > 1 else outputs[0]
607
+
608
+
609
+ def op_from_customop(op_type, *args, **kwargs) -> _TracingEagerOp:
610
+ return _TracingEagerOp.from_customop(op_type, *args, **kwargs)
611
+
612
+
613
+ def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp:
614
+ return _TracingEagerOp.from_model(path_or_model, *args, **kwargs)
615
+
616
+
617
+ _EagerTensor._all_ops = {'argmax': argmax,
618
+ 'softmax': softmax,
619
+ 'reshape': reshape,
620
+ 'transpose': transpose,
621
+ 'repeat': repeat,
622
+ 'any': any,
623
+ 'all': all}
624
+
625
+ tensor = _EagerTensor.mytensor
626
+ tensor_from_onnx = _EagerTensor.from_onnx
627
+ tensor_from_torch = _EagerTensor.from_torch
628
+ tensor_set_session = _EagerTensor.set_active_session
@@ -0,0 +1,31 @@
1
+ try:
2
+ import torch
3
+ except ImportError:
4
+ raise RuntimeError("pytorch not installed, which is required by this ONNX build tool")
5
+
6
+ from torch import (float32,
7
+ float,
8
+ float64,
9
+ double,
10
+ float16,
11
+ bfloat16,
12
+ half,
13
+ uint8,
14
+ int8,
15
+ int16,
16
+ short,
17
+ int32,
18
+ int,
19
+ int64,
20
+ long,
21
+ complex64,
22
+ cfloat,
23
+ complex128,
24
+ cdouble,
25
+ quint8,
26
+ qint8,
27
+ qint32,
28
+ bool) # noqa
29
+
30
+ from torch import randn, onnx # noqa
31
+ from ._tensor import * # noqa
@@ -0,0 +1,13 @@
1
+ # onnxruntime-extensions pre&post processing frontend depends on the PyTorch
2
+ try:
3
+ import torch
4
+ except ImportError as e:
5
+ print("No torch installation found, which is required by the pre&post scripting!")
6
+ raise e
7
+
8
+ from ._base import ProcessingTracedModule, ProcessingScriptModule, CustomFunction
9
+ from ._torchext import * # noqa
10
+ from ._unifier import export
11
+
12
+ from ._imagenet import * # noqa
13
+ from ._nlp import PreHuggingFaceGPT2, PreHuggingFaceBert, HfBertTokenizer # noqa