careamics 0.0.1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (48) hide show
  1. careamics/__init__.py +7 -1
  2. careamics/bioimage/__init__.py +15 -0
  3. careamics/bioimage/docs/Noise2Void.md +5 -0
  4. careamics/bioimage/docs/__init__.py +1 -0
  5. careamics/bioimage/io.py +182 -0
  6. careamics/bioimage/rdf.py +105 -0
  7. careamics/config/__init__.py +11 -0
  8. careamics/config/algorithm.py +231 -0
  9. careamics/config/config.py +297 -0
  10. careamics/config/config_filter.py +44 -0
  11. careamics/config/data.py +194 -0
  12. careamics/config/torch_optim.py +118 -0
  13. careamics/config/training.py +534 -0
  14. careamics/dataset/__init__.py +1 -0
  15. careamics/dataset/dataset_utils.py +111 -0
  16. careamics/dataset/extraction_strategy.py +21 -0
  17. careamics/dataset/in_memory_dataset.py +202 -0
  18. careamics/dataset/patching.py +492 -0
  19. careamics/dataset/prepare_dataset.py +175 -0
  20. careamics/dataset/tiff_dataset.py +212 -0
  21. careamics/engine.py +1014 -0
  22. careamics/losses/__init__.py +4 -0
  23. careamics/losses/loss_factory.py +38 -0
  24. careamics/losses/losses.py +34 -0
  25. careamics/manipulation/__init__.py +4 -0
  26. careamics/manipulation/pixel_manipulation.py +158 -0
  27. careamics/models/__init__.py +4 -0
  28. careamics/models/layers.py +152 -0
  29. careamics/models/model_factory.py +251 -0
  30. careamics/models/unet.py +322 -0
  31. careamics/prediction/__init__.py +9 -0
  32. careamics/prediction/prediction_utils.py +106 -0
  33. careamics/utils/__init__.py +20 -0
  34. careamics/utils/ascii_logo.txt +9 -0
  35. careamics/utils/augment.py +65 -0
  36. careamics/utils/context.py +45 -0
  37. careamics/utils/logging.py +321 -0
  38. careamics/utils/metrics.py +160 -0
  39. careamics/utils/normalization.py +55 -0
  40. careamics/utils/torch_utils.py +89 -0
  41. careamics/utils/validators.py +170 -0
  42. careamics/utils/wandb.py +121 -0
  43. careamics-0.1.0rc2.dist-info/METADATA +81 -0
  44. careamics-0.1.0rc2.dist-info/RECORD +47 -0
  45. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
  46. {careamics-0.0.1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +1 -1
  47. careamics-0.0.1.dist-info/METADATA +0 -46
  48. careamics-0.0.1.dist-info/RECORD +0 -6
@@ -0,0 +1,175 @@
1
+ """
2
+ Dataset preparation module.
3
+
4
+ Methods to set up the datasets for training, validation and prediction.
5
+ """
6
+ from pathlib import Path
7
+ from typing import List, Optional, Union
8
+
9
+ from careamics.config import Configuration
10
+ from careamics.manipulation import default_manipulate
11
+ from careamics.utils import check_tiling_validity
12
+
13
+ from .extraction_strategy import ExtractionStrategy
14
+ from .in_memory_dataset import InMemoryDataset
15
+ from .tiff_dataset import TiffDataset
16
+
17
+
18
+ def get_train_dataset(
19
+ config: Configuration, train_path: str
20
+ ) -> Union[TiffDataset, InMemoryDataset]:
21
+ """
22
+ Create training dataset.
23
+
24
+ Depending on the configuration, this methods return either a TiffDataset or an
25
+ InMemoryDataset.
26
+
27
+ Parameters
28
+ ----------
29
+ config : Configuration
30
+ Configuration.
31
+ train_path : Union[str, Path]
32
+ Path to training data.
33
+
34
+ Returns
35
+ -------
36
+ Union[TiffDataset, InMemoryDataset]
37
+ Dataset.
38
+ """
39
+ if config.data.in_memory:
40
+ dataset = InMemoryDataset(
41
+ data_path=train_path,
42
+ data_format=config.data.data_format,
43
+ axes=config.data.axes,
44
+ mean=config.data.mean,
45
+ std=config.data.std,
46
+ patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
47
+ patch_size=config.training.patch_size,
48
+ patch_transform=default_manipulate,
49
+ patch_transform_params={
50
+ "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
51
+ "roi_size": config.algorithm.roi_size,
52
+ },
53
+ )
54
+ else:
55
+ dataset = TiffDataset(
56
+ data_path=train_path,
57
+ data_format=config.data.data_format,
58
+ axes=config.data.axes,
59
+ mean=config.data.mean,
60
+ std=config.data.std,
61
+ patch_extraction_method=ExtractionStrategy.RANDOM,
62
+ patch_size=config.training.patch_size,
63
+ patch_transform=default_manipulate,
64
+ patch_transform_params={
65
+ "mask_pixel_percentage": config.algorithm.masked_pixel_percentage,
66
+ "roi_size": config.algorithm.roi_size,
67
+ },
68
+ )
69
+ return dataset
70
+
71
+
72
+ def get_validation_dataset(config: Configuration, val_path: str) -> InMemoryDataset:
73
+ """
74
+ Create validation dataset.
75
+
76
+ Validation dataset is kept in memory.
77
+
78
+ Parameters
79
+ ----------
80
+ config : Configuration
81
+ Configuration.
82
+ val_path : Union[str, Path]
83
+ Path to validation data.
84
+
85
+ Returns
86
+ -------
87
+ TiffDataset
88
+ In memory dataset.
89
+ """
90
+ data_path = val_path
91
+
92
+ dataset = InMemoryDataset(
93
+ data_path=data_path,
94
+ data_format=config.data.data_format,
95
+ axes=config.data.axes,
96
+ mean=config.data.mean,
97
+ std=config.data.std,
98
+ patch_extraction_method=ExtractionStrategy.SEQUENTIAL,
99
+ patch_size=config.training.patch_size,
100
+ patch_transform=default_manipulate,
101
+ patch_transform_params={
102
+ "mask_pixel_percentage": config.algorithm.masked_pixel_percentage
103
+ },
104
+ )
105
+
106
+ return dataset
107
+
108
+
109
+ def get_prediction_dataset(
110
+ config: Configuration,
111
+ pred_path: Union[str, Path],
112
+ *,
113
+ tile_shape: Optional[List[int]] = None,
114
+ overlaps: Optional[List[int]] = None,
115
+ axes: Optional[str] = None,
116
+ ) -> TiffDataset:
117
+ """
118
+ Create prediction dataset.
119
+
120
+ To use tiling, both `tile_shape` and `overlaps` must be specified, have same
121
+ length, be divisible by 2 and greater than 0. Finally, the overlaps must be
122
+ smaller than the tiles.
123
+
124
+ By default, axes are extracted from the configuration. To use images with
125
+ different axes, set the `axes` parameter. Note that the difference between
126
+ configuration and parameter axes must be S or T, but not any of the spatial
127
+ dimensions (e.g. 2D vs 3D).
128
+
129
+ Parameters
130
+ ----------
131
+ config : Configuration
132
+ Configuration.
133
+ pred_path : Union[str, Path]
134
+ Path to prediction data.
135
+ tile_shape : Optional[List[int]], optional
136
+ 2D or 3D shape of the tiles, by default None.
137
+ overlaps : Optional[List[int]], optional
138
+ 2D or 3D overlaps between tiles, by default None.
139
+ axes : Optional[str], optional
140
+ Axes of the data, by default None.
141
+
142
+ Returns
143
+ -------
144
+ TiffDataset
145
+ Dataset.
146
+ """
147
+ use_tiling = False # default value
148
+
149
+ # Validate tiles and overlaps
150
+ if tile_shape is not None and overlaps is not None:
151
+ check_tiling_validity(tile_shape, overlaps)
152
+
153
+ # Use tiling
154
+ use_tiling = True
155
+
156
+ # Extraction strategy
157
+ if use_tiling:
158
+ patch_extraction_method = ExtractionStrategy.TILED
159
+ else:
160
+ patch_extraction_method = None
161
+
162
+ # Create dataset
163
+ dataset = TiffDataset(
164
+ data_path=pred_path,
165
+ data_format=config.data.data_format,
166
+ axes=config.data.axes if axes is None else axes, # supersede axes
167
+ mean=config.data.mean,
168
+ std=config.data.std,
169
+ patch_size=tile_shape,
170
+ patch_overlap=overlaps,
171
+ patch_extraction_method=patch_extraction_method,
172
+ patch_transform=None,
173
+ )
174
+
175
+ return dataset
@@ -0,0 +1,212 @@
1
+ """
2
+ Tiff dataset module.
3
+
4
+ This module contains the implementation of the TiffDataset class, which allows loading
5
+ tiff files.
6
+ """
7
+ from pathlib import Path
8
+ from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from careamics.utils import normalize
14
+ from careamics.utils.logging import get_logger
15
+
16
+ from .dataset_utils import (
17
+ list_files,
18
+ read_tiff,
19
+ )
20
+ from .extraction_strategy import ExtractionStrategy
21
+ from .patching import generate_patches
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class TiffDataset(torch.utils.data.IterableDataset):
27
+ """
28
+ Dataset allowing extracting patches from tiff images.
29
+
30
+ Parameters
31
+ ----------
32
+ data_path : Union[str, Path]
33
+ Path to the data, must be a directory.
34
+ data_format : str
35
+ Extension of the files to load, without the period.
36
+ axes : str
37
+ Description of axes in format STCZYX.
38
+ patch_extraction_method : Union[ExtractionStrategies, None]
39
+ Patch extraction strategy, as defined in extraction_strategy.
40
+ patch_size : Optional[Union[List[int], Tuple[int]]], optional
41
+ Size of the patches in each dimension, by default None.
42
+ patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
43
+ Overlap of the patches in each dimension, by default None.
44
+ mean : Optional[float], optional
45
+ Expected mean of the dataset, by default None.
46
+ std : Optional[float], optional
47
+ Expected standard deviation of the dataset, by default None.
48
+ patch_transform : Optional[Callable], optional
49
+ Patch transform callable, by default None.
50
+ patch_transform_params : Optional[Dict], optional
51
+ Patch transform parameters, by default None.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ data_path: Union[str, Path],
57
+ data_format: str, # TODO: TiffDataset should not know that they are tiff
58
+ axes: str,
59
+ patch_extraction_method: Union[ExtractionStrategy, None],
60
+ patch_size: Optional[Union[List[int], Tuple[int]]] = None,
61
+ patch_overlap: Optional[Union[List[int], Tuple[int]]] = None,
62
+ mean: Optional[float] = None,
63
+ std: Optional[float] = None,
64
+ patch_transform: Optional[Callable] = None,
65
+ patch_transform_params: Optional[Dict] = None,
66
+ ) -> None:
67
+ """
68
+ Constructor.
69
+
70
+ Parameters
71
+ ----------
72
+ data_path : Union[str, Path]
73
+ Path to the data, must be a directory.
74
+ data_format : str
75
+ Extension of the files to load, without the period.
76
+ axes : str
77
+ Description of axes in format STCZYX.
78
+ patch_extraction_method : Union[ExtractionStrategies, None]
79
+ Patch extraction strategy, as defined in extraction_strategy.
80
+ patch_size : Optional[Union[List[int], Tuple[int]]], optional
81
+ Size of the patches in each dimension, by default None.
82
+ patch_overlap : Optional[Union[List[int], Tuple[int]]], optional
83
+ Overlap of the patches in each dimension, by default None.
84
+ mean : Optional[float], optional
85
+ Mean of the dataset, by default None.
86
+ std : Optional[float], optional
87
+ Standard deviation of the dataset, by default None.
88
+ patch_transform : Optional[Callable], optional
89
+ Patch transform callable, by default None.
90
+ patch_transform_params : Optional[Dict], optional
91
+ Patch transform parameters, by default None.
92
+
93
+ Raises
94
+ ------
95
+ ValueError
96
+ If data_path is not a directory.
97
+ """
98
+ self.data_path = Path(data_path)
99
+ if not self.data_path.is_dir():
100
+ raise ValueError("Path to data should be an existing folder.")
101
+
102
+ self.data_format = data_format
103
+ self.axes = axes
104
+
105
+ self.patch_transform = patch_transform
106
+
107
+ self.files = list_files(self.data_path, self.data_format)
108
+
109
+ self.mean = mean
110
+ self.std = std
111
+ if not mean or not std:
112
+ self.mean, self.std = self._calculate_mean_and_std()
113
+
114
+ self.patch_size = patch_size
115
+ self.patch_overlap = patch_overlap
116
+ self.patch_extraction_method = patch_extraction_method
117
+ self.patch_transform = patch_transform
118
+ self.patch_transform_params = patch_transform_params
119
+
120
+ def _calculate_mean_and_std(self) -> Tuple[float, float]:
121
+ """
122
+ Calculate mean and std of the dataset.
123
+
124
+ Returns
125
+ -------
126
+ Tuple[float, float]
127
+ Tuple containing mean and standard deviation.
128
+ """
129
+ means, stds = 0, 0
130
+ num_samples = 0
131
+
132
+ for sample in self._iterate_files():
133
+ means += sample.mean()
134
+ stds += np.std(sample)
135
+ num_samples += 1
136
+
137
+ result_mean = means / num_samples
138
+ result_std = stds / num_samples
139
+
140
+ logger.info(f"Calculated mean and std for {num_samples} images")
141
+ logger.info(f"Mean: {result_mean}, std: {result_std}")
142
+ return result_mean, result_std
143
+
144
+ def _iterate_files(self) -> Generator:
145
+ """
146
+ Iterate over data source and yield whole image.
147
+
148
+ Yields
149
+ ------
150
+ np.ndarray
151
+ Image.
152
+ """
153
+ # When num_workers > 0, each worker process will have a different copy of the
154
+ # dataset object
155
+ # Configuring each copy independently to avoid having duplicate data returned
156
+ # from the workers
157
+ worker_info = torch.utils.data.get_worker_info()
158
+ worker_id = worker_info.id if worker_info is not None else 0
159
+ num_workers = worker_info.num_workers if worker_info is not None else 1
160
+
161
+ for i, filename in enumerate(self.files):
162
+ if i % num_workers == worker_id:
163
+ sample = read_tiff(filename, self.axes)
164
+ yield sample
165
+
166
+ def __iter__(self) -> Generator[np.ndarray, None, None]:
167
+ """
168
+ Iterate over data source and yield single patch.
169
+
170
+ Yields
171
+ ------
172
+ np.ndarray
173
+ Single patch.
174
+ """
175
+ assert (
176
+ self.mean is not None and self.std is not None
177
+ ), "Mean and std must be provided"
178
+ for sample in self._iterate_files():
179
+ # TODO patch_extraction_method should never be None!
180
+ if self.patch_extraction_method:
181
+ # TODO: move S and T unpacking logic from patch generator
182
+ patches = generate_patches(
183
+ sample,
184
+ self.patch_extraction_method,
185
+ self.patch_size,
186
+ self.patch_overlap,
187
+ )
188
+
189
+ for patch in patches:
190
+ if isinstance(patch, tuple):
191
+ normalized_patch = normalize(
192
+ img=patch[0], mean=self.mean, std=self.std
193
+ )
194
+ patch = (normalized_patch, *patch[1:])
195
+ else:
196
+ patch = normalize(img=patch, mean=self.mean, std=self.std)
197
+
198
+ if self.patch_transform is not None:
199
+ assert self.patch_transform_params is not None
200
+ patch = self.patch_transform(
201
+ patch, **self.patch_transform_params
202
+ )
203
+
204
+ yield patch
205
+
206
+ else:
207
+ # if S or T dims are not empty - assume every image is a separate
208
+ # sample in dim 0
209
+ for i in range(sample.shape[0]):
210
+ item = np.expand_dims(sample[i], (0, 1))
211
+ item = normalize(img=item, mean=self.mean, std=self.std)
212
+ yield item