onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime_extensions/__init__.py +82 -0
- onnxruntime_extensions/_cuops.py +564 -0
- onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
- onnxruntime_extensions/_extensions_pydll.pyi +45 -0
- onnxruntime_extensions/_hf_cvt.py +331 -0
- onnxruntime_extensions/_ocos.py +133 -0
- onnxruntime_extensions/_ortapi2.py +274 -0
- onnxruntime_extensions/_torch_cvt.py +231 -0
- onnxruntime_extensions/_version.py +2 -0
- onnxruntime_extensions/cmd.py +66 -0
- onnxruntime_extensions/cvt.py +306 -0
- onnxruntime_extensions/onnxprocess/__init__.py +12 -0
- onnxruntime_extensions/onnxprocess/_builder.py +53 -0
- onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
- onnxruntime_extensions/onnxprocess/_session.py +355 -0
- onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
- onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
- onnxruntime_extensions/pnp/__init__.py +13 -0
- onnxruntime_extensions/pnp/_base.py +124 -0
- onnxruntime_extensions/pnp/_imagenet.py +65 -0
- onnxruntime_extensions/pnp/_nlp.py +148 -0
- onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
- onnxruntime_extensions/pnp/_torchext.py +310 -0
- onnxruntime_extensions/pnp/_unifier.py +45 -0
- onnxruntime_extensions/pnp/_utils.py +302 -0
- onnxruntime_extensions/pp_api.py +83 -0
- onnxruntime_extensions/tools/__init__.py +0 -0
- onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
- onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
- onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
- onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
- onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
- onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
- onnxruntime_extensions/util.py +186 -0
- onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
- onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
- onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
- onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
- onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,628 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import builtins
|
|
3
|
+
import functools
|
|
4
|
+
import numpy as np
|
|
5
|
+
from onnx import onnx_pb as onnx_proto
|
|
6
|
+
from typing import List, Tuple, Optional, Union, Any, ContextManager, overload, Iterator, NamedTuple
|
|
7
|
+
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout # noqa
|
|
8
|
+
from torch import strided, memory_format, contiguous_format, StringType # noqa
|
|
9
|
+
|
|
10
|
+
from ._onnx_ops import ox as _ox
|
|
11
|
+
from .._ortapi2 import OrtPyFunction
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class _EagerTensor:
|
|
15
|
+
def __init__(self, _t, name=None, sess=None, raw_data: Any = None):
|
|
16
|
+
self._t = _t if isinstance(_t, torch.Tensor) else torch.tensor(_t)
|
|
17
|
+
if isinstance(name, (tuple, list)):
|
|
18
|
+
assert len(name) == 1, "Multiple names for one tensor!"
|
|
19
|
+
name = name[0]
|
|
20
|
+
self.name = '' if name is None else name
|
|
21
|
+
self.raw_data = raw_data
|
|
22
|
+
self.symbolic_shape = []
|
|
23
|
+
|
|
24
|
+
def __repr__(self):
|
|
25
|
+
if self.raw_data is not None:
|
|
26
|
+
return "name: {}, \"{}\"".format(self.name, str(self.raw_data))
|
|
27
|
+
else:
|
|
28
|
+
return "name: {}, {}, dtype={}".format(self.name, repr(self._t), str(self._t.dtype))
|
|
29
|
+
|
|
30
|
+
_all_ops = {}
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def value(self) -> Union[torch.Tensor, Any]:
|
|
34
|
+
return self.raw_data if self.raw_data else self._t
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def t(self):
|
|
38
|
+
return self._t
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def dtype(self):
|
|
42
|
+
return self._t.dtype
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def onnx_type(self):
|
|
46
|
+
return self.to_onnx_type(self._t.dtype)
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def is_numeric(cls, np_arr):
|
|
50
|
+
return np_arr.dtype.kind in set('buifc')
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def set_active_session(cls, sess):
|
|
54
|
+
"""
|
|
55
|
+
set the active operator tracing log session. if sess is None, the active session will be removed
|
|
56
|
+
:param sess:
|
|
57
|
+
:return:
|
|
58
|
+
"""
|
|
59
|
+
if not hasattr(cls, '_active_session'):
|
|
60
|
+
cls._active_session = sess
|
|
61
|
+
if sess is None:
|
|
62
|
+
raise RuntimeError("unset the active session twice!")
|
|
63
|
+
else:
|
|
64
|
+
if sess is not None:
|
|
65
|
+
raise RuntimeError("The active session already assigned!")
|
|
66
|
+
delattr(cls, '_active_session')
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def get_trace_session(cls):
|
|
70
|
+
if not hasattr(cls, '_active_session'):
|
|
71
|
+
raise RuntimeError("the tracing not started yet!")
|
|
72
|
+
return cls._active_session # noqa
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_container(cls):
|
|
76
|
+
return cls.get_trace_session().container
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def from_onnx(cls, raw_val, ort_sess, name):
|
|
80
|
+
raw_data = None
|
|
81
|
+
if cls.is_numeric(raw_val):
|
|
82
|
+
val = torch.from_numpy(raw_val)
|
|
83
|
+
else:
|
|
84
|
+
# only keep the shape and the value was stored by it-self.
|
|
85
|
+
val = torch.empty(*raw_val.shape, dtype=torch.uint8)
|
|
86
|
+
raw_data = raw_val
|
|
87
|
+
t = cls(val, name, ort_sess, raw_data)
|
|
88
|
+
return t
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def from_torch(cls, _t, name):
|
|
92
|
+
t_name = name if name is not None else "id_{}".format(id(_t))
|
|
93
|
+
ts = cls(_t, t_name)
|
|
94
|
+
return ts
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
# torch.tensor prototype
|
|
98
|
+
def mytensor(cls, data: Any, dtype: Optional[_dtype] = None, device: Union[_device, str, None] = None, requires_grad: _bool = False): # noqa
|
|
99
|
+
y = torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
|
|
100
|
+
val = _ox.make_tensor(cls.to_onnx_type(y.dtype), list(y.size()),
|
|
101
|
+
[data] if isinstance(data, (int, float, str, bool)) else data)
|
|
102
|
+
s = _ox.constant([], [_ox.get_unique_tensor_name('const')], cls.get_container(), None, value=val)
|
|
103
|
+
return cls.from_torch(y, s)
|
|
104
|
+
|
|
105
|
+
def numpy(self):
|
|
106
|
+
return self._t.numpy() if self.raw_data is None else self.raw_data
|
|
107
|
+
|
|
108
|
+
def item(self):
|
|
109
|
+
return self.numpy().item()
|
|
110
|
+
|
|
111
|
+
def get_shape(self):
|
|
112
|
+
return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape
|
|
113
|
+
|
|
114
|
+
def _to_binary_tensor_args(self, other):
|
|
115
|
+
# convert self, other to [self, other], but if either is a number, convert that to a constant
|
|
116
|
+
x, y = self, other
|
|
117
|
+
if isinstance(y, (int, float, bool, np.ndarray)):
|
|
118
|
+
y = self.mytensor(y)
|
|
119
|
+
elif isinstance(x, (int, float, bool, np.ndarray)):
|
|
120
|
+
x = self.mytensor(x)
|
|
121
|
+
return x, y
|
|
122
|
+
|
|
123
|
+
_dup_id = 0
|
|
124
|
+
|
|
125
|
+
def __copy__(self):
|
|
126
|
+
new_t = _EagerTensor.from_torch(self.t, self.name + '_{}'.format(_EagerTensor._dup_id))
|
|
127
|
+
self._dup_id += 1
|
|
128
|
+
new_t.raw_data = self.raw_data
|
|
129
|
+
return new_t
|
|
130
|
+
|
|
131
|
+
def __add__(self, other):
|
|
132
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
133
|
+
y = torch.add(x0._t, x1._t)
|
|
134
|
+
s = _ox.add(*_EagerTensor.ox_args([x0, x1]))
|
|
135
|
+
return self.from_torch(y, s)
|
|
136
|
+
|
|
137
|
+
def __sub__(self, other):
|
|
138
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
139
|
+
y = torch.sub(x0._t, x1._t)
|
|
140
|
+
s = _ox.sub(*_EagerTensor.ox_args([x0, x1]))
|
|
141
|
+
return self.from_torch(y, s)
|
|
142
|
+
|
|
143
|
+
def __mul__(self, other):
|
|
144
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
145
|
+
y = torch.mul(x0._t, x1._t)
|
|
146
|
+
s = _ox.mul(*_EagerTensor.ox_args([x0, x1]))
|
|
147
|
+
return self.from_torch(y, s)
|
|
148
|
+
|
|
149
|
+
def __div__(self, other):
|
|
150
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
151
|
+
y = torch.div(x0._t, x1._t)
|
|
152
|
+
s = _ox.div(*_EagerTensor.ox_args([x0, x1]))
|
|
153
|
+
return self.from_torch(y, s)
|
|
154
|
+
|
|
155
|
+
def __pow__(self, other):
|
|
156
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
157
|
+
y = torch.pow(x0._t, x1._t)
|
|
158
|
+
s = _ox.pow(*_EagerTensor.ox_args([x0, x1]))
|
|
159
|
+
return self.from_torch(y, s)
|
|
160
|
+
|
|
161
|
+
def __matmul__(self, other):
|
|
162
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
163
|
+
y = torch.matmul(x0._t, x1._t)
|
|
164
|
+
s = _ox.matmul(*_EagerTensor.ox_args([x0, x1]))
|
|
165
|
+
return self.from_torch(y, s)
|
|
166
|
+
|
|
167
|
+
def __lt__(self, other):
|
|
168
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
169
|
+
y = torch.less(x0._t, x1._t)
|
|
170
|
+
s = _ox.less(*_EagerTensor.ox_args([x0, x1]))
|
|
171
|
+
return self.from_torch(y, s)
|
|
172
|
+
|
|
173
|
+
def __le__(self, other):
|
|
174
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
175
|
+
y = torch.less_equal(x0._t, x1._t)
|
|
176
|
+
s = _ox.less_equal(*_EagerTensor.ox_args([x0, x1]))
|
|
177
|
+
return self.from_torch(y, s)
|
|
178
|
+
|
|
179
|
+
def __eq__(self, other):
|
|
180
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
181
|
+
y = torch.equal(x0._t, x1._t)
|
|
182
|
+
s = _ox.equal(*_EagerTensor.ox_args([x0, x1]))
|
|
183
|
+
return self.from_torch(y, s)
|
|
184
|
+
|
|
185
|
+
def __ne__(self, other):
|
|
186
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
187
|
+
y = torch.not_equal(x0._t, x1._t)
|
|
188
|
+
s = _ox.not_equal(*_EagerTensor.ox_args([x0, x1]))
|
|
189
|
+
return self.from_torch(y, s)
|
|
190
|
+
|
|
191
|
+
def __gt__(self, other):
|
|
192
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
193
|
+
y = torch.greater(x0._t, x1._t)
|
|
194
|
+
s = _ox.greater(*_EagerTensor.ox_args([x0, x1]))
|
|
195
|
+
return self.from_torch(y, s)
|
|
196
|
+
|
|
197
|
+
def __ge__(self, other):
|
|
198
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
199
|
+
y = torch.greater_equal(x0._t, x1._t)
|
|
200
|
+
s = _ox.greater_equal(*_EagerTensor.ox_args([x0, x1]))
|
|
201
|
+
return self.from_torch(y, s)
|
|
202
|
+
|
|
203
|
+
def __invert__(self):
|
|
204
|
+
if self.t.dtype is torch.bool:
|
|
205
|
+
y = torch.logical_not(self.t)
|
|
206
|
+
s = _ox.not_op(*self.my_args())
|
|
207
|
+
return self.from_torch(y, s)
|
|
208
|
+
else:
|
|
209
|
+
raise NotImplementedError("no numeric tensor inverse supported yet.")
|
|
210
|
+
|
|
211
|
+
def __neg__(self):
|
|
212
|
+
y = torch.neg([self.t])
|
|
213
|
+
s = _ox.neg(*self.my_args())
|
|
214
|
+
return self.from_torch(y, s)
|
|
215
|
+
|
|
216
|
+
def __not__(self):
|
|
217
|
+
y = torch.logical_not(self.t)
|
|
218
|
+
s = _ox.not_op(*self.my_args())
|
|
219
|
+
return self.from_torch(y, s)
|
|
220
|
+
|
|
221
|
+
def __or__(self, other):
|
|
222
|
+
x0, x1 = self._to_binary_tensor_args(other)
|
|
223
|
+
y = torch.logical_or(x0._t, x1._t)
|
|
224
|
+
s = _ox.or_op(*_EagerTensor.ox_args([x0, x1]))
|
|
225
|
+
return self.from_torch(y, s)
|
|
226
|
+
|
|
227
|
+
def __getitem__(self, indices):
|
|
228
|
+
y = self.value.__getitem__(indices)
|
|
229
|
+
|
|
230
|
+
# normalize indices to tuples of slices
|
|
231
|
+
# Formats encountered:
|
|
232
|
+
# - a single int
|
|
233
|
+
# - a tuple of (int or slice)
|
|
234
|
+
if not isinstance(indices, (tuple, list)): # single item: make it a tuple
|
|
235
|
+
indices = (indices,)
|
|
236
|
+
squeeze = [axis for axis, index in enumerate(indices) if
|
|
237
|
+
isinstance(index, int)] # which axes had a single index?
|
|
238
|
+
indices = tuple(
|
|
239
|
+
index if isinstance(index, slice) else slice(index, index + 1 if index != -1 else None, 1) for index in
|
|
240
|
+
indices) # make all tuple items of type Slice
|
|
241
|
+
bs, es, ss, ds = [], [], [], []
|
|
242
|
+
INT_MAX = 2 ** 63 - 1
|
|
243
|
+
for axis, index in enumerate(indices):
|
|
244
|
+
if not isinstance(index, slice):
|
|
245
|
+
raise ValueError("Index expected")
|
|
246
|
+
if index.start is None and index.stop is None: # [:] can be skipped
|
|
247
|
+
continue
|
|
248
|
+
b, e, s = index.start, index.stop, index.step
|
|
249
|
+
bs.append(b if b is not None else 0)
|
|
250
|
+
es.append(e if e is not None else INT_MAX)
|
|
251
|
+
ss.append(s if s is not None else 1)
|
|
252
|
+
ds.append(axis)
|
|
253
|
+
s = _ox.slice(*self.my_args(), starts=bs, ends=es, axes=ds, steps=ss)
|
|
254
|
+
if squeeze: # single index means we must drop the axis
|
|
255
|
+
s = _ox.squeeze(*self.ox_name_args(s), axes=squeeze)
|
|
256
|
+
|
|
257
|
+
return self.from_torch(y, s)
|
|
258
|
+
|
|
259
|
+
def __getattribute__(self, attr):
|
|
260
|
+
"""
|
|
261
|
+
A little hack that allows to call unary operators in a chaining fashion,
|
|
262
|
+
e.g. x.shape() instead of ox.shape(x).
|
|
263
|
+
"""
|
|
264
|
+
if attr in _EagerTensor._all_ops:
|
|
265
|
+
f = _EagerTensor._all_ops[attr]
|
|
266
|
+
return functools.partial(f, self)
|
|
267
|
+
else:
|
|
268
|
+
return object.__getattribute__(self, attr)
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def ox_name_args(cls, input_names, output_names=None):
|
|
272
|
+
"""
|
|
273
|
+
generate the arguments for ONNX model builder.
|
|
274
|
+
:param input_names: input name list
|
|
275
|
+
:param output_names: output name list, can be None, or [None]*output_n
|
|
276
|
+
:return: input_names, output_names, container, operator_name
|
|
277
|
+
"""
|
|
278
|
+
container = cls.get_trace_session().container
|
|
279
|
+
if output_names is None:
|
|
280
|
+
output_names = [None] # by default, there is only one output
|
|
281
|
+
|
|
282
|
+
output_names = [_ox.get_unique_tensor_name(str(n_))
|
|
283
|
+
if output_names[n_] is None else
|
|
284
|
+
output_names[n_] for n_ in range(len(output_names))]
|
|
285
|
+
operator_name = None
|
|
286
|
+
return input_names, output_names, container, operator_name
|
|
287
|
+
|
|
288
|
+
@classmethod
|
|
289
|
+
def ort_verify(cls, ts_from, ts_to):
|
|
290
|
+
result, model = cls.get_trace_session().runops(ts_from, ts_to)
|
|
291
|
+
for idx in range(len(ts_to)):
|
|
292
|
+
if not np.allclose(ts_to[idx].numpy(), result[idx]):
|
|
293
|
+
# ONNX cannot be import globally, which is conflict with torch.onnx
|
|
294
|
+
import onnx # noqa
|
|
295
|
+
onnx.save_model(model, 'mt_debmodel.onnx')
|
|
296
|
+
raise RuntimeError("ONNXRuntime Result is not same pytorch!")
|
|
297
|
+
|
|
298
|
+
def create_and_verify(self, value, name, additional_inputs=None):
|
|
299
|
+
ts_y = self.from_torch(value, name)
|
|
300
|
+
inputs = [self] + ([] if additional_inputs is None else additional_inputs)
|
|
301
|
+
self.ort_verify(inputs, [ts_y])
|
|
302
|
+
return ts_y
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def ox_args(cls, tensors, output_names=None):
|
|
306
|
+
input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors]
|
|
307
|
+
return cls.ox_name_args(input_names, output_names)
|
|
308
|
+
|
|
309
|
+
def my_args(self):
|
|
310
|
+
return self.ox_args([self])
|
|
311
|
+
|
|
312
|
+
@staticmethod
|
|
313
|
+
def normalize_seq(list_or_tuple):
|
|
314
|
+
return [x.value.item() if isinstance(x, _EagerTensor) else x for x in list_or_tuple]
|
|
315
|
+
|
|
316
|
+
@staticmethod
|
|
317
|
+
def to_onnx_type(torch_type):
|
|
318
|
+
ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
|
|
319
|
+
torch.float32: onnx_proto.TensorProto.FLOAT,
|
|
320
|
+
torch.long: onnx_proto.TensorProto.INT64,
|
|
321
|
+
torch.int32: onnx_proto.TensorProto.INT32}
|
|
322
|
+
# ...
|
|
323
|
+
return ty_dict.get(torch_type, onnx_proto.TensorProto.STRING)
|
|
324
|
+
|
|
325
|
+
def long(self):
|
|
326
|
+
y = self._t.long()
|
|
327
|
+
s = _ox.cast(*self.my_args(), to=onnx_proto.TensorProto.INT64)
|
|
328
|
+
return self.create_and_verify(y, s[0])
|
|
329
|
+
|
|
330
|
+
def cumsum(self, dim: _int, *, dtype: Optional[_dtype] = None): # noqa
|
|
331
|
+
y = self._t.cumsum(dim, dtype=dtype)
|
|
332
|
+
s = _ox.cumsum(*self.my_args(), axis=dim)
|
|
333
|
+
return self.create_and_verify(y, s[0])
|
|
334
|
+
|
|
335
|
+
def size(self):
|
|
336
|
+
y = self._t.size()
|
|
337
|
+
s = _ox.shape(*self.my_args())
|
|
338
|
+
return self.create_and_verify(y, s[0])
|
|
339
|
+
|
|
340
|
+
def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False):
|
|
341
|
+
y = self._t.type(dtype, non_blocking)
|
|
342
|
+
s = _ox.cast(*self.my_args(), to=self.to_onnx_type(dtype))
|
|
343
|
+
return self.create_and_verify(y, s)
|
|
344
|
+
|
|
345
|
+
def to(self, device):
|
|
346
|
+
y = self._t.to(device)
|
|
347
|
+
s = _ox.identity(*self.my_args())
|
|
348
|
+
return self.create_and_verify(y, s[0])
|
|
349
|
+
|
|
350
|
+
def cpu(self):
|
|
351
|
+
y = self._t.cpu()
|
|
352
|
+
s = _ox.identity(*self.my_args())
|
|
353
|
+
return self.create_and_verify(y, s[0])
|
|
354
|
+
|
|
355
|
+
def detach(self):
|
|
356
|
+
y = self._t.detach()
|
|
357
|
+
s = _ox.identity(*self.my_args())
|
|
358
|
+
return self.create_and_verify(y, s[0])
|
|
359
|
+
|
|
360
|
+
def clone(self):
|
|
361
|
+
y = self._t.clone()
|
|
362
|
+
s = _ox.identity(*self.my_args())
|
|
363
|
+
return self.create_and_verify(y, s[0])
|
|
364
|
+
|
|
365
|
+
def masked_fill(self, mask, value):
|
|
366
|
+
y = self._t.masked_fill(mask.value, value)
|
|
367
|
+
if not isinstance(value, _EagerTensor):
|
|
368
|
+
value = _EagerTensor.mytensor(value)
|
|
369
|
+
s = _ox.where(*_EagerTensor.ox_args([mask, value, self]))
|
|
370
|
+
return self.create_and_verify(y, s[0], additional_inputs=[mask, value])
|
|
371
|
+
|
|
372
|
+
def unsqueeze(self, dim: _int):
|
|
373
|
+
y = self._t.unsqueeze(dim)
|
|
374
|
+
s = _ox.unsqueeze(*self.my_args(), [dim])
|
|
375
|
+
return self.create_and_verify(y, s[0])
|
|
376
|
+
|
|
377
|
+
def squeeze(self, dim: _int):
|
|
378
|
+
y = self._t.squeeze(dim)
|
|
379
|
+
s = _ox.squeeze(*self.my_args(), [dim])
|
|
380
|
+
return self.create_and_verify(y, s[0])
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _create_ox_sequence(*size):
|
|
384
|
+
container = _EagerTensor.get_container()
|
|
385
|
+
con_x = []
|
|
386
|
+
if builtins.any(isinstance(n_, _EagerTensor) for n_ in size):
|
|
387
|
+
for x in size:
|
|
388
|
+
if isinstance(x, _EagerTensor):
|
|
389
|
+
x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0]
|
|
390
|
+
else:
|
|
391
|
+
x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x])
|
|
392
|
+
x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0]
|
|
393
|
+
con_x.append(x_h)
|
|
394
|
+
return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
|
|
395
|
+
else:
|
|
396
|
+
ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size)
|
|
397
|
+
return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None):
|
|
401
|
+
if onnx_type is None:
|
|
402
|
+
onnx_type = onnx_proto.TensorProto.FLOAT
|
|
403
|
+
names = _create_ox_sequence(*size)
|
|
404
|
+
ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
|
|
405
|
+
|
|
406
|
+
container = _EagerTensor.get_container()
|
|
407
|
+
s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
|
|
408
|
+
return s[0]
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
|
|
412
|
+
dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None,
|
|
413
|
+
requires_grad: _bool = False) -> _EagerTensor: # noqa
|
|
414
|
+
|
|
415
|
+
if len(size) == 1 and isinstance(size[0], list):
|
|
416
|
+
size = size[0]
|
|
417
|
+
n_size = _EagerTensor.normalize_seq(size)
|
|
418
|
+
y = torch.empty(*n_size, memory_format=memory_format, out=out,
|
|
419
|
+
dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
|
|
420
|
+
s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
|
421
|
+
return _EagerTensor.from_torch(y, s)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
|
425
|
+
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
|
|
426
|
+
|
|
427
|
+
if len(size) == 1 and isinstance(size[0], list):
|
|
428
|
+
size = size[0]
|
|
429
|
+
n_size = _EagerTensor.normalize_seq(size)
|
|
430
|
+
y = torch.zeros(*n_size, out=out, dtype=dtype,
|
|
431
|
+
layout=layout, device=device, requires_grad=requires_grad)
|
|
432
|
+
s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
|
433
|
+
return _EagerTensor.from_torch(y, s)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
|
|
437
|
+
device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
|
|
438
|
+
|
|
439
|
+
if len(size) == 1 and isinstance(size[0], list):
|
|
440
|
+
size = size[0]
|
|
441
|
+
n_size = _EagerTensor.normalize_seq(size)
|
|
442
|
+
y = torch.ones(*n_size, out=out, dtype=dtype,
|
|
443
|
+
layout=layout, device=device, requires_grad=requires_grad)
|
|
444
|
+
s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
|
|
445
|
+
return _EagerTensor.from_torch(y, s)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa
|
|
449
|
+
|
|
450
|
+
if len(repeats) == 1 and isinstance(repeats[0], list):
|
|
451
|
+
repeats = repeats[0]
|
|
452
|
+
n_size = _EagerTensor.normalize_seq(repeats)
|
|
453
|
+
y = input_ts.t.repeat(*n_size)
|
|
454
|
+
seq = _create_ox_sequence(*repeats)
|
|
455
|
+
s = _ox.tile(*input_ts.my_args(), repeats=seq[0])
|
|
456
|
+
return _EagerTensor.from_torch(y, s[0])
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa
|
|
460
|
+
y = torch.argmax(input_ts.value, dim, keepdim)
|
|
461
|
+
s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim)
|
|
462
|
+
return _EagerTensor.from_torch(y, s)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def softmax(input_ts: _EagerTensor, dim: _int, dtype: Optional[_dtype]=None) -> _EagerTensor:
|
|
466
|
+
y = torch.softmax(input_ts.value, dim, dtype)
|
|
467
|
+
s = _ox.softmax(*input_ts.my_args(), axis=dim)
|
|
468
|
+
return _EagerTensor.from_torch(y, s)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def cat(tensors: Union[Tuple[_EagerTensor, ...], List[_EagerTensor]],
|
|
472
|
+
dim, *, out: Optional[_EagerTensor] = None) -> _EagerTensor: # noqa
|
|
473
|
+
res = torch.cat([t_.value for t_ in tensors], dim, out=out)
|
|
474
|
+
oname = _ox.concat(*_EagerTensor.ox_args(tensors), dim)
|
|
475
|
+
y = _EagerTensor.from_torch(res, oname[0])
|
|
476
|
+
_EagerTensor.ort_verify(tensors, [y])
|
|
477
|
+
return y
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
|
|
481
|
+
container = _EagerTensor.get_container()
|
|
482
|
+
y = torch.all(input_ts.value)
|
|
483
|
+
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
|
|
484
|
+
s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1])
|
|
485
|
+
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
|
|
486
|
+
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
|
|
487
|
+
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
|
|
488
|
+
return input_ts.create_and_verify(y, s[0])
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
|
|
492
|
+
container = _EagerTensor.get_container()
|
|
493
|
+
y = torch.any(input_ts.value)
|
|
494
|
+
s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
|
|
495
|
+
s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1])
|
|
496
|
+
s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
|
|
497
|
+
container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
|
|
498
|
+
s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
|
|
499
|
+
return input_ts.create_and_verify(y, s[0])
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def reshape(input_ts: _EagerTensor, shape: _size):
|
|
503
|
+
y = input_ts.t.reshape(shape)
|
|
504
|
+
s = _ox.reshape(*input_ts.my_args(), desired_shape=shape)
|
|
505
|
+
return input_ts.create_and_verify(y, s[0])
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def transpose(input_ts: _EagerTensor, dim0: _int, dim1: _int):
|
|
509
|
+
y = input_ts.t.transpose(dim0, dim1)
|
|
510
|
+
axes = list(range(y.dim()))
|
|
511
|
+
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
|
|
512
|
+
s = _ox.transpose(*input_ts.my_args(), perm=axes)
|
|
513
|
+
return input_ts.create_and_verify(y, s[0])
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
class _LoopIterator:
|
|
517
|
+
def __init__(self, ctx):
|
|
518
|
+
self.context = ctx
|
|
519
|
+
|
|
520
|
+
def __iter__(self):
|
|
521
|
+
return self
|
|
522
|
+
|
|
523
|
+
def __next__(self):
|
|
524
|
+
if self.context.is_stopped():
|
|
525
|
+
_EagerTensor.get_trace_session().pop_container()
|
|
526
|
+
raise StopIteration
|
|
527
|
+
return self.context.current()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class _ControlFlowContext:
|
|
531
|
+
def __init__(self):
|
|
532
|
+
self.condition_i = None
|
|
533
|
+
self.condition = None
|
|
534
|
+
self.loop_count = None
|
|
535
|
+
self.iteration_num = None
|
|
536
|
+
self.states_i = []
|
|
537
|
+
self.loop_states = []
|
|
538
|
+
self.scan_outputs = []
|
|
539
|
+
self.sub_graph = None
|
|
540
|
+
|
|
541
|
+
def flow_output(self, cond, *outputs):
|
|
542
|
+
assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects"
|
|
543
|
+
if self.sub_graph is None:
|
|
544
|
+
trc = _EagerTensor.get_trace_session()
|
|
545
|
+
self.sub_graph = trc.build_graph(trc.container,
|
|
546
|
+
[self.iteration_num, self.condition] + self.loop_states,
|
|
547
|
+
[cond] + list(outputs))
|
|
548
|
+
|
|
549
|
+
self.condition = cond
|
|
550
|
+
c_state = len(self.loop_states)
|
|
551
|
+
self.loop_states = list(outputs[:c_state])
|
|
552
|
+
if len(self.scan_outputs) == 0:
|
|
553
|
+
sc = [_EagerTensor(torch.unsqueeze(sci_.value, 0), 'sc_' + sci_.name) for sci_ in outputs[c_state:]]
|
|
554
|
+
self.scan_outputs = sc
|
|
555
|
+
else:
|
|
556
|
+
next_extra_vars = []
|
|
557
|
+
for idx_, ext_ in enumerate(outputs[c_state:]):
|
|
558
|
+
et = self.scan_outputs[idx_]
|
|
559
|
+
next_extra_vars.append(_EagerTensor(
|
|
560
|
+
torch.cat([et.value, torch.unsqueeze(outputs[c_state + idx_].value, 0)]), name=et.name))
|
|
561
|
+
self.scan_outputs = next_extra_vars
|
|
562
|
+
self.iteration_num.value.add_(1)
|
|
563
|
+
|
|
564
|
+
def current(self):
|
|
565
|
+
return [self.iteration_num] + list(self.loop_states)
|
|
566
|
+
|
|
567
|
+
def finalize(self):
|
|
568
|
+
# generate the outputs from the enclosing scope variables
|
|
569
|
+
full_outputs = [_EagerTensor(o_.value, 'lp_' + o_.name) for o_ in self.loop_states + self.scan_outputs]
|
|
570
|
+
_ox.loop(*_EagerTensor.ox_args(
|
|
571
|
+
[self.loop_count, self.condition_i] + list(self.states_i),
|
|
572
|
+
[ts_.name for ts_ in full_outputs]), body=self.sub_graph)
|
|
573
|
+
return tuple(full_outputs)
|
|
574
|
+
|
|
575
|
+
def is_stopped(self):
|
|
576
|
+
return self.condition.item() is False or self.iteration_num.item() >= self.loop_count.item()
|
|
577
|
+
|
|
578
|
+
def loop(self, loop_c, condition, *states):
|
|
579
|
+
self.condition = condition
|
|
580
|
+
self.condition_i = condition
|
|
581
|
+
self.states_i = states
|
|
582
|
+
_EagerTensor.get_trace_session().stack_container()
|
|
583
|
+
self.iteration_num = _EagerTensor.mytensor(0)
|
|
584
|
+
# clone the variables for the sub graph.
|
|
585
|
+
self.loop_states = [_EagerTensor(st_.value, st_.name) for st_ in states]
|
|
586
|
+
self.loop_count = loop_c
|
|
587
|
+
loop_b = _LoopIterator(self)
|
|
588
|
+
return iter(loop_b)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def control_flow():
|
|
592
|
+
return _ControlFlowContext()
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class _TracingEagerOp(OrtPyFunction):
|
|
596
|
+
def __call__(self, *args, **kwargs):
|
|
597
|
+
np_args = [ts_.numpy() if isinstance(ts_, _EagerTensor) else ts_ for ts_ in args]
|
|
598
|
+
outseq = super().__call__(*np_args, **kwargs)
|
|
599
|
+
outseq = outseq if isinstance(outseq, (list, tuple)) else [outseq]
|
|
600
|
+
|
|
601
|
+
outputs = [_EagerTensor.from_onnx(outseq[n_], self.ort_session, out_.name)
|
|
602
|
+
for n_, out_ in enumerate(self.ort_session.get_outputs())]
|
|
603
|
+
|
|
604
|
+
y_names = [y.name for y in outputs]
|
|
605
|
+
_ox.model_call(*_EagerTensor.ox_args(args, output_names=y_names), oxml=self.onnx_model)
|
|
606
|
+
return tuple(outputs) if len(outputs) > 1 else outputs[0]
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def op_from_customop(op_type, *args, **kwargs) -> _TracingEagerOp:
|
|
610
|
+
return _TracingEagerOp.from_customop(op_type, *args, **kwargs)
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp:
|
|
614
|
+
return _TracingEagerOp.from_model(path_or_model, *args, **kwargs)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
_EagerTensor._all_ops = {'argmax': argmax,
|
|
618
|
+
'softmax': softmax,
|
|
619
|
+
'reshape': reshape,
|
|
620
|
+
'transpose': transpose,
|
|
621
|
+
'repeat': repeat,
|
|
622
|
+
'any': any,
|
|
623
|
+
'all': all}
|
|
624
|
+
|
|
625
|
+
tensor = _EagerTensor.mytensor
|
|
626
|
+
tensor_from_onnx = _EagerTensor.from_onnx
|
|
627
|
+
tensor_from_torch = _EagerTensor.from_torch
|
|
628
|
+
tensor_set_session = _EagerTensor.set_active_session
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import torch
|
|
3
|
+
except ImportError:
|
|
4
|
+
raise RuntimeError("pytorch not installed, which is required by this ONNX build tool")
|
|
5
|
+
|
|
6
|
+
from torch import (float32,
|
|
7
|
+
float,
|
|
8
|
+
float64,
|
|
9
|
+
double,
|
|
10
|
+
float16,
|
|
11
|
+
bfloat16,
|
|
12
|
+
half,
|
|
13
|
+
uint8,
|
|
14
|
+
int8,
|
|
15
|
+
int16,
|
|
16
|
+
short,
|
|
17
|
+
int32,
|
|
18
|
+
int,
|
|
19
|
+
int64,
|
|
20
|
+
long,
|
|
21
|
+
complex64,
|
|
22
|
+
cfloat,
|
|
23
|
+
complex128,
|
|
24
|
+
cdouble,
|
|
25
|
+
quint8,
|
|
26
|
+
qint8,
|
|
27
|
+
qint32,
|
|
28
|
+
bool) # noqa
|
|
29
|
+
|
|
30
|
+
from torch import randn, onnx # noqa
|
|
31
|
+
from ._tensor import * # noqa
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# onnxruntime-extensions pre&post processing frontend depends on the PyTorch
|
|
2
|
+
try:
|
|
3
|
+
import torch
|
|
4
|
+
except ImportError as e:
|
|
5
|
+
print("No torch installation found, which is required by the pre&post scripting!")
|
|
6
|
+
raise e
|
|
7
|
+
|
|
8
|
+
from ._base import ProcessingTracedModule, ProcessingScriptModule, CustomFunction
|
|
9
|
+
from ._torchext import * # noqa
|
|
10
|
+
from ._unifier import export
|
|
11
|
+
|
|
12
|
+
from ._imagenet import * # noqa
|
|
13
|
+
from ._nlp import PreHuggingFaceGPT2, PreHuggingFaceBert, HfBertTokenizer # noqa
|