careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,10 @@ from .struct_mask_parameters import StructMaskParameters
|
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def _apply_struct_mask(
|
|
16
|
-
patch: np.ndarray,
|
|
16
|
+
patch: np.ndarray,
|
|
17
|
+
coords: np.ndarray,
|
|
18
|
+
struct_params: StructMaskParameters,
|
|
19
|
+
rng: Optional[np.random.Generator] = None,
|
|
17
20
|
) -> np.ndarray:
|
|
18
21
|
"""Apply structN2V masks to patch.
|
|
19
22
|
|
|
@@ -31,12 +34,17 @@ def _apply_struct_mask(
|
|
|
31
34
|
Coordinates of the ROI(subpatch) centers.
|
|
32
35
|
struct_params : StructMaskParameters
|
|
33
36
|
Parameters for the structN2V mask (axis and span).
|
|
37
|
+
rng : np.random.Generator or None
|
|
38
|
+
Random number generator.
|
|
34
39
|
|
|
35
40
|
Returns
|
|
36
41
|
-------
|
|
37
42
|
np.ndarray
|
|
38
43
|
Patch with the structN2V mask applied.
|
|
39
44
|
"""
|
|
45
|
+
if rng is None:
|
|
46
|
+
rng = np.random.default_rng()
|
|
47
|
+
|
|
40
48
|
# relative axis
|
|
41
49
|
moving_axis = -1 - struct_params.axis
|
|
42
50
|
|
|
@@ -67,7 +75,7 @@ def _apply_struct_mask(
|
|
|
67
75
|
mix = np.delete(mix, mix[:, moving_axis] > max_bound, axis=0)
|
|
68
76
|
|
|
69
77
|
# replace neighbouring pixels with random values from flat dist
|
|
70
|
-
patch[tuple(mix.T)] =
|
|
78
|
+
patch[tuple(mix.T)] = rng.uniform(patch.min(), patch.max(), size=mix.shape[0])
|
|
71
79
|
|
|
72
80
|
return patch
|
|
73
81
|
|
|
@@ -98,7 +106,9 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray:
|
|
|
98
106
|
|
|
99
107
|
|
|
100
108
|
def _get_stratified_coords(
|
|
101
|
-
mask_pixel_perc: float,
|
|
109
|
+
mask_pixel_perc: float,
|
|
110
|
+
shape: Tuple[int, ...],
|
|
111
|
+
rng: Optional[np.random.Generator] = None,
|
|
102
112
|
) -> np.ndarray:
|
|
103
113
|
"""
|
|
104
114
|
Generate coordinates of the pixels to mask.
|
|
@@ -113,6 +123,8 @@ def _get_stratified_coords(
|
|
|
113
123
|
calculating the distance between masked pixels across each axis.
|
|
114
124
|
shape : Tuple[int, ...]
|
|
115
125
|
Shape of the input patch.
|
|
126
|
+
rng : np.random.Generator or None
|
|
127
|
+
Random number generator.
|
|
116
128
|
|
|
117
129
|
Returns
|
|
118
130
|
-------
|
|
@@ -124,7 +136,8 @@ def _get_stratified_coords(
|
|
|
124
136
|
"Calculating coordinates is only possible for 2D and 3D patches"
|
|
125
137
|
)
|
|
126
138
|
|
|
127
|
-
rng
|
|
139
|
+
if rng is None:
|
|
140
|
+
rng = np.random.default_rng()
|
|
128
141
|
|
|
129
142
|
mask_pixel_distance = np.round((100 / mask_pixel_perc) ** (1 / len(shape))).astype(
|
|
130
143
|
np.int32
|
|
@@ -228,6 +241,7 @@ def uniform_manipulate(
|
|
|
228
241
|
subpatch_size: int = 11,
|
|
229
242
|
remove_center: bool = True,
|
|
230
243
|
struct_params: Optional[StructMaskParameters] = None,
|
|
244
|
+
rng: Optional[np.random.Generator] = None,
|
|
231
245
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
232
246
|
"""
|
|
233
247
|
Manipulate pixels by replacing them with a neighbor values.
|
|
@@ -248,19 +262,23 @@ def uniform_manipulate(
|
|
|
248
262
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
249
263
|
remove_center : bool
|
|
250
264
|
Whether to remove the center pixel from the subpatch, by default False.
|
|
251
|
-
struct_params :
|
|
265
|
+
struct_params : StructMaskParameters or None
|
|
252
266
|
Parameters for the structN2V mask (axis and span).
|
|
267
|
+
rng : np.random.Generator or None
|
|
268
|
+
Random number generator.
|
|
253
269
|
|
|
254
270
|
Returns
|
|
255
271
|
-------
|
|
256
272
|
Tuple[np.ndarray]
|
|
257
273
|
Tuple containing the manipulated patch and the corresponding mask.
|
|
258
274
|
"""
|
|
275
|
+
if rng is None:
|
|
276
|
+
rng = np.random.default_rng()
|
|
277
|
+
|
|
259
278
|
# Get the coordinates of the pixels to be replaced
|
|
260
279
|
transformed_patch = patch.copy()
|
|
261
280
|
|
|
262
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
263
|
-
rng = np.random.default_rng()
|
|
281
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
264
282
|
|
|
265
283
|
# Generate coordinate grid for subpatch
|
|
266
284
|
roi_span_full = np.arange(
|
|
@@ -303,6 +321,7 @@ def median_manipulate(
|
|
|
303
321
|
mask_pixel_percentage: float,
|
|
304
322
|
subpatch_size: int = 11,
|
|
305
323
|
struct_params: Optional[StructMaskParameters] = None,
|
|
324
|
+
rng: Optional[np.random.Generator] = None,
|
|
306
325
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
307
326
|
"""
|
|
308
327
|
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
@@ -322,18 +341,23 @@ def median_manipulate(
|
|
|
322
341
|
Approximate percentage of pixels to be masked.
|
|
323
342
|
subpatch_size : int
|
|
324
343
|
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
325
|
-
struct_params :
|
|
344
|
+
struct_params : StructMaskParameters or None, optional
|
|
326
345
|
Parameters for the structN2V mask (axis and span).
|
|
346
|
+
rng : np.random.Generator or None, optional
|
|
347
|
+
Random number generato, by default None.
|
|
327
348
|
|
|
328
349
|
Returns
|
|
329
350
|
-------
|
|
330
351
|
Tuple[np.ndarray]
|
|
331
352
|
Tuple containing the manipulated patch, the original patch and the mask.
|
|
332
353
|
"""
|
|
354
|
+
if rng is None:
|
|
355
|
+
rng = np.random.default_rng()
|
|
356
|
+
|
|
333
357
|
transformed_patch = patch.copy()
|
|
334
358
|
|
|
335
359
|
# Get the coordinates of the pixels to be replaced
|
|
336
|
-
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape)
|
|
360
|
+
subpatch_centers = _get_stratified_coords(mask_pixel_percentage, patch.shape, rng)
|
|
337
361
|
|
|
338
362
|
# Generate coordinate grid for subpatch
|
|
339
363
|
roi_span = np.array(
|
careamics/transforms/tta.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
|
1
1
|
"""Test-time augmentations."""
|
|
2
2
|
|
|
3
|
-
from typing import List
|
|
4
|
-
|
|
5
3
|
from torch import Tensor, flip, mean, rot90, stack
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
# TODO add tests
|
|
9
6
|
class ImageRestorationTTA:
|
|
10
7
|
"""
|
|
11
8
|
Test-time augmentation for image restoration tasks.
|
|
@@ -13,62 +10,79 @@ class ImageRestorationTTA:
|
|
|
13
10
|
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
14
11
|
as well as the original image flipped.
|
|
15
12
|
|
|
16
|
-
Tensors should be of shape SC(Z)YX
|
|
13
|
+
Tensors should be of shape SC(Z)YX.
|
|
17
14
|
|
|
18
15
|
This transformation is used in the LightningModule in order to perform test-time
|
|
19
|
-
|
|
16
|
+
augmentation.
|
|
20
17
|
"""
|
|
21
18
|
|
|
22
|
-
def
|
|
23
|
-
"""Constructor."""
|
|
24
|
-
pass
|
|
25
|
-
|
|
26
|
-
def forward(self, x: Tensor) -> List[Tensor]:
|
|
19
|
+
def forward(self, input_tensor: Tensor) -> list[Tensor]:
|
|
27
20
|
"""
|
|
28
21
|
Apply test-time augmentation to the input tensor.
|
|
29
22
|
|
|
30
23
|
Parameters
|
|
31
24
|
----------
|
|
32
|
-
|
|
25
|
+
input_tensor : Tensor
|
|
33
26
|
Input tensor, shape SC(Z)YX.
|
|
34
27
|
|
|
35
28
|
Returns
|
|
36
29
|
-------
|
|
37
|
-
|
|
30
|
+
list of torch.Tensor
|
|
38
31
|
List of augmented tensors.
|
|
39
32
|
"""
|
|
33
|
+
# axes: only applies to YX axes
|
|
34
|
+
axes = (-2, -1)
|
|
35
|
+
|
|
40
36
|
augmented = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
rot90(
|
|
37
|
+
# original
|
|
38
|
+
input_tensor,
|
|
39
|
+
# rotations
|
|
40
|
+
rot90(input_tensor, 1, dims=axes),
|
|
41
|
+
rot90(input_tensor, 2, dims=axes),
|
|
42
|
+
rot90(input_tensor, 3, dims=axes),
|
|
43
|
+
# original flipped
|
|
44
|
+
flip(input_tensor, dims=(axes[0],)),
|
|
45
|
+
flip(input_tensor, dims=(axes[1],)),
|
|
45
46
|
]
|
|
46
|
-
augmented_flip = augmented.copy()
|
|
47
|
-
for x_ in augmented:
|
|
48
|
-
augmented_flip.append(flip(x_, dims=(-3, -1)))
|
|
49
|
-
return augmented_flip
|
|
50
47
|
|
|
51
|
-
|
|
48
|
+
# rotated once, flipped
|
|
49
|
+
augmented.extend(
|
|
50
|
+
[
|
|
51
|
+
flip(augmented[1], dims=(axes[0],)),
|
|
52
|
+
flip(augmented[1], dims=(axes[1],)),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return augmented
|
|
57
|
+
|
|
58
|
+
def backward(self, x: list[Tensor]) -> Tensor:
|
|
52
59
|
"""Undo the test-time augmentation.
|
|
53
60
|
|
|
54
61
|
Parameters
|
|
55
62
|
----------
|
|
56
63
|
x : Any
|
|
57
|
-
List of augmented tensors.
|
|
64
|
+
List of augmented tensors of shape SC(Z)YX.
|
|
58
65
|
|
|
59
66
|
Returns
|
|
60
67
|
-------
|
|
61
68
|
Any
|
|
62
69
|
Original tensor.
|
|
63
70
|
"""
|
|
71
|
+
axes = (-2, -1)
|
|
72
|
+
|
|
64
73
|
reverse = [
|
|
74
|
+
# original
|
|
65
75
|
x[0],
|
|
66
|
-
|
|
67
|
-
rot90(x[
|
|
68
|
-
rot90(x[
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
# rotated
|
|
77
|
+
rot90(x[1], -1, dims=axes),
|
|
78
|
+
rot90(x[2], -2, dims=axes),
|
|
79
|
+
rot90(x[3], -3, dims=axes),
|
|
80
|
+
# original flipped
|
|
81
|
+
flip(x[4], dims=(axes[0],)),
|
|
82
|
+
flip(x[5], dims=(axes[1],)),
|
|
83
|
+
# rotated once, flipped
|
|
84
|
+
rot90(flip(x[6], dims=(axes[0],)), -1, dims=axes),
|
|
85
|
+
rot90(flip(x[7], dims=(axes[1],)), -1, dims=axes),
|
|
73
86
|
]
|
|
87
|
+
|
|
74
88
|
return mean(stack(reverse), dim=0)
|
careamics/utils/ram.py
CHANGED
|
@@ -5,11 +5,11 @@ import psutil
|
|
|
5
5
|
|
|
6
6
|
def get_ram_size() -> int:
|
|
7
7
|
"""
|
|
8
|
-
Get RAM size in
|
|
8
|
+
Get RAM size in mbytes.
|
|
9
9
|
|
|
10
10
|
Returns
|
|
11
11
|
-------
|
|
12
12
|
int
|
|
13
13
|
RAM size in mbytes.
|
|
14
14
|
"""
|
|
15
|
-
return psutil.virtual_memory().
|
|
15
|
+
return psutil.virtual_memory().available / 1024**2
|
|
@@ -1,31 +1,32 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0rc7
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
7
|
-
Author-email:
|
|
7
|
+
Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
|
|
8
8
|
License: BSD-3-Clause
|
|
9
9
|
License-File: LICENSE
|
|
10
10
|
Classifier: Development Status :: 3 - Alpha
|
|
11
11
|
Classifier: License :: OSI Approved :: BSD License
|
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
17
|
Classifier: Typing :: Typed
|
|
19
|
-
Requires-Python: >=3.
|
|
18
|
+
Requires-Python: >=3.9
|
|
20
19
|
Requires-Dist: bioimageio-core>=0.6.0
|
|
20
|
+
Requires-Dist: numpy<2.0.0
|
|
21
21
|
Requires-Dist: psutil
|
|
22
22
|
Requires-Dist: pydantic>=2.5
|
|
23
23
|
Requires-Dist: pytorch-lightning>=2.2.0
|
|
24
24
|
Requires-Dist: pyyaml
|
|
25
|
-
Requires-Dist: scikit-image
|
|
25
|
+
Requires-Dist: scikit-image<=0.23.2
|
|
26
26
|
Requires-Dist: tifffile
|
|
27
27
|
Requires-Dist: torch>=2.0.0
|
|
28
|
-
Requires-Dist:
|
|
28
|
+
Requires-Dist: torchvision
|
|
29
|
+
Requires-Dist: zarr<3.0.0
|
|
29
30
|
Provides-Extra: dev
|
|
30
31
|
Requires-Dist: pre-commit; extra == 'dev'
|
|
31
32
|
Requires-Dist: pytest; extra == 'dev'
|
|
@@ -1,28 +1,27 @@
|
|
|
1
1
|
careamics/__init__.py,sha256=DkMGt4t9ua0gCgvZFEtb6eydvoxG976T0KUro8KnDNA,760
|
|
2
|
-
careamics/careamist.py,sha256=
|
|
2
|
+
careamics/careamist.py,sha256=BtCJWXD4zlKAo05acHj-k-r7wdBcH9eHMQQ6x0wERjo,25911
|
|
3
3
|
careamics/conftest.py,sha256=Od4WcaaP0UP-XUMrFr_oo4e6c2hi_RvNbuaRTopwlmI,911
|
|
4
|
-
careamics/lightning_datamodule.py,sha256=
|
|
5
|
-
careamics/lightning_module.py,sha256=
|
|
6
|
-
careamics/lightning_prediction_datamodule.py,sha256=
|
|
7
|
-
careamics/lightning_prediction_loop.py,sha256=qDfRVXPiCVyRz-P3l9tmlCfMT8mx9waKNfNrIMrjt3w,4599
|
|
4
|
+
careamics/lightning_datamodule.py,sha256=NoMJIaJU0BizBNTSC-dzR1mWED1urHMXdH6hIFi-QfE,32536
|
|
5
|
+
careamics/lightning_module.py,sha256=T1G_QmBAMHfZyynD6nywT9F6bFdjDDdXqJQnjqiODek,10483
|
|
6
|
+
careamics/lightning_prediction_datamodule.py,sha256=cZuiwImD-e4Fuy493PWfwhR0fWm_uqor8itjgsQCWos,14595
|
|
8
7
|
careamics/py.typed,sha256=esB4cHc6c07uVkGtqf8at7ttEnprwRxwk8obY8Qumq4,187
|
|
9
8
|
careamics/callbacks/__init__.py,sha256=spxJlDByD-6QtMl9vcIty8Wb0tyHaSTKTItozHenI44,204
|
|
10
9
|
careamics/callbacks/hyperparameters_callback.py,sha256=ODJpwwdgc1-Py8yEUpXLar8_IOAcfR7lF3--6LfSiGc,1496
|
|
11
10
|
careamics/callbacks/progress_bar_callback.py,sha256=8HvNSWZldixd6pjz0dLDo0apIbzTovv5smKmZ6tZQ8U,2444
|
|
12
11
|
careamics/config/__init__.py,sha256=SP1oJKhK3VDN9ABwnpfR3H02qRprzymjRfNYeC7kHEo,1019
|
|
13
|
-
careamics/config/algorithm_model.py,sha256=
|
|
12
|
+
careamics/config/algorithm_model.py,sha256=Lu7eYyLql5bEeTDLnJ4ms-kqSQ-EufNbt9SNa91K_Ec,5109
|
|
14
13
|
careamics/config/callback_model.py,sha256=CcamVhgRsVdskCe_9EtyWi1YbrNX5vKEplc97AYz1h8,3118
|
|
15
|
-
careamics/config/configuration_example.py,sha256=
|
|
16
|
-
careamics/config/configuration_factory.py,sha256=
|
|
17
|
-
careamics/config/configuration_model.py,sha256=
|
|
18
|
-
careamics/config/data_model.py,sha256=
|
|
19
|
-
careamics/config/inference_model.py,sha256=
|
|
20
|
-
careamics/config/optimizer_models.py,sha256
|
|
21
|
-
careamics/config/tile_information.py,sha256
|
|
14
|
+
careamics/config/configuration_example.py,sha256=fhV02Y3wm9anwWFx4Yi5y-OoP31wEYBJYggbyGONQuk,2486
|
|
15
|
+
careamics/config/configuration_factory.py,sha256=viTuBK8bUD266OG9w7B9XZDXlH39D-FaxMq6mai3k6E,21323
|
|
16
|
+
careamics/config/configuration_model.py,sha256=mmHNnWDK7BklCwd3JV6EcHU57EJ6JKBK0mK11G6Cvxo,18472
|
|
17
|
+
careamics/config/data_model.py,sha256=JANDySUFo2iU74RYPhHgjQjfzn1x2mUZlFJFLiW42TA,14257
|
|
18
|
+
careamics/config/inference_model.py,sha256=3n3jy922lTD9NhnSNA876Rmrr2MavFk2pmOE967itLQ,6447
|
|
19
|
+
careamics/config/optimizer_models.py,sha256=-XVzi7CsxfcThTUjlyn4btJa6oHbFOf-q3h9qLM9t0k,5346
|
|
20
|
+
careamics/config/tile_information.py,sha256=TAqfAthPSnIqerq72qP4KlTXPqci9XE9pCVN7J3bMJ4,2246
|
|
22
21
|
careamics/config/training_model.py,sha256=oghv91J7xIdI69wpNJGmLUAwgM9l3VhMsbsOo4USqkU,1559
|
|
23
22
|
careamics/config/architectures/__init__.py,sha256=CdnViydyTdQixus3uWHBIgbgxmu9t1_ADehqpjN_57U,444
|
|
24
23
|
careamics/config/architectures/architecture_model.py,sha256=545hlbOZU9EJNGTcSpy7eXpfzCtvIm28dDJGMo36AfQ,886
|
|
25
|
-
careamics/config/architectures/custom_model.py,sha256=
|
|
24
|
+
careamics/config/architectures/custom_model.py,sha256=K2RXK2YINm3SCzTxhxzUzFbFV-FYvWQVEDQ-i5bOIoQ,4592
|
|
26
25
|
careamics/config/architectures/register_model.py,sha256=lHH0aUPmXtI3Bq_76zkhg07_Yb_nOJZkZJLCC_G-rZM,2434
|
|
27
26
|
careamics/config/architectures/unet_model.py,sha256=sQjfqTjh1kTNi369U3_94jroU6LyLlflaIe8FwdHQvo,2892
|
|
28
27
|
careamics/config/architectures/vae_model.py,sha256=Z0satmte4udManh_bxtl93ZmQlmo6JFE1NQIuZkTsQk,926
|
|
@@ -39,56 +38,81 @@ careamics/config/support/supported_losses.py,sha256=TPsMCuDdgb64TRyDwonnwHb1R-rk
|
|
|
39
38
|
careamics/config/support/supported_optimizers.py,sha256=xxbJsyohJTlHeUz2I4eRwcE3BeACs-6PH8cpX6w2wX8,1394
|
|
40
39
|
careamics/config/support/supported_pixel_manipulations.py,sha256=rFiktUlvoFU7s1NAKEMqsXOzLw5eaw9GtCKUznvq6xc,432
|
|
41
40
|
careamics/config/support/supported_struct_axis.py,sha256=alZMA5Y-BpDymLPUEd1zqVY0xMkgl9Rv1d4ujED6sco,424
|
|
42
|
-
careamics/config/support/supported_transforms.py,sha256=
|
|
41
|
+
careamics/config/support/supported_transforms.py,sha256=4uob-bnZ5aqpN5aEI67-aa7bsmVCrKxEknzf2BAZ3W4,283
|
|
43
42
|
careamics/config/transformations/__init__.py,sha256=oqwBAL2XXbPRZZ5iOzNqalX6SyJ1M-S0lkfbDGZOzyE,378
|
|
44
43
|
careamics/config/transformations/n2v_manipulate_model.py,sha256=UTyfpm1mmMvYg_HoMzXilZhJGx_muiV-lLQ4UThCFJ0,1854
|
|
45
|
-
careamics/config/transformations/normalize_model.py,sha256=
|
|
44
|
+
careamics/config/transformations/normalize_model.py,sha256=1Rkk6IkF-7ytGU6HSzP-TpOi4RRWiQJ6fOd8zammXcg,1936
|
|
46
45
|
careamics/config/transformations/transform_model.py,sha256=i7KAtSv4nah2H7uyJFKqg7RdKF68OHIPMNNvDo0HxGY,1000
|
|
47
46
|
careamics/config/transformations/xy_flip_model.py,sha256=zU-uZ1b1zNZWckbho3onN-B7BHKhN7jbgbNZyRQhv2s,1025
|
|
48
47
|
careamics/config/transformations/xy_random_rotate90_model.py,sha256=6sYKmtCLvz0SV1qZgBSHUTH-CUjwvHnohq1HyPntbyE,894
|
|
49
48
|
careamics/config/validators/__init__.py,sha256=iv0nVI0W7j9DxFPwh0DjRCzM9P8oLQn4Gwi5rfuFrrI,180
|
|
50
|
-
careamics/config/validators/validator_utils.py,sha256=
|
|
51
|
-
careamics/dataset/__init__.py,sha256=
|
|
52
|
-
careamics/dataset/in_memory_dataset.py,sha256=
|
|
53
|
-
careamics/dataset/
|
|
49
|
+
careamics/config/validators/validator_utils.py,sha256=aNFzpBVbef3BZIt6MiNMVc2kW6MJDWqQgdYkFM8Gjig,2621
|
|
50
|
+
careamics/dataset/__init__.py,sha256=NQSWdpQu6BhqGGHUYuOt1hXJrGUN1LPNCP1A8duMY84,547
|
|
51
|
+
careamics/dataset/in_memory_dataset.py,sha256=DfFpSdsYM4aNw6FWn_yDHA6seQDSozGyt-Q57pDpJDA,9457
|
|
52
|
+
careamics/dataset/in_memory_pred_dataset.py,sha256=VvwW5D8TjgO_kR8eZinP-9qepSiI6ZsUN7FZ0Rvc8Bs,2161
|
|
53
|
+
careamics/dataset/in_memory_tiled_pred_dataset.py,sha256=DANmlnlV1ysXKdwGvmJoOYKcjlgoMhnSGSDRpeK79ZA,3552
|
|
54
|
+
careamics/dataset/iterable_dataset.py,sha256=uEmiO8n2qirJv5XkMU5lKmPMBL7rw06GRQQL6BQpfus,9694
|
|
55
|
+
careamics/dataset/iterable_pred_dataset.py,sha256=AtsNRKOEkfDG8y3wa0bi5ImEqEaU2E5LM_iXcuK4Ehw,3706
|
|
56
|
+
careamics/dataset/iterable_tiled_pred_dataset.py,sha256=Q0OkAtYWlsClRSNdl74sPIvbq5rLGZvORzywOJ_YrUw,4499
|
|
54
57
|
careamics/dataset/zarr_dataset.py,sha256=lojnK5bhiF1vyjuPtWXBrZ9sy5fT_rBvZJbbbnE-H_I,5665
|
|
55
|
-
careamics/dataset/dataset_utils/__init__.py,sha256=
|
|
58
|
+
careamics/dataset/dataset_utils/__init__.py,sha256=DuPIjndTs0VhZsUIk2IcSk6H9N0d0ARyA5U3v3Qz-hw,666
|
|
56
59
|
careamics/dataset/dataset_utils/dataset_utils.py,sha256=zYNglet5lYKxIhTeOGG2K24oujC-m5zyYlwJcQcleVA,2662
|
|
57
|
-
careamics/dataset/dataset_utils/file_utils.py,sha256=
|
|
58
|
-
careamics/dataset/dataset_utils/
|
|
60
|
+
careamics/dataset/dataset_utils/file_utils.py,sha256=s7RmmnHa7ojl4kauXfuj7hn0dAx0HB1d2ES7sUSS7IQ,4062
|
|
61
|
+
careamics/dataset/dataset_utils/iterate_over_files.py,sha256=ACiltjAH2aKR0UOEvWPuuxv68NWEd2aDMFE07caxhWo,2859
|
|
62
|
+
careamics/dataset/dataset_utils/read_tiff.py,sha256=emzQgodEaBsLB0ULH4lUUbsDd9PylR8DQ3rb7g0l2b8,1336
|
|
59
63
|
careamics/dataset/dataset_utils/read_utils.py,sha256=0nsfzHq3zr9kjm2qZZrMRKI5LC5MiRSH35xPBCYyBrQ,579
|
|
60
64
|
careamics/dataset/dataset_utils/read_zarr.py,sha256=2jzREAnJDQSv0qmsL-v00BxmiZ_sp0ijq667LZSQ_hY,1685
|
|
65
|
+
careamics/dataset/dataset_utils/running_stats.py,sha256=0uOLaXpNwmY4lIElsHg4Ezf1YRbHy9An8GHXGYOaYmg,5565
|
|
61
66
|
careamics/dataset/patching/__init__.py,sha256=7-s12oUAZNlMOwSkxSwbD7vojQINWYFzn_4qIJ87WBg,37
|
|
62
|
-
careamics/dataset/patching/patching.py,sha256=
|
|
63
|
-
careamics/dataset/patching/random_patching.py,sha256=
|
|
67
|
+
careamics/dataset/patching/patching.py,sha256=XoJMfOwYItNQNJOJmRN9swtFiFu0G2L6qvUhP7jhYes,8432
|
|
68
|
+
careamics/dataset/patching/random_patching.py,sha256=61sLxA4eJN5TIWBVIDZdJahS_CkclpM7Kc_VdPj91dU,6486
|
|
64
69
|
careamics/dataset/patching/sequential_patching.py,sha256=_l3Q2uYIhjMJMaxDdSbHC9_2kRF9eLz-Xs3r9i7j3Nc,5903
|
|
65
|
-
careamics/dataset/patching/
|
|
66
|
-
careamics/dataset/
|
|
70
|
+
careamics/dataset/patching/validate_patch_dimension.py,sha256=sQQ0-4b4uu60MNKkoWv95KxQ80J7Ku0CEk0-kAXlKeI,2134
|
|
71
|
+
careamics/dataset/tiling/__init__.py,sha256=XynyAz85hVfkLtrG0lrMr_aBQm_YEwfu5uFcXMGHlOA,190
|
|
72
|
+
careamics/dataset/tiling/collate_tiles.py,sha256=OrPZ-n-V3uGOc_7CcPnyEJqdbEVDlTfJfWmZnyBZ-HA,978
|
|
73
|
+
careamics/dataset/tiling/tiled_patching.py,sha256=Zhhc0TwXVy4P_tZxS3B5tQZK6SRhGiQwnzVr-1BC4ww,5952
|
|
67
74
|
careamics/losses/__init__.py,sha256=kVEwfZ2xXfd8x0n-VHGKm6qvzbto5pIIJYP_jN-bCtw,89
|
|
68
75
|
careamics/losses/loss_factory.py,sha256=vaMlxH5oescWTKlK1adWwbeD9tW4Ti-p7qKmc1iHCi0,1005
|
|
69
76
|
careamics/losses/losses.py,sha256=DKwHZ9ifVe6wMd3tBOiswLC-saU1bj1RCcXGOkREmKU,2328
|
|
77
|
+
careamics/lvae_training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
|
+
careamics/lvae_training/data_modules.py,sha256=A5Uoo4qtPdX99QSi-Zl22LzO0I1DszJbQuXMGUXGQEE,46665
|
|
79
|
+
careamics/lvae_training/data_utils.py,sha256=tRk0k0TkBLPocqlUlkwQN_dm5jzw5z74YNs2DsCuy9Y,21670
|
|
80
|
+
careamics/lvae_training/eval_utils.py,sha256=_AlXNXk4uGS2AGsF4PHJZpJoWBgq32kvQLEh7awOIvc,32405
|
|
81
|
+
careamics/lvae_training/get_config.py,sha256=-CWVxlPo71_huUSmXnmYvOmgvcvrZiv0wIpXnR32l6E,3054
|
|
82
|
+
careamics/lvae_training/lightning_module.py,sha256=ryr7iHqCMzCl5esi6_gEcnKFDQkMrw0EXK9Zfgv1Nek,27186
|
|
83
|
+
careamics/lvae_training/metrics.py,sha256=KTDAKhe3vh-YxzGibjtkIG2nnUyujbnwqX4xGwaRXwE,6718
|
|
84
|
+
careamics/lvae_training/train_lvae.py,sha256=Eu--3-RHSfhQVsJ-CTDXhUeoM1fzf_H9IGtBaNPOsHI,11044
|
|
85
|
+
careamics/lvae_training/train_utils.py,sha256=e-d4QsF-li8MmAPkAmB1daHpkuU16nBTnQFZYqpTjn4,3567
|
|
70
86
|
careamics/model_io/__init__.py,sha256=HITzjiuZQwo-rQ2_Ma3bz9l7PDANv1_S489E-tffV9s,155
|
|
71
|
-
careamics/model_io/bmz_io.py,sha256=
|
|
72
|
-
careamics/model_io/model_io_utils.py,sha256=
|
|
87
|
+
careamics/model_io/bmz_io.py,sha256=Gc6uN0aO_kEDzQnJTSTNDS7PiYC684FfDNI0X9rZm8g,7031
|
|
88
|
+
careamics/model_io/model_io_utils.py,sha256=Pxm_9uYRBDOMa8dC4ENk-Vre9CXsTIORGvMwn8mLzXY,2347
|
|
73
89
|
careamics/model_io/bioimage/__init__.py,sha256=r94nu8WDAvj0Fbu4C-iJXdOhfSQXeZBvN3UKsLG0RNI,298
|
|
74
90
|
careamics/model_io/bioimage/_readme_factory.py,sha256=LZAuEiWNBTPaD8KrLPMq16yJuOPKDZiGQuTMHKLvoT4,3514
|
|
75
91
|
careamics/model_io/bioimage/bioimage_utils.py,sha256=nlW0J1daYyLbL6yVN3QSn3HhA2joMjIG-thK64lpVTY,1085
|
|
76
|
-
careamics/model_io/bioimage/model_description.py,sha256=
|
|
92
|
+
careamics/model_io/bioimage/model_description.py,sha256=3jw4wkJDefLEW-2BbEfAml3AwyteZszL-v8JYpJRcOo,9635
|
|
77
93
|
careamics/models/__init__.py,sha256=Wty5hwQb_As33pQOZqY5j-DpDOdh5ArBH4BhQDSuXTQ,133
|
|
78
94
|
careamics/models/activation.py,sha256=xdqz4-yKV7oElG_dDrYuibS8HOiYvKdV_r9FwWPvaDE,977
|
|
79
95
|
careamics/models/layers.py,sha256=oWzpq8OdHFEJqPWC9X8IRPNe0XqAnesSqwoT6V3t1Mw,13712
|
|
80
96
|
careamics/models/model_factory.py,sha256=5YRwRRUemxb-pTRL3VWn8N61tCGyhrurqPgcFaNETb0,1360
|
|
81
97
|
careamics/models/unet.py,sha256=3pXpiCIw7WUaDV0Jmczkxi99C5-Zu3NpQpWxgRkeGL8,14321
|
|
82
|
-
careamics/
|
|
83
|
-
careamics/
|
|
98
|
+
careamics/models/lvae/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
99
|
+
careamics/models/lvae/layers.py,sha256=wFuQgmtJtB7YNuNi2dVoOEWq1ndR6ku4iGvC2u0TJlM,84991
|
|
100
|
+
careamics/models/lvae/likelihoods.py,sha256=FRFTh34FaBLGxn9OXFzqFyHhhJMSKYhgqxwG65VbGh8,10489
|
|
101
|
+
careamics/models/lvae/lvae.py,sha256=5RlK4-h55dGz9UMCh8JCbLsaaIQ5S2IKGeI9d4nD5dA,40167
|
|
102
|
+
careamics/models/lvae/noise_models.py,sha256=yotY5gkPAowbI7esOmHlzBWcSsZlH2G3U7uYIWghGwY,15703
|
|
103
|
+
careamics/models/lvae/utils.py,sha256=muy4nLHmnB3BPAI0tQbJK_vVtBZOLBvhrJigHIOx5V4,11542
|
|
104
|
+
careamics/prediction_utils/__init__.py,sha256=0rtfNXeH5RvJ9ieeCbBV9i9eyLXxo_IMpqVH1-H2N8E,359
|
|
105
|
+
careamics/prediction_utils/create_pred_datamodule.py,sha256=rv_Q0v4Es-NE1IU_nUkJYUsYo3Gh3whvOj43WlTw2hc,5846
|
|
106
|
+
careamics/prediction_utils/prediction_outputs.py,sha256=1BHJF_dpw3QwH8uF_uE1u4tVui01mm899O2VcGvWYvM,4730
|
|
107
|
+
careamics/prediction_utils/stitch_prediction.py,sha256=VRJc51KHg_3gWTCNdvQpHfrGaNqDHd9hHhnvqxg2cjE,3081
|
|
84
108
|
careamics/transforms/__init__.py,sha256=VIHIsC8sMAh1TCm67ifB816Zp-LRo6rAONPuT2Qs3bs,483
|
|
85
109
|
careamics/transforms/compose.py,sha256=mTkhoxvgvsBqNoz9RWpJ_tqsDl1CDp0-UARTjUuBRf4,3477
|
|
86
|
-
careamics/transforms/n2v_manipulate.py,sha256=
|
|
87
|
-
careamics/transforms/normalize.py,sha256=
|
|
88
|
-
careamics/transforms/pixel_manipulation.py,sha256=
|
|
110
|
+
careamics/transforms/n2v_manipulate.py,sha256=Gty7Jtu-RiFb1EnlrOi652qAOGKU5ZHvidRvykWqJxg,5438
|
|
111
|
+
careamics/transforms/normalize.py,sha256=dfGWCGPyNwyEqg5wUCAA8cGdT1MvNkpKUEpw8Cw8DfA,7274
|
|
112
|
+
careamics/transforms/pixel_manipulation.py,sha256=lNA19Vlo_3GHzRnT_4AFuv6eWQaxbie2PTYGalCY4YQ,13346
|
|
89
113
|
careamics/transforms/struct_mask_parameters.py,sha256=jE29Li9sx3olaRnqYfJsSlKi2t0WQzJmCm9aCbIQEsA,421
|
|
90
114
|
careamics/transforms/transform.py,sha256=cEqc4ci8na70i-HIGYC7udRfVa8D_8OjdRVrr3txLvQ,464
|
|
91
|
-
careamics/transforms/tta.py,sha256=
|
|
115
|
+
careamics/transforms/tta.py,sha256=78S7Df9rLHmEVSQSI1qDcRrRJGauyG3oaIrXkckCkmw,2335
|
|
92
116
|
careamics/transforms/xy_flip.py,sha256=Q1kKTa2kE3W1P3dlpT4GAVSSHM3TebnrvIyWh75Fnko,3443
|
|
93
117
|
careamics/transforms/xy_random_rotate90.py,sha256=zWdBROLLjgxTMSQEQesJr17j84BmZhKWCMVVONHU8mw,2781
|
|
94
118
|
careamics/utils/__init__.py,sha256=tO1X5QTfnthepuW0uYagz5fWehtLtwK2gPmkUeqhdOw,334
|
|
@@ -97,11 +121,10 @@ careamics/utils/context.py,sha256=Ljf70OR1FcYpsVpxb5Sr2fzmPVIZgDS1uZob_3BcELg,14
|
|
|
97
121
|
careamics/utils/logging.py,sha256=coIscjkDYpqcsGnsONuYOdIYd6_gHxdnYIZ-e9Y2Ybg,10322
|
|
98
122
|
careamics/utils/metrics.py,sha256=9YQe5Aj2Pv2h9jnRFeRbDQ_3qXAW0QHpucSqiUtwDcA,2382
|
|
99
123
|
careamics/utils/path_utils.py,sha256=8AugiG5DOmzgSnTCJI8vypXaPE0XhnR-9pzeiFUZ-0I,554
|
|
100
|
-
careamics/utils/ram.py,sha256=
|
|
124
|
+
careamics/utils/ram.py,sha256=tksyn8dVX_iJXmrDZDGub32hFZWIaNxnMheO5G1p43I,244
|
|
101
125
|
careamics/utils/receptive_field.py,sha256=Y2h4c8S6glX3qcx5KHDmO17Kkuyey9voxfoXyqcAfiM,3296
|
|
102
|
-
careamics/utils/running_stats.py,sha256=GIPMPuH9EOUKD_cYBkJFPggXRKnQEiOXx68Pq9UCCVI,1384
|
|
103
126
|
careamics/utils/torch_utils.py,sha256=g1zxdlM7_BA7mMLcCzmrxZX4LmH__KXlJibC95muVaA,3014
|
|
104
|
-
careamics-0.1.
|
|
105
|
-
careamics-0.1.
|
|
106
|
-
careamics-0.1.
|
|
107
|
-
careamics-0.1.
|
|
127
|
+
careamics-0.1.0rc7.dist-info/METADATA,sha256=ZIJHL8fCiF2MDXpkvtegSSA889deeY9USh0emDYJncM,3525
|
|
128
|
+
careamics-0.1.0rc7.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
|
129
|
+
careamics-0.1.0rc7.dist-info/licenses/LICENSE,sha256=6zdNW-k_xHRKYWUf9tDI_ZplUciFHyj0g16DYuZ2udw,1509
|
|
130
|
+
careamics-0.1.0rc7.dist-info/RECORD,,
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
"""Lithning prediction loop allowing tiling."""
|
|
2
|
-
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import pytorch_lightning as L
|
|
6
|
-
from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher
|
|
7
|
-
from pytorch_lightning.loops.utilities import _no_grad_context
|
|
8
|
-
from pytorch_lightning.trainer import call
|
|
9
|
-
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
|
|
10
|
-
|
|
11
|
-
from careamics.prediction import stitch_prediction
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class CAREamicsPredictionLoop(L.loops._PredictionLoop):
|
|
15
|
-
"""
|
|
16
|
-
CAREamics prediction loop.
|
|
17
|
-
|
|
18
|
-
This class extends the PyTorch Lightning `_PredictionLoop` class to include
|
|
19
|
-
the stitching of the tiles into a single prediction result.
|
|
20
|
-
"""
|
|
21
|
-
|
|
22
|
-
def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
|
|
23
|
-
"""Call `on_predict_epoch_end` hook.
|
|
24
|
-
|
|
25
|
-
Adapted from the parent method.
|
|
26
|
-
|
|
27
|
-
Returns
|
|
28
|
-
-------
|
|
29
|
-
Optional[_PREDICT_OUTPUT]
|
|
30
|
-
Prediction output.
|
|
31
|
-
"""
|
|
32
|
-
trainer = self.trainer
|
|
33
|
-
call._call_callback_hooks(trainer, "on_predict_epoch_end")
|
|
34
|
-
call._call_lightning_module_hook(trainer, "on_predict_epoch_end")
|
|
35
|
-
|
|
36
|
-
if self.return_predictions:
|
|
37
|
-
########################################################
|
|
38
|
-
################ CAREamics specific code ###############
|
|
39
|
-
if len(self.predicted_array) == 1:
|
|
40
|
-
# TODO does this make sense to here? (force numpy array)
|
|
41
|
-
return self.predicted_array[0].numpy()
|
|
42
|
-
else:
|
|
43
|
-
# TODO revisit logic
|
|
44
|
-
return [element.numpy() for element in self.predicted_array]
|
|
45
|
-
########################################################
|
|
46
|
-
return None
|
|
47
|
-
|
|
48
|
-
@_no_grad_context
|
|
49
|
-
def run(self) -> Optional[_PREDICT_OUTPUT]:
|
|
50
|
-
"""Run the prediction loop.
|
|
51
|
-
|
|
52
|
-
Adapted from the parent method in order to stitch the predictions.
|
|
53
|
-
|
|
54
|
-
Returns
|
|
55
|
-
-------
|
|
56
|
-
Optional[_PREDICT_OUTPUT]
|
|
57
|
-
Prediction output.
|
|
58
|
-
"""
|
|
59
|
-
self.setup_data()
|
|
60
|
-
if self.skip:
|
|
61
|
-
return None
|
|
62
|
-
self.reset()
|
|
63
|
-
self.on_run_start()
|
|
64
|
-
data_fetcher = self._data_fetcher
|
|
65
|
-
assert data_fetcher is not None
|
|
66
|
-
|
|
67
|
-
self.predicted_array = []
|
|
68
|
-
self.tiles = []
|
|
69
|
-
self.stitching_data = []
|
|
70
|
-
|
|
71
|
-
while True:
|
|
72
|
-
try:
|
|
73
|
-
if isinstance(data_fetcher, _DataLoaderIterDataFetcher):
|
|
74
|
-
dataloader_iter = next(data_fetcher)
|
|
75
|
-
# hook's batch_idx and dataloader_idx arguments correctness cannot
|
|
76
|
-
# be guaranteed in this setting
|
|
77
|
-
batch = data_fetcher._batch
|
|
78
|
-
batch_idx = data_fetcher._batch_idx
|
|
79
|
-
dataloader_idx = data_fetcher._dataloader_idx
|
|
80
|
-
else:
|
|
81
|
-
dataloader_iter = None
|
|
82
|
-
batch, batch_idx, dataloader_idx = next(data_fetcher)
|
|
83
|
-
self.batch_progress.is_last_batch = data_fetcher.done
|
|
84
|
-
|
|
85
|
-
# run step hooks
|
|
86
|
-
self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
|
|
87
|
-
|
|
88
|
-
########################################################
|
|
89
|
-
################ CAREamics specific code ###############
|
|
90
|
-
# TODO: next line is not compatible with muSplit
|
|
91
|
-
is_tiled = len(self.predictions[batch_idx]) == 2
|
|
92
|
-
if is_tiled:
|
|
93
|
-
# extract the last tile flag and the coordinates (crop and stitch)
|
|
94
|
-
last_tile, *stitch_data = self.predictions[batch_idx][1]
|
|
95
|
-
|
|
96
|
-
# append the tile and the coordinates to the lists
|
|
97
|
-
self.tiles.append(self.predictions[batch_idx][0])
|
|
98
|
-
self.stitching_data.append(stitch_data)
|
|
99
|
-
|
|
100
|
-
# if last tile, stitch the tiles and add array to the prediction
|
|
101
|
-
if any(last_tile):
|
|
102
|
-
predicted_batches = stitch_prediction(
|
|
103
|
-
self.tiles, self.stitching_data
|
|
104
|
-
)
|
|
105
|
-
self.predicted_array.append(predicted_batches)
|
|
106
|
-
self.tiles.clear()
|
|
107
|
-
self.stitching_data.clear()
|
|
108
|
-
else:
|
|
109
|
-
# simply add the prediction to the list
|
|
110
|
-
self.predicted_array.append(self.predictions[batch_idx])
|
|
111
|
-
########################################################
|
|
112
|
-
except StopIteration:
|
|
113
|
-
break
|
|
114
|
-
finally:
|
|
115
|
-
self._restarting = False
|
|
116
|
-
return self.on_run_end()
|
|
117
|
-
|
|
118
|
-
# TODO predictions aren't stacked, list returned
|