cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/baseutils.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
# stdlib
|
|
2
|
+
from collections import abc
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
|
+
|
|
5
|
+
# dependencies
|
|
6
|
+
import torch
|
|
7
|
+
import typing_extensions as tx
|
|
8
|
+
|
|
9
|
+
# internal
|
|
10
|
+
from .utils.indexing import guess_shape
|
|
11
|
+
from .utils.py import ensure_list
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_first_element(x, include=None, exclude=None, types=None):
|
|
15
|
+
"""Return the fist element (tensor or string) in the nested structure"""
|
|
16
|
+
types = ensure_list(types or [])
|
|
17
|
+
|
|
18
|
+
def _recursive(x):
|
|
19
|
+
if hasattr(x, 'items'):
|
|
20
|
+
for k, v in x.items():
|
|
21
|
+
if include and k not in include:
|
|
22
|
+
continue
|
|
23
|
+
if exclude and k in exclude:
|
|
24
|
+
continue
|
|
25
|
+
v, ok = _recursive(v)
|
|
26
|
+
if ok:
|
|
27
|
+
return v, True
|
|
28
|
+
return None, False
|
|
29
|
+
if isinstance(x, (list, tuple)):
|
|
30
|
+
for v in x:
|
|
31
|
+
v, ok = _recursive(v)
|
|
32
|
+
if ok:
|
|
33
|
+
return v, True
|
|
34
|
+
return None, False
|
|
35
|
+
if torch.is_tensor(x) or (types and isinstance(x, types)):
|
|
36
|
+
return x, True
|
|
37
|
+
return x, False
|
|
38
|
+
|
|
39
|
+
return _recursive(x)[0]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def recursive_cat(x, **kwargs):
|
|
43
|
+
"""Concatenate tensors across the channel axis in a nested structure"""
|
|
44
|
+
def _rec(*x):
|
|
45
|
+
if all(torch.is_tensor(x1) for x1 in x):
|
|
46
|
+
return torch.cat(x, **kwargs)
|
|
47
|
+
if isinstance(x[0], (list, tuple)):
|
|
48
|
+
return type(x[0])(_rec(*x1) for x1 in zip(*x))
|
|
49
|
+
if hasattr(x[0], 'items'):
|
|
50
|
+
return type(x[0])(**{k: _rec(*x2)
|
|
51
|
+
for k in x[0].keys()
|
|
52
|
+
for x2 in zip(x1[k] for x1 in x)})
|
|
53
|
+
raise TypeError(f'What should I do with a {type(x[0])}?')
|
|
54
|
+
return _rec(*x)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def prepare_output(results, returns):
|
|
58
|
+
"""Prepare object returned by `apply_transform`
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
results : dict[str, tensor]
|
|
63
|
+
Named results
|
|
64
|
+
returns : list[str] or dict[str, str] or str
|
|
65
|
+
Structure describing which results should be returned.
|
|
66
|
+
The results will be returned in an object of the same type, with
|
|
67
|
+
the requested results associated to the same keys (if `dict`) or
|
|
68
|
+
same position (if `list`). If a `str`, the requested tensor is
|
|
69
|
+
returned.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
requested_results : list[tensor] or dict[str, tensor] or tensor
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
if returns is None:
|
|
77
|
+
if torch.is_tensor(results):
|
|
78
|
+
pass
|
|
79
|
+
else:
|
|
80
|
+
results = results.get('output', None)
|
|
81
|
+
elif isinstance(returns, dict):
|
|
82
|
+
results = type(returns)(
|
|
83
|
+
**{key: results.get(target, None)
|
|
84
|
+
for key, target in returns.items()})
|
|
85
|
+
elif isinstance(returns, (list, tuple)):
|
|
86
|
+
results = type(returns)(
|
|
87
|
+
[results.get(target, None) for target in returns])
|
|
88
|
+
else:
|
|
89
|
+
results = results.get(returns, None)
|
|
90
|
+
return Returned(results)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def return_requires(returns) -> tx.Set[str]:
|
|
94
|
+
"""Return all requires fields in a flat structure"""
|
|
95
|
+
if returns is None:
|
|
96
|
+
return {'output'}
|
|
97
|
+
returns = flatstruct(returns)
|
|
98
|
+
if isinstance(returns, dict):
|
|
99
|
+
return set(returns.values())
|
|
100
|
+
elif isinstance(returns, (list, tuple, set)):
|
|
101
|
+
return set(returns)
|
|
102
|
+
else:
|
|
103
|
+
return {returns}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def returns_find(flag, returned, returns):
|
|
107
|
+
"""Find tensor corresponding to flag in returned structure"""
|
|
108
|
+
if returns is None:
|
|
109
|
+
if flag == 'output':
|
|
110
|
+
return returned
|
|
111
|
+
else:
|
|
112
|
+
return None
|
|
113
|
+
if isinstance(returns, dict):
|
|
114
|
+
for key, value in returns.items():
|
|
115
|
+
if value == flag:
|
|
116
|
+
return returned.get(key, None)
|
|
117
|
+
elif isinstance(returns, (list, tuple, set)):
|
|
118
|
+
if flag in returns:
|
|
119
|
+
return returned[returns.index(flag)]
|
|
120
|
+
else:
|
|
121
|
+
return None
|
|
122
|
+
else:
|
|
123
|
+
assert isinstance(returns, str)
|
|
124
|
+
if returns == flag:
|
|
125
|
+
return returned
|
|
126
|
+
else:
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def returns_update(value, flag, returned, returns):
|
|
131
|
+
"""Find tensor corresponding to flag in returned structure"""
|
|
132
|
+
if returns is None:
|
|
133
|
+
if flag == 'output':
|
|
134
|
+
return value
|
|
135
|
+
else:
|
|
136
|
+
return None
|
|
137
|
+
if isinstance(returns, dict):
|
|
138
|
+
if flag in returns:
|
|
139
|
+
returned[flag] = value
|
|
140
|
+
return returned
|
|
141
|
+
elif isinstance(returns, (list, tuple)):
|
|
142
|
+
if flag in returns:
|
|
143
|
+
returned[returns.index(flag)] = value
|
|
144
|
+
return returned
|
|
145
|
+
else:
|
|
146
|
+
assert isinstance(returns, str)
|
|
147
|
+
if returns == flag:
|
|
148
|
+
return value
|
|
149
|
+
else:
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def flatstruct(x):
|
|
154
|
+
"""Flatten a nested structure of tensors"""
|
|
155
|
+
|
|
156
|
+
def _flatten(nested):
|
|
157
|
+
if isinstance(nested, dict):
|
|
158
|
+
flat = type(nested)()
|
|
159
|
+
is_dict = True
|
|
160
|
+
for key, elem in nested.items():
|
|
161
|
+
elem = _flatten(elem)
|
|
162
|
+
if isinstance(elem, dict):
|
|
163
|
+
for subkey, subelem in elem.items():
|
|
164
|
+
flat[subkey] = subelem
|
|
165
|
+
elif not isinstance(elem, (dict, list, tuple)):
|
|
166
|
+
flat[key] = elem
|
|
167
|
+
else:
|
|
168
|
+
is_dict = False
|
|
169
|
+
flat[key] = elem
|
|
170
|
+
if not is_dict:
|
|
171
|
+
flat, flatdict = [], flat
|
|
172
|
+
for elem in flatdict.values():
|
|
173
|
+
if not isinstance(elem, (dict, list, tuple)):
|
|
174
|
+
flat.append(elem)
|
|
175
|
+
else:
|
|
176
|
+
flat.extend(elem)
|
|
177
|
+
return flat
|
|
178
|
+
elif isinstance(nested, (list, tuple)):
|
|
179
|
+
flat = []
|
|
180
|
+
for elem in nested:
|
|
181
|
+
elem = _flatten(elem)
|
|
182
|
+
if not isinstance(elem, (dict, list, tuple)):
|
|
183
|
+
flat.append(elem)
|
|
184
|
+
elif isinstance(elem, dict):
|
|
185
|
+
flat.extend(elem.values())
|
|
186
|
+
else:
|
|
187
|
+
flat.extend(elem)
|
|
188
|
+
flat = type(nested)(flat)
|
|
189
|
+
return flat
|
|
190
|
+
else:
|
|
191
|
+
return nested
|
|
192
|
+
|
|
193
|
+
return _flatten(x)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def nested_get(nested, *keys, default=None):
|
|
197
|
+
"""Get value from a nested structure of dicts/lists/tuples"""
|
|
198
|
+
if not keys:
|
|
199
|
+
return nested
|
|
200
|
+
if isinstance(nested, abc.Mapping):
|
|
201
|
+
if keys[0] in nested:
|
|
202
|
+
return nested_get(nested[keys[0]], *keys[1:], default=default)
|
|
203
|
+
else:
|
|
204
|
+
return default
|
|
205
|
+
elif isinstance(nested, abc.Sequence) and not isinstance(nested, str):
|
|
206
|
+
if isinstance(keys[0], int) and 0 <= keys[0] < len(nested):
|
|
207
|
+
return nested_get(nested[keys[0]], *keys[1:], default=default)
|
|
208
|
+
else:
|
|
209
|
+
return default
|
|
210
|
+
else:
|
|
211
|
+
raise TypeError(
|
|
212
|
+
f'Cannot get key {keys[0]} from object of type {type(nested)}'
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class Arguments:
|
|
217
|
+
"""Base class for wrapping arguments of a transform call.
|
|
218
|
+
|
|
219
|
+
No instance of this class are ever created. Instead, it returns
|
|
220
|
+
instances of one of its concrete subclasses:
|
|
221
|
+
|
|
222
|
+
- `NoArguments`: when no arguments are passed
|
|
223
|
+
- `Arg`: when a single argument is passed
|
|
224
|
+
- `Args`: when only positional arguments are passed
|
|
225
|
+
- `Kwargs`: when only keyword arguments are passed
|
|
226
|
+
- `ArgsAndKwargs`: when both positional and keyword arguments are passed
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
def __new__(cls, *args, **kwargs):
|
|
230
|
+
if cls is Arguments:
|
|
231
|
+
if not kwargs and len(args) == 1:
|
|
232
|
+
arg, = args
|
|
233
|
+
if not isinstance(arg, Arguments):
|
|
234
|
+
arg = Arg(arg)
|
|
235
|
+
return arg
|
|
236
|
+
if args and kwargs:
|
|
237
|
+
return ArgsAndKwargs(*args, **kwargs)
|
|
238
|
+
elif args:
|
|
239
|
+
return Args(*args)
|
|
240
|
+
elif kwargs:
|
|
241
|
+
return Kwargs(**kwargs)
|
|
242
|
+
else:
|
|
243
|
+
return NoArguments()
|
|
244
|
+
return super().__new__(cls)
|
|
245
|
+
|
|
246
|
+
def __init__(self, *args, **kwargs):
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
def __str__(self):
|
|
250
|
+
return repr(self)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class NoArguments(Arguments):
|
|
254
|
+
"""Wrapper when no arguments where passed."""
|
|
255
|
+
|
|
256
|
+
def __bool__(self):
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
def __len__(self):
|
|
260
|
+
return 0
|
|
261
|
+
|
|
262
|
+
def __iter__(self):
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
def keys(self):
|
|
266
|
+
return set()
|
|
267
|
+
|
|
268
|
+
def values(self):
|
|
269
|
+
return set()
|
|
270
|
+
|
|
271
|
+
def items(self):
|
|
272
|
+
return set()
|
|
273
|
+
|
|
274
|
+
def unwrap(self):
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
def to_args_kwargs(self):
|
|
278
|
+
return (), {}
|
|
279
|
+
|
|
280
|
+
def __repr__(self):
|
|
281
|
+
return "NoArguments()"
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
class Arg(Arguments):
|
|
285
|
+
"""Single argument"""
|
|
286
|
+
|
|
287
|
+
def __new__(cls, arg):
|
|
288
|
+
if cls is Arg:
|
|
289
|
+
if isinstance(arg, Arguments):
|
|
290
|
+
return arg
|
|
291
|
+
if isinstance(arg, Mapping):
|
|
292
|
+
return DictArg(arg)
|
|
293
|
+
if isinstance(arg, Sequence) and not isinstance(arg, str):
|
|
294
|
+
return TupleArg(arg)
|
|
295
|
+
return super().__new__(cls)
|
|
296
|
+
|
|
297
|
+
def __init__(self, arg):
|
|
298
|
+
if arg is self:
|
|
299
|
+
# This can happen when calling `Arg` on an `Arg`.
|
|
300
|
+
# Our `__new__` should return the input object instead of
|
|
301
|
+
# creating a new one, but because this object _is_ an instance
|
|
302
|
+
# of `Arg`, `__init__` is still called.
|
|
303
|
+
return
|
|
304
|
+
self.arg = arg
|
|
305
|
+
|
|
306
|
+
def unwrap(self):
|
|
307
|
+
return self.arg
|
|
308
|
+
|
|
309
|
+
def to_args_kwargs(self):
|
|
310
|
+
return (self.arg,), {}
|
|
311
|
+
|
|
312
|
+
def __repr__(self):
|
|
313
|
+
return f"{self.__class__.__name__}({self.arg!r})"
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class DictArg(Arg, Mapping):
|
|
317
|
+
"""Single argument that is a mapping."""
|
|
318
|
+
|
|
319
|
+
def __len__(self):
|
|
320
|
+
return len(self.arg)
|
|
321
|
+
|
|
322
|
+
def __iter__(self):
|
|
323
|
+
return iter(self.arg)
|
|
324
|
+
|
|
325
|
+
def __getitem__(self, key):
|
|
326
|
+
return self.arg[key]
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class TupleArg(Arg, Sequence):
|
|
330
|
+
"""Single argument that is a sequence."""
|
|
331
|
+
|
|
332
|
+
def __len__(self):
|
|
333
|
+
return len(self.arg)
|
|
334
|
+
|
|
335
|
+
def __getitem__(self, index):
|
|
336
|
+
return self.arg[index]
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class Args(tuple, Arguments):
|
|
340
|
+
"""Tuple of arguments: `*args`"""
|
|
341
|
+
|
|
342
|
+
def __new__(cls, *args):
|
|
343
|
+
# Call tuple directly to avoid conflict betweet tuple.__new__
|
|
344
|
+
# (which expects a single arg) and Arguments.__new__
|
|
345
|
+
return tuple.__new__(cls, args)
|
|
346
|
+
|
|
347
|
+
def __init__(self, *args):
|
|
348
|
+
# Call tuple directly to avoid conflict betweet tuple.__new__
|
|
349
|
+
# (which expects a single arg) and Arguments.__init__ (which
|
|
350
|
+
# expects no arg)
|
|
351
|
+
tuple.__init__(args)
|
|
352
|
+
|
|
353
|
+
def unwrap(self):
|
|
354
|
+
return tuple(self)
|
|
355
|
+
|
|
356
|
+
def to_args_kwargs(self):
|
|
357
|
+
return tuple(self), {}
|
|
358
|
+
|
|
359
|
+
def __repr__(self):
|
|
360
|
+
args = ", ".join(repr(a) for a in self)
|
|
361
|
+
return f"{self.__class__.__name__}({args})"
|
|
362
|
+
|
|
363
|
+
class Keys:
|
|
364
|
+
def __init__(self, parent):
|
|
365
|
+
self.parent = parent
|
|
366
|
+
|
|
367
|
+
def __iter__(self):
|
|
368
|
+
for i in range(len(self.parent)):
|
|
369
|
+
yield i
|
|
370
|
+
|
|
371
|
+
def __repr__(self):
|
|
372
|
+
return f"Keys({list(self)!r})"
|
|
373
|
+
|
|
374
|
+
class Values:
|
|
375
|
+
def __init__(self, parent):
|
|
376
|
+
self.parent = parent
|
|
377
|
+
|
|
378
|
+
def __iter__(self):
|
|
379
|
+
for a in self.parent:
|
|
380
|
+
yield a
|
|
381
|
+
|
|
382
|
+
def __repr__(self):
|
|
383
|
+
return f"Values({list(self)!r})"
|
|
384
|
+
|
|
385
|
+
class Items:
|
|
386
|
+
def __init__(self, parent):
|
|
387
|
+
self.parent = parent
|
|
388
|
+
|
|
389
|
+
def __iter__(self):
|
|
390
|
+
for i, a in enumerate(self.parent):
|
|
391
|
+
yield (i, a)
|
|
392
|
+
|
|
393
|
+
def __repr__(self):
|
|
394
|
+
return f"Items({list(self)!r})"
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class Kwargs(dict, Arguments):
|
|
398
|
+
"""Dict-like, except that unzipping works on values instead of keys"""
|
|
399
|
+
|
|
400
|
+
def __init__(self, *args, **kwargs):
|
|
401
|
+
if args:
|
|
402
|
+
# Catch case where Arguments(...) was called on a Kwargs.
|
|
403
|
+
# In this case, we want to keep the Kwargs as is, and not
|
|
404
|
+
# call __init__ again.
|
|
405
|
+
if kwargs or len(args) != 1 or not isinstance(args[0], Kwargs):
|
|
406
|
+
raise ValueError("Wrong instantiation")
|
|
407
|
+
return
|
|
408
|
+
super().__init__(**kwargs)
|
|
409
|
+
|
|
410
|
+
def unwrap(self):
|
|
411
|
+
return dict(self)
|
|
412
|
+
|
|
413
|
+
def to_args_kwargs(self):
|
|
414
|
+
return (), dict(self)
|
|
415
|
+
|
|
416
|
+
def __repr__(self):
|
|
417
|
+
kwargs = ", ".join(f"{k}={v!r}" for k, v in self.items())
|
|
418
|
+
return f"{self.__class__.__name__}({kwargs})"
|
|
419
|
+
|
|
420
|
+
def __iter__(self):
|
|
421
|
+
# Iterate across values instead of keys.
|
|
422
|
+
# This allows `Kwargs` to act like a tuple of values.
|
|
423
|
+
for v in self.values():
|
|
424
|
+
yield v
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
class ArgsAndKwargs(Arguments):
|
|
428
|
+
"""Iterator across both args and kwargs"""
|
|
429
|
+
|
|
430
|
+
def __init__(self, *args, **kwargs):
|
|
431
|
+
if len(args) == 1 and isinstance(args[0], ArgsAndKwargs):
|
|
432
|
+
# Catch case where Arguments(...) was called on a ArgsAndKwargs.
|
|
433
|
+
# In this case, we want to keep the ArgsAndKwargs as is, and not
|
|
434
|
+
# call __init__ again.
|
|
435
|
+
if kwargs:
|
|
436
|
+
raise ValueError("Wrong instantiation")
|
|
437
|
+
return
|
|
438
|
+
self.args = Args(*args)
|
|
439
|
+
self.kwargs = Kwargs(**kwargs)
|
|
440
|
+
|
|
441
|
+
def to_args_kwargs(self):
|
|
442
|
+
return tuple(self.args), dict(self.kwargs)
|
|
443
|
+
|
|
444
|
+
def unwrap(self):
|
|
445
|
+
return self.to_args_kwargs()
|
|
446
|
+
|
|
447
|
+
def __repr__(self):
|
|
448
|
+
args = ", ".join(repr(a) for a in self.args)
|
|
449
|
+
kwargs = ", ".join(f"{k}={v!r}" for k, v in self.items())
|
|
450
|
+
args_kwargs = ", ".join([args, kwargs])
|
|
451
|
+
return f"{self.__class__.__name__}({args_kwargs})"
|
|
452
|
+
|
|
453
|
+
def __iter__(self):
|
|
454
|
+
# Iterate across values of both args and kwargs, in that order.
|
|
455
|
+
# This allows `ArgsAndKwargs` to act like the concatenation of
|
|
456
|
+
# an `Args` and a `Kwargs`.
|
|
457
|
+
for a in self.args:
|
|
458
|
+
yield a
|
|
459
|
+
for v in self.kwargs:
|
|
460
|
+
yield v
|
|
461
|
+
|
|
462
|
+
def __len__(self):
|
|
463
|
+
return len(self.args) + len(self.kwargs)
|
|
464
|
+
|
|
465
|
+
def __getitem__(self, index):
|
|
466
|
+
if isinstance(index, int):
|
|
467
|
+
return self.args[index]
|
|
468
|
+
else:
|
|
469
|
+
return self.kwargs[index]
|
|
470
|
+
|
|
471
|
+
def keys(self):
|
|
472
|
+
return self.Keys(self)
|
|
473
|
+
|
|
474
|
+
def values(self):
|
|
475
|
+
return self.Values(self)
|
|
476
|
+
|
|
477
|
+
def items(self):
|
|
478
|
+
return self.Items(self)
|
|
479
|
+
|
|
480
|
+
class Keys:
|
|
481
|
+
def __init__(self, parent):
|
|
482
|
+
self.parent = parent
|
|
483
|
+
|
|
484
|
+
def __iter__(self):
|
|
485
|
+
for i in range(len(self.parent.args)):
|
|
486
|
+
yield i
|
|
487
|
+
for k in self.parent.kwargs.keys():
|
|
488
|
+
yield k
|
|
489
|
+
|
|
490
|
+
def __repr__(self):
|
|
491
|
+
return f"Keys({list(self)!r})"
|
|
492
|
+
|
|
493
|
+
class Values:
|
|
494
|
+
def __init__(self, parent):
|
|
495
|
+
self.parent = parent
|
|
496
|
+
|
|
497
|
+
def __iter__(self):
|
|
498
|
+
for a in self.parent.args:
|
|
499
|
+
yield a
|
|
500
|
+
for v in self.parent.kwargs.values():
|
|
501
|
+
yield v
|
|
502
|
+
|
|
503
|
+
def __repr__(self):
|
|
504
|
+
return f"Values({list(self)!r})"
|
|
505
|
+
|
|
506
|
+
class Items:
|
|
507
|
+
def __init__(self, parent):
|
|
508
|
+
self.parent = parent
|
|
509
|
+
|
|
510
|
+
def __iter__(self):
|
|
511
|
+
for i, a in enumerate(self.parent.args):
|
|
512
|
+
yield (i, a)
|
|
513
|
+
for k, v in self.parent.kwargs.items():
|
|
514
|
+
yield (k, v)
|
|
515
|
+
|
|
516
|
+
def __repr__(self):
|
|
517
|
+
return f"Items({list(self)!r})"
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
class Returned:
|
|
521
|
+
"""Internal object used to mark that this is an object returned
|
|
522
|
+
by `transform_tensor` at the most nested level"""
|
|
523
|
+
def __init__(self, obj):
|
|
524
|
+
self.obj = obj
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
class VirtualTensor:
|
|
528
|
+
"""Virtual tensor used to recursively compute final transforms"""
|
|
529
|
+
|
|
530
|
+
def __init__(self, shape, dtype=None, device=None,
|
|
531
|
+
vmin=None, vmax=None, vmean=None):
|
|
532
|
+
self.shape = shape
|
|
533
|
+
self.dtype = dtype
|
|
534
|
+
self.device = device
|
|
535
|
+
self.vmin = vmin
|
|
536
|
+
self.vmax = vmax
|
|
537
|
+
self.vmean = vmean
|
|
538
|
+
|
|
539
|
+
@classmethod
|
|
540
|
+
def from_tensor(cls, x, compute_stats=False):
|
|
541
|
+
if compute_stats:
|
|
542
|
+
vmin = x.reshape([len(x), -1]).min(dim=-1).values
|
|
543
|
+
vmax = x.reshape([len(x), -1]).max(dim=-1).values
|
|
544
|
+
vmean = x.float().mean(dim=list(range(1, x.ndim)))
|
|
545
|
+
else:
|
|
546
|
+
vmin = vmax = vmean = None
|
|
547
|
+
return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
|
|
548
|
+
vmin=vmin, vmax=vmax, vmean=vmean)
|
|
549
|
+
|
|
550
|
+
@classmethod
|
|
551
|
+
def from_virtual(cls, x):
|
|
552
|
+
return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
|
|
553
|
+
vmin=x.vmin, vmax=x.vmax, vmean=x.vmean)
|
|
554
|
+
|
|
555
|
+
@classmethod
|
|
556
|
+
def from_any(cls, x, compute_stats=False):
|
|
557
|
+
if torch.is_tensor(x):
|
|
558
|
+
return cls.from_tensor(x, compute_stats)
|
|
559
|
+
elif isinstance(x, VirtualTensor):
|
|
560
|
+
return cls.from_virtual(x)
|
|
561
|
+
elif isinstance(x, (list, tuple, torch.Size)):
|
|
562
|
+
return VirtualTensor(x)
|
|
563
|
+
else:
|
|
564
|
+
raise TypeError(f"Don't know how to convert type {type(x)} "
|
|
565
|
+
f"to VirtualTensor")
|
|
566
|
+
|
|
567
|
+
def __getitem__(self, index):
|
|
568
|
+
outshape = guess_shape(index, self.shape)
|
|
569
|
+
return VirtualTensor(
|
|
570
|
+
outshape, self.dtype, self.device,
|
|
571
|
+
self.vmin, self.vmax, self.vmean
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
UNSET = object()
|