cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,575 @@
1
+ # stdlib
2
+ from collections import abc
3
+ from collections.abc import Mapping, Sequence
4
+
5
+ # dependencies
6
+ import torch
7
+ import typing_extensions as tx
8
+
9
+ # internal
10
+ from .utils.indexing import guess_shape
11
+ from .utils.py import ensure_list
12
+
13
+
14
+ def get_first_element(x, include=None, exclude=None, types=None):
15
+ """Return the fist element (tensor or string) in the nested structure"""
16
+ types = ensure_list(types or [])
17
+
18
+ def _recursive(x):
19
+ if hasattr(x, 'items'):
20
+ for k, v in x.items():
21
+ if include and k not in include:
22
+ continue
23
+ if exclude and k in exclude:
24
+ continue
25
+ v, ok = _recursive(v)
26
+ if ok:
27
+ return v, True
28
+ return None, False
29
+ if isinstance(x, (list, tuple)):
30
+ for v in x:
31
+ v, ok = _recursive(v)
32
+ if ok:
33
+ return v, True
34
+ return None, False
35
+ if torch.is_tensor(x) or (types and isinstance(x, types)):
36
+ return x, True
37
+ return x, False
38
+
39
+ return _recursive(x)[0]
40
+
41
+
42
+ def recursive_cat(x, **kwargs):
43
+ """Concatenate tensors across the channel axis in a nested structure"""
44
+ def _rec(*x):
45
+ if all(torch.is_tensor(x1) for x1 in x):
46
+ return torch.cat(x, **kwargs)
47
+ if isinstance(x[0], (list, tuple)):
48
+ return type(x[0])(_rec(*x1) for x1 in zip(*x))
49
+ if hasattr(x[0], 'items'):
50
+ return type(x[0])(**{k: _rec(*x2)
51
+ for k in x[0].keys()
52
+ for x2 in zip(x1[k] for x1 in x)})
53
+ raise TypeError(f'What should I do with a {type(x[0])}?')
54
+ return _rec(*x)
55
+
56
+
57
+ def prepare_output(results, returns):
58
+ """Prepare object returned by `apply_transform`
59
+
60
+ Parameters
61
+ ----------
62
+ results : dict[str, tensor]
63
+ Named results
64
+ returns : list[str] or dict[str, str] or str
65
+ Structure describing which results should be returned.
66
+ The results will be returned in an object of the same type, with
67
+ the requested results associated to the same keys (if `dict`) or
68
+ same position (if `list`). If a `str`, the requested tensor is
69
+ returned.
70
+
71
+ Returns
72
+ -------
73
+ requested_results : list[tensor] or dict[str, tensor] or tensor
74
+
75
+ """
76
+ if returns is None:
77
+ if torch.is_tensor(results):
78
+ pass
79
+ else:
80
+ results = results.get('output', None)
81
+ elif isinstance(returns, dict):
82
+ results = type(returns)(
83
+ **{key: results.get(target, None)
84
+ for key, target in returns.items()})
85
+ elif isinstance(returns, (list, tuple)):
86
+ results = type(returns)(
87
+ [results.get(target, None) for target in returns])
88
+ else:
89
+ results = results.get(returns, None)
90
+ return Returned(results)
91
+
92
+
93
+ def return_requires(returns) -> tx.Set[str]:
94
+ """Return all requires fields in a flat structure"""
95
+ if returns is None:
96
+ return {'output'}
97
+ returns = flatstruct(returns)
98
+ if isinstance(returns, dict):
99
+ return set(returns.values())
100
+ elif isinstance(returns, (list, tuple, set)):
101
+ return set(returns)
102
+ else:
103
+ return {returns}
104
+
105
+
106
+ def returns_find(flag, returned, returns):
107
+ """Find tensor corresponding to flag in returned structure"""
108
+ if returns is None:
109
+ if flag == 'output':
110
+ return returned
111
+ else:
112
+ return None
113
+ if isinstance(returns, dict):
114
+ for key, value in returns.items():
115
+ if value == flag:
116
+ return returned.get(key, None)
117
+ elif isinstance(returns, (list, tuple, set)):
118
+ if flag in returns:
119
+ return returned[returns.index(flag)]
120
+ else:
121
+ return None
122
+ else:
123
+ assert isinstance(returns, str)
124
+ if returns == flag:
125
+ return returned
126
+ else:
127
+ return None
128
+
129
+
130
+ def returns_update(value, flag, returned, returns):
131
+ """Find tensor corresponding to flag in returned structure"""
132
+ if returns is None:
133
+ if flag == 'output':
134
+ return value
135
+ else:
136
+ return None
137
+ if isinstance(returns, dict):
138
+ if flag in returns:
139
+ returned[flag] = value
140
+ return returned
141
+ elif isinstance(returns, (list, tuple)):
142
+ if flag in returns:
143
+ returned[returns.index(flag)] = value
144
+ return returned
145
+ else:
146
+ assert isinstance(returns, str)
147
+ if returns == flag:
148
+ return value
149
+ else:
150
+ return None
151
+
152
+
153
+ def flatstruct(x):
154
+ """Flatten a nested structure of tensors"""
155
+
156
+ def _flatten(nested):
157
+ if isinstance(nested, dict):
158
+ flat = type(nested)()
159
+ is_dict = True
160
+ for key, elem in nested.items():
161
+ elem = _flatten(elem)
162
+ if isinstance(elem, dict):
163
+ for subkey, subelem in elem.items():
164
+ flat[subkey] = subelem
165
+ elif not isinstance(elem, (dict, list, tuple)):
166
+ flat[key] = elem
167
+ else:
168
+ is_dict = False
169
+ flat[key] = elem
170
+ if not is_dict:
171
+ flat, flatdict = [], flat
172
+ for elem in flatdict.values():
173
+ if not isinstance(elem, (dict, list, tuple)):
174
+ flat.append(elem)
175
+ else:
176
+ flat.extend(elem)
177
+ return flat
178
+ elif isinstance(nested, (list, tuple)):
179
+ flat = []
180
+ for elem in nested:
181
+ elem = _flatten(elem)
182
+ if not isinstance(elem, (dict, list, tuple)):
183
+ flat.append(elem)
184
+ elif isinstance(elem, dict):
185
+ flat.extend(elem.values())
186
+ else:
187
+ flat.extend(elem)
188
+ flat = type(nested)(flat)
189
+ return flat
190
+ else:
191
+ return nested
192
+
193
+ return _flatten(x)
194
+
195
+
196
+ def nested_get(nested, *keys, default=None):
197
+ """Get value from a nested structure of dicts/lists/tuples"""
198
+ if not keys:
199
+ return nested
200
+ if isinstance(nested, abc.Mapping):
201
+ if keys[0] in nested:
202
+ return nested_get(nested[keys[0]], *keys[1:], default=default)
203
+ else:
204
+ return default
205
+ elif isinstance(nested, abc.Sequence) and not isinstance(nested, str):
206
+ if isinstance(keys[0], int) and 0 <= keys[0] < len(nested):
207
+ return nested_get(nested[keys[0]], *keys[1:], default=default)
208
+ else:
209
+ return default
210
+ else:
211
+ raise TypeError(
212
+ f'Cannot get key {keys[0]} from object of type {type(nested)}'
213
+ )
214
+
215
+
216
+ class Arguments:
217
+ """Base class for wrapping arguments of a transform call.
218
+
219
+ No instance of this class are ever created. Instead, it returns
220
+ instances of one of its concrete subclasses:
221
+
222
+ - `NoArguments`: when no arguments are passed
223
+ - `Arg`: when a single argument is passed
224
+ - `Args`: when only positional arguments are passed
225
+ - `Kwargs`: when only keyword arguments are passed
226
+ - `ArgsAndKwargs`: when both positional and keyword arguments are passed
227
+ """
228
+
229
+ def __new__(cls, *args, **kwargs):
230
+ if cls is Arguments:
231
+ if not kwargs and len(args) == 1:
232
+ arg, = args
233
+ if not isinstance(arg, Arguments):
234
+ arg = Arg(arg)
235
+ return arg
236
+ if args and kwargs:
237
+ return ArgsAndKwargs(*args, **kwargs)
238
+ elif args:
239
+ return Args(*args)
240
+ elif kwargs:
241
+ return Kwargs(**kwargs)
242
+ else:
243
+ return NoArguments()
244
+ return super().__new__(cls)
245
+
246
+ def __init__(self, *args, **kwargs):
247
+ pass
248
+
249
+ def __str__(self):
250
+ return repr(self)
251
+
252
+
253
+ class NoArguments(Arguments):
254
+ """Wrapper when no arguments where passed."""
255
+
256
+ def __bool__(self):
257
+ return False
258
+
259
+ def __len__(self):
260
+ return 0
261
+
262
+ def __iter__(self):
263
+ return
264
+
265
+ def keys(self):
266
+ return set()
267
+
268
+ def values(self):
269
+ return set()
270
+
271
+ def items(self):
272
+ return set()
273
+
274
+ def unwrap(self):
275
+ return None
276
+
277
+ def to_args_kwargs(self):
278
+ return (), {}
279
+
280
+ def __repr__(self):
281
+ return "NoArguments()"
282
+
283
+
284
+ class Arg(Arguments):
285
+ """Single argument"""
286
+
287
+ def __new__(cls, arg):
288
+ if cls is Arg:
289
+ if isinstance(arg, Arguments):
290
+ return arg
291
+ if isinstance(arg, Mapping):
292
+ return DictArg(arg)
293
+ if isinstance(arg, Sequence) and not isinstance(arg, str):
294
+ return TupleArg(arg)
295
+ return super().__new__(cls)
296
+
297
+ def __init__(self, arg):
298
+ if arg is self:
299
+ # This can happen when calling `Arg` on an `Arg`.
300
+ # Our `__new__` should return the input object instead of
301
+ # creating a new one, but because this object _is_ an instance
302
+ # of `Arg`, `__init__` is still called.
303
+ return
304
+ self.arg = arg
305
+
306
+ def unwrap(self):
307
+ return self.arg
308
+
309
+ def to_args_kwargs(self):
310
+ return (self.arg,), {}
311
+
312
+ def __repr__(self):
313
+ return f"{self.__class__.__name__}({self.arg!r})"
314
+
315
+
316
+ class DictArg(Arg, Mapping):
317
+ """Single argument that is a mapping."""
318
+
319
+ def __len__(self):
320
+ return len(self.arg)
321
+
322
+ def __iter__(self):
323
+ return iter(self.arg)
324
+
325
+ def __getitem__(self, key):
326
+ return self.arg[key]
327
+
328
+
329
+ class TupleArg(Arg, Sequence):
330
+ """Single argument that is a sequence."""
331
+
332
+ def __len__(self):
333
+ return len(self.arg)
334
+
335
+ def __getitem__(self, index):
336
+ return self.arg[index]
337
+
338
+
339
+ class Args(tuple, Arguments):
340
+ """Tuple of arguments: `*args`"""
341
+
342
+ def __new__(cls, *args):
343
+ # Call tuple directly to avoid conflict betweet tuple.__new__
344
+ # (which expects a single arg) and Arguments.__new__
345
+ return tuple.__new__(cls, args)
346
+
347
+ def __init__(self, *args):
348
+ # Call tuple directly to avoid conflict betweet tuple.__new__
349
+ # (which expects a single arg) and Arguments.__init__ (which
350
+ # expects no arg)
351
+ tuple.__init__(args)
352
+
353
+ def unwrap(self):
354
+ return tuple(self)
355
+
356
+ def to_args_kwargs(self):
357
+ return tuple(self), {}
358
+
359
+ def __repr__(self):
360
+ args = ", ".join(repr(a) for a in self)
361
+ return f"{self.__class__.__name__}({args})"
362
+
363
+ class Keys:
364
+ def __init__(self, parent):
365
+ self.parent = parent
366
+
367
+ def __iter__(self):
368
+ for i in range(len(self.parent)):
369
+ yield i
370
+
371
+ def __repr__(self):
372
+ return f"Keys({list(self)!r})"
373
+
374
+ class Values:
375
+ def __init__(self, parent):
376
+ self.parent = parent
377
+
378
+ def __iter__(self):
379
+ for a in self.parent:
380
+ yield a
381
+
382
+ def __repr__(self):
383
+ return f"Values({list(self)!r})"
384
+
385
+ class Items:
386
+ def __init__(self, parent):
387
+ self.parent = parent
388
+
389
+ def __iter__(self):
390
+ for i, a in enumerate(self.parent):
391
+ yield (i, a)
392
+
393
+ def __repr__(self):
394
+ return f"Items({list(self)!r})"
395
+
396
+
397
+ class Kwargs(dict, Arguments):
398
+ """Dict-like, except that unzipping works on values instead of keys"""
399
+
400
+ def __init__(self, *args, **kwargs):
401
+ if args:
402
+ # Catch case where Arguments(...) was called on a Kwargs.
403
+ # In this case, we want to keep the Kwargs as is, and not
404
+ # call __init__ again.
405
+ if kwargs or len(args) != 1 or not isinstance(args[0], Kwargs):
406
+ raise ValueError("Wrong instantiation")
407
+ return
408
+ super().__init__(**kwargs)
409
+
410
+ def unwrap(self):
411
+ return dict(self)
412
+
413
+ def to_args_kwargs(self):
414
+ return (), dict(self)
415
+
416
+ def __repr__(self):
417
+ kwargs = ", ".join(f"{k}={v!r}" for k, v in self.items())
418
+ return f"{self.__class__.__name__}({kwargs})"
419
+
420
+ def __iter__(self):
421
+ # Iterate across values instead of keys.
422
+ # This allows `Kwargs` to act like a tuple of values.
423
+ for v in self.values():
424
+ yield v
425
+
426
+
427
+ class ArgsAndKwargs(Arguments):
428
+ """Iterator across both args and kwargs"""
429
+
430
+ def __init__(self, *args, **kwargs):
431
+ if len(args) == 1 and isinstance(args[0], ArgsAndKwargs):
432
+ # Catch case where Arguments(...) was called on a ArgsAndKwargs.
433
+ # In this case, we want to keep the ArgsAndKwargs as is, and not
434
+ # call __init__ again.
435
+ if kwargs:
436
+ raise ValueError("Wrong instantiation")
437
+ return
438
+ self.args = Args(*args)
439
+ self.kwargs = Kwargs(**kwargs)
440
+
441
+ def to_args_kwargs(self):
442
+ return tuple(self.args), dict(self.kwargs)
443
+
444
+ def unwrap(self):
445
+ return self.to_args_kwargs()
446
+
447
+ def __repr__(self):
448
+ args = ", ".join(repr(a) for a in self.args)
449
+ kwargs = ", ".join(f"{k}={v!r}" for k, v in self.items())
450
+ args_kwargs = ", ".join([args, kwargs])
451
+ return f"{self.__class__.__name__}({args_kwargs})"
452
+
453
+ def __iter__(self):
454
+ # Iterate across values of both args and kwargs, in that order.
455
+ # This allows `ArgsAndKwargs` to act like the concatenation of
456
+ # an `Args` and a `Kwargs`.
457
+ for a in self.args:
458
+ yield a
459
+ for v in self.kwargs:
460
+ yield v
461
+
462
+ def __len__(self):
463
+ return len(self.args) + len(self.kwargs)
464
+
465
+ def __getitem__(self, index):
466
+ if isinstance(index, int):
467
+ return self.args[index]
468
+ else:
469
+ return self.kwargs[index]
470
+
471
+ def keys(self):
472
+ return self.Keys(self)
473
+
474
+ def values(self):
475
+ return self.Values(self)
476
+
477
+ def items(self):
478
+ return self.Items(self)
479
+
480
+ class Keys:
481
+ def __init__(self, parent):
482
+ self.parent = parent
483
+
484
+ def __iter__(self):
485
+ for i in range(len(self.parent.args)):
486
+ yield i
487
+ for k in self.parent.kwargs.keys():
488
+ yield k
489
+
490
+ def __repr__(self):
491
+ return f"Keys({list(self)!r})"
492
+
493
+ class Values:
494
+ def __init__(self, parent):
495
+ self.parent = parent
496
+
497
+ def __iter__(self):
498
+ for a in self.parent.args:
499
+ yield a
500
+ for v in self.parent.kwargs.values():
501
+ yield v
502
+
503
+ def __repr__(self):
504
+ return f"Values({list(self)!r})"
505
+
506
+ class Items:
507
+ def __init__(self, parent):
508
+ self.parent = parent
509
+
510
+ def __iter__(self):
511
+ for i, a in enumerate(self.parent.args):
512
+ yield (i, a)
513
+ for k, v in self.parent.kwargs.items():
514
+ yield (k, v)
515
+
516
+ def __repr__(self):
517
+ return f"Items({list(self)!r})"
518
+
519
+
520
+ class Returned:
521
+ """Internal object used to mark that this is an object returned
522
+ by `transform_tensor` at the most nested level"""
523
+ def __init__(self, obj):
524
+ self.obj = obj
525
+
526
+
527
+ class VirtualTensor:
528
+ """Virtual tensor used to recursively compute final transforms"""
529
+
530
+ def __init__(self, shape, dtype=None, device=None,
531
+ vmin=None, vmax=None, vmean=None):
532
+ self.shape = shape
533
+ self.dtype = dtype
534
+ self.device = device
535
+ self.vmin = vmin
536
+ self.vmax = vmax
537
+ self.vmean = vmean
538
+
539
+ @classmethod
540
+ def from_tensor(cls, x, compute_stats=False):
541
+ if compute_stats:
542
+ vmin = x.reshape([len(x), -1]).min(dim=-1).values
543
+ vmax = x.reshape([len(x), -1]).max(dim=-1).values
544
+ vmean = x.float().mean(dim=list(range(1, x.ndim)))
545
+ else:
546
+ vmin = vmax = vmean = None
547
+ return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
548
+ vmin=vmin, vmax=vmax, vmean=vmean)
549
+
550
+ @classmethod
551
+ def from_virtual(cls, x):
552
+ return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
553
+ vmin=x.vmin, vmax=x.vmax, vmean=x.vmean)
554
+
555
+ @classmethod
556
+ def from_any(cls, x, compute_stats=False):
557
+ if torch.is_tensor(x):
558
+ return cls.from_tensor(x, compute_stats)
559
+ elif isinstance(x, VirtualTensor):
560
+ return cls.from_virtual(x)
561
+ elif isinstance(x, (list, tuple, torch.Size)):
562
+ return VirtualTensor(x)
563
+ else:
564
+ raise TypeError(f"Don't know how to convert type {type(x)} "
565
+ f"to VirtualTensor")
566
+
567
+ def __getitem__(self, index):
568
+ outshape = guess_shape(index, self.shape)
569
+ return VirtualTensor(
570
+ outshape, self.dtype, self.device,
571
+ self.vmin, self.vmax, self.vmean
572
+ )
573
+
574
+
575
+ UNSET = object()