careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +163 -266
- careamics/config/algorithm_model.py +0 -15
- careamics/config/architectures/custom_model.py +3 -3
- careamics/config/configuration_example.py +0 -3
- careamics/config/configuration_factory.py +23 -25
- careamics/config/configuration_model.py +11 -11
- careamics/config/data_model.py +80 -50
- careamics/config/inference_model.py +29 -17
- careamics/config/optimizer_models.py +7 -7
- careamics/config/support/supported_transforms.py +0 -1
- careamics/config/tile_information.py +26 -58
- careamics/config/transformations/normalize_model.py +32 -4
- careamics/config/validators/validator_utils.py +1 -1
- careamics/dataset/__init__.py +12 -1
- careamics/dataset/dataset_utils/__init__.py +8 -1
- careamics/dataset/dataset_utils/file_utils.py +1 -1
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/read_tiff.py +0 -9
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +66 -171
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +92 -249
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
- careamics/dataset/patching/patching.py +54 -25
- careamics/dataset/patching/random_patching.py +9 -4
- careamics/dataset/patching/validate_patch_dimension.py +5 -3
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
- careamics/lightning_datamodule.py +1 -6
- careamics/lightning_module.py +11 -7
- careamics/lightning_prediction_datamodule.py +52 -72
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/data_modules.py +1220 -0
- careamics/lvae_training/data_utils.py +618 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +339 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/bioimage/model_description.py +40 -32
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +5 -2
- careamics/models/lvae/__init__.py +0 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +312 -0
- careamics/models/lvae/lvae.py +985 -0
- careamics/models/lvae/noise_models.py +409 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/prediction_utils/__init__.py +12 -0
- careamics/prediction_utils/create_pred_datamodule.py +185 -0
- careamics/prediction_utils/prediction_outputs.py +165 -0
- careamics/prediction_utils/stitch_prediction.py +100 -0
- careamics/transforms/n2v_manipulate.py +3 -1
- careamics/transforms/normalize.py +139 -68
- careamics/transforms/pixel_manipulation.py +33 -9
- careamics/transforms/tta.py +43 -29
- careamics/utils/ram.py +2 -2
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
- careamics/lightning_prediction_loop.py +0 -118
- careamics/prediction/__init__.py +0 -7
- careamics/prediction/stitch_prediction.py +0 -70
- careamics/utils/running_stats.py +0 -43
- {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
"""Prediction utility functions."""
|
|
2
|
-
|
|
3
|
-
from typing import List
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def stitch_prediction(
|
|
10
|
-
tiles: List[torch.Tensor],
|
|
11
|
-
stitching_data: List[List[torch.Tensor]],
|
|
12
|
-
) -> torch.Tensor:
|
|
13
|
-
"""
|
|
14
|
-
Stitch tiles back together to form a full image.
|
|
15
|
-
|
|
16
|
-
Parameters
|
|
17
|
-
----------
|
|
18
|
-
tiles : List[torch.Tensor]
|
|
19
|
-
Cropped tiles and their respective stitching coordinates.
|
|
20
|
-
stitching_data : List
|
|
21
|
-
List of information and coordinates obtained from
|
|
22
|
-
`dataset.tiled_patching.extract_tiles`.
|
|
23
|
-
|
|
24
|
-
Returns
|
|
25
|
-
-------
|
|
26
|
-
np.ndarray
|
|
27
|
-
Full image.
|
|
28
|
-
"""
|
|
29
|
-
# retrieve whole array size, there is two cases to consider:
|
|
30
|
-
# 1. the tiles are stored in a list
|
|
31
|
-
# 2. the tiles are stored in a list with batches along the first dim
|
|
32
|
-
if tiles[0].shape[0] > 1:
|
|
33
|
-
input_shape = np.array(
|
|
34
|
-
[el.numpy() for el in stitching_data[0][0][0]], dtype=int
|
|
35
|
-
).squeeze()
|
|
36
|
-
else:
|
|
37
|
-
input_shape = np.array(
|
|
38
|
-
[el.numpy() for el in stitching_data[0][0]], dtype=int
|
|
39
|
-
).squeeze()
|
|
40
|
-
|
|
41
|
-
# TODO should use torch.zeros instead of np.zeros
|
|
42
|
-
predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32))
|
|
43
|
-
|
|
44
|
-
for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip(
|
|
45
|
-
tiles, stitching_data
|
|
46
|
-
):
|
|
47
|
-
for batch_idx in range(tile_batch.shape[0]):
|
|
48
|
-
# Compute coordinates for cropping predicted tile
|
|
49
|
-
slices = tuple(
|
|
50
|
-
[
|
|
51
|
-
slice(c[0][batch_idx], c[1][batch_idx])
|
|
52
|
-
for c in overlap_crop_coords_batch
|
|
53
|
-
]
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
# Crop predited tile according to overlap coordinates
|
|
57
|
-
cropped_tile = tile_batch[batch_idx].squeeze()[slices]
|
|
58
|
-
|
|
59
|
-
# Insert cropped tile into predicted image using stitch coordinates
|
|
60
|
-
predicted_image[
|
|
61
|
-
(
|
|
62
|
-
...,
|
|
63
|
-
*[
|
|
64
|
-
slice(c[0][batch_idx], c[1][batch_idx])
|
|
65
|
-
for c in stitch_coords_batch
|
|
66
|
-
],
|
|
67
|
-
)
|
|
68
|
-
] = cropped_tile.to(torch.float32)
|
|
69
|
-
|
|
70
|
-
return predicted_image
|
careamics/utils/running_stats.py
DELETED
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
"""Running stats submodule, used in the Zarr dataset."""
|
|
2
|
-
|
|
3
|
-
# from multiprocessing import Value
|
|
4
|
-
# from typing import Tuple
|
|
5
|
-
|
|
6
|
-
# import numpy as np
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# class RunningStats:
|
|
10
|
-
# """Calculates running mean and std."""
|
|
11
|
-
|
|
12
|
-
# def __init__(self) -> None:
|
|
13
|
-
# self.reset()
|
|
14
|
-
|
|
15
|
-
# def reset(self) -> None:
|
|
16
|
-
# """Reset the running stats."""
|
|
17
|
-
# self.avg_mean = Value("d", 0)
|
|
18
|
-
# self.avg_std = Value("d", 0)
|
|
19
|
-
# self.m2 = Value("d", 0)
|
|
20
|
-
# self.count = Value("i", 0)
|
|
21
|
-
|
|
22
|
-
# def init(self, mean: float, std: float) -> None:
|
|
23
|
-
# """Initialize running stats."""
|
|
24
|
-
# with self.avg_mean.get_lock():
|
|
25
|
-
# self.avg_mean.value += mean
|
|
26
|
-
# with self.avg_std.get_lock():
|
|
27
|
-
# self.avg_std.value = std
|
|
28
|
-
|
|
29
|
-
# def compute_std(self) -> Tuple[float, float]:
|
|
30
|
-
# """Compute std."""
|
|
31
|
-
# if self.count.value >= 2:
|
|
32
|
-
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
|
|
33
|
-
|
|
34
|
-
# def update(self, value: float) -> None:
|
|
35
|
-
# """Update running stats."""
|
|
36
|
-
# with self.count.get_lock():
|
|
37
|
-
# self.count.value += 1
|
|
38
|
-
# delta = value - self.avg_mean.value
|
|
39
|
-
# with self.avg_mean.get_lock():
|
|
40
|
-
# self.avg_mean.value += delta / self.count.value
|
|
41
|
-
# delta2 = value - self.avg_mean.value
|
|
42
|
-
# with self.m2.get_lock():
|
|
43
|
-
# self.m2.value += delta * delta2
|
|
File without changes
|