careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (66) hide show
  1. careamics/__init__.py +8 -6
  2. careamics/careamist.py +30 -29
  3. careamics/config/__init__.py +12 -9
  4. careamics/config/algorithm_model.py +5 -5
  5. careamics/config/architectures/unet_model.py +1 -0
  6. careamics/config/callback_model.py +1 -0
  7. careamics/config/configuration_example.py +87 -0
  8. careamics/config/configuration_factory.py +285 -78
  9. careamics/config/configuration_model.py +22 -23
  10. careamics/config/data_model.py +62 -160
  11. careamics/config/inference_model.py +20 -21
  12. careamics/config/references/algorithm_descriptions.py +1 -0
  13. careamics/config/references/references.py +1 -0
  14. careamics/config/support/supported_extraction_strategies.py +1 -0
  15. careamics/config/support/supported_optimizers.py +3 -3
  16. careamics/config/training_model.py +2 -1
  17. careamics/config/transformations/n2v_manipulate_model.py +2 -1
  18. careamics/config/transformations/nd_flip_model.py +7 -12
  19. careamics/config/transformations/normalize_model.py +2 -1
  20. careamics/config/transformations/transform_model.py +1 -0
  21. careamics/config/transformations/xy_random_rotate90_model.py +7 -9
  22. careamics/config/validators/validator_utils.py +1 -0
  23. careamics/conftest.py +1 -0
  24. careamics/dataset/dataset_utils/__init__.py +0 -1
  25. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  26. careamics/dataset/in_memory_dataset.py +17 -48
  27. careamics/dataset/iterable_dataset.py +16 -71
  28. careamics/dataset/patching/__init__.py +0 -7
  29. careamics/dataset/patching/patching.py +1 -0
  30. careamics/dataset/patching/sequential_patching.py +6 -6
  31. careamics/dataset/patching/tiled_patching.py +10 -6
  32. careamics/lightning_datamodule.py +123 -49
  33. careamics/lightning_module.py +7 -7
  34. careamics/lightning_prediction_datamodule.py +59 -48
  35. careamics/losses/__init__.py +0 -1
  36. careamics/losses/loss_factory.py +1 -0
  37. careamics/model_io/__init__.py +0 -1
  38. careamics/model_io/bioimage/_readme_factory.py +2 -1
  39. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  40. careamics/model_io/bioimage/model_description.py +4 -3
  41. careamics/model_io/bmz_io.py +8 -7
  42. careamics/model_io/model_io_utils.py +4 -4
  43. careamics/models/layers.py +1 -0
  44. careamics/models/model_factory.py +1 -0
  45. careamics/models/unet.py +91 -17
  46. careamics/prediction/stitch_prediction.py +1 -0
  47. careamics/transforms/__init__.py +2 -23
  48. careamics/transforms/compose.py +98 -0
  49. careamics/transforms/n2v_manipulate.py +18 -23
  50. careamics/transforms/nd_flip.py +38 -64
  51. careamics/transforms/normalize.py +45 -34
  52. careamics/transforms/pixel_manipulation.py +2 -2
  53. careamics/transforms/transform.py +33 -0
  54. careamics/transforms/tta.py +2 -2
  55. careamics/transforms/xy_random_rotate90.py +41 -68
  56. careamics/utils/__init__.py +0 -1
  57. careamics/utils/context.py +1 -0
  58. careamics/utils/logging.py +1 -0
  59. careamics/utils/metrics.py +1 -0
  60. careamics/utils/torch_utils.py +1 -0
  61. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  62. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  63. careamics/dataset/patching/patch_transform.py +0 -44
  64. careamics-0.1.0rc3.dist-info/RECORD +0 -109
  65. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  66. {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  """Files and arrays utils used in the datasets."""
2
2
 
3
-
4
3
  __all__ = [
5
4
  "reshape_array",
6
5
  "get_files_size",
@@ -1,4 +1,5 @@
1
1
  """Convenience methods for datasets."""
2
+
2
3
  from typing import List, Tuple
3
4
 
4
5
  import numpy as np
@@ -1,4 +1,5 @@
1
1
  """In-memory dataset module."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import copy
@@ -8,11 +9,12 @@ from typing import Any, Callable, List, Optional, Tuple, Union
8
9
  import numpy as np
9
10
  from torch.utils.data import Dataset
10
11
 
11
- from ..config import DataModel, InferenceModel
12
+ from careamics.transforms import Compose
13
+
14
+ from ..config import DataConfig, InferenceConfig
12
15
  from ..config.tile_information import TileInformation
13
16
  from ..utils.logging import get_logger
14
17
  from .dataset_utils import read_tiff, reshape_array
15
- from .patching.patch_transform import get_patch_transform
16
18
  from .patching.patching import (
17
19
  prepare_patches_supervised,
18
20
  prepare_patches_supervised_array,
@@ -29,7 +31,7 @@ class InMemoryDataset(Dataset):
29
31
 
30
32
  def __init__(
31
33
  self,
32
- data_config: DataModel,
34
+ data_config: DataConfig,
33
35
  inputs: Union[np.ndarray, List[Path]],
34
36
  data_target: Optional[Union[np.ndarray, List[Path]]] = None,
35
37
  read_source_func: Callable = read_tiff,
@@ -60,18 +62,15 @@ class InMemoryDataset(Dataset):
60
62
  self.mean, self.std = computed_mean, computed_std
61
63
  logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
62
64
 
63
- # if the transforms are not an instance of Compose
64
- if self.data_config.has_transform_list():
65
- # update mean and std in configuration
66
- # the object is mutable and should then be recorded in the CAREamist obj
67
- self.data_config.set_mean_and_std(self.mean, self.std)
65
+ # update mean and std in configuration
66
+ # the object is mutable and should then be recorded in the CAREamist obj
67
+ self.data_config.set_mean_and_std(self.mean, self.std)
68
68
  else:
69
69
  self.mean, self.std = self.data_config.mean, self.data_config.std
70
70
 
71
71
  # get transforms
72
- self.patch_transform = get_patch_transform(
73
- patch_transforms=self.data_config.transforms,
74
- with_target=self.data_target is not None,
72
+ self.patch_transform = Compose(
73
+ transform_list=self.data_config.transforms,
75
74
  )
76
75
 
77
76
  def _prepare_patches(
@@ -166,33 +165,10 @@ class InMemoryDataset(Dataset):
166
165
  # get target
167
166
  target = self.data_targets[index]
168
167
 
169
- # Albumentations requires Channel last
170
- c_patch = np.moveaxis(patch, 0, -1)
171
- c_target = np.moveaxis(target, 0, -1)
172
-
173
- # Apply transforms
174
- transformed = self.patch_transform(image=c_patch, target=c_target)
175
-
176
- # move axes back
177
- patch = np.moveaxis(transformed["image"], -1, 0)
178
- target = np.moveaxis(transformed["target"], -1, 0)
179
-
180
- return patch, target
168
+ return self.patch_transform(patch=patch, target=target)
181
169
 
182
170
  elif self.data_config.has_n2v_manipulate():
183
- # Albumentations requires Channel last
184
- patch = np.moveaxis(patch, 0, -1)
185
-
186
- # Apply transforms
187
- transformed_patch = self.patch_transform(image=patch)["image"]
188
- manip_patch, patch, mask = transformed_patch
189
-
190
- # move C axes back
191
- manip_patch = np.moveaxis(manip_patch, -1, 0)
192
- patch = np.moveaxis(patch, -1, 0)
193
- mask = np.moveaxis(mask, -1, 0)
194
-
195
- return (manip_patch, patch, mask)
171
+ return self.patch_transform(patch=patch)
196
172
  else:
197
173
  raise ValueError(
198
174
  "Something went wrong! No target provided (not supervised training) "
@@ -279,7 +255,7 @@ class InMemoryPredictionDataset(Dataset):
279
255
 
280
256
  def __init__(
281
257
  self,
282
- prediction_config: InferenceModel,
258
+ prediction_config: InferenceConfig,
283
259
  inputs: np.ndarray,
284
260
  data_target: Optional[np.ndarray] = None,
285
261
  read_source_func: Optional[Callable] = read_tiff,
@@ -318,9 +294,8 @@ class InMemoryPredictionDataset(Dataset):
318
294
  self.mean, self.std = self.pred_config.mean, self.pred_config.std
319
295
 
320
296
  # get transforms
321
- self.patch_transform = get_patch_transform(
322
- patch_transforms=self.pred_config.transforms,
323
- with_target=self.data_target is not None,
297
+ self.patch_transform = Compose(
298
+ transform_list=self.pred_config.transforms,
324
299
  )
325
300
 
326
301
  def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
@@ -379,13 +354,7 @@ class InMemoryPredictionDataset(Dataset):
379
354
  """
380
355
  tile_array, tile_info = self.data[index]
381
356
 
382
- # Albumentations requires channel last, use the XArrayTile array
383
- patch = np.moveaxis(tile_array, 0, -1)
384
-
385
357
  # Apply transforms
386
- transformed_patch = self.patch_transform(image=patch)["image"]
387
-
388
- # move C axes back
389
- transformed_patch = np.moveaxis(transformed_patch, -1, 0)
358
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
390
359
 
391
- return transformed_patch, tile_info
360
+ return transformed_tile, tile_info
@@ -7,13 +7,12 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
7
7
  import numpy as np
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
 
10
- from ..config import DataModel, InferenceModel
10
+ from careamics.transforms import Compose
11
+
12
+ from ..config import DataConfig, InferenceConfig
11
13
  from ..config.tile_information import TileInformation
12
14
  from ..utils.logging import get_logger
13
15
  from .dataset_utils import read_tiff, reshape_array
14
- from .patching import (
15
- get_patch_transform,
16
- )
17
16
  from .patching.random_patching import extract_patches_random
18
17
  from .patching.tiled_patching import extract_tiles
19
18
 
@@ -46,7 +45,7 @@ class PathIterableDataset(IterableDataset):
46
45
 
47
46
  def __init__(
48
47
  self,
49
- data_config: Union[DataModel, InferenceModel],
48
+ data_config: Union[DataConfig, InferenceConfig],
50
49
  src_files: List[Path],
51
50
  target_files: Optional[List[Path]] = None,
52
51
  read_source_func: Callable = read_tiff,
@@ -61,26 +60,15 @@ class PathIterableDataset(IterableDataset):
61
60
  if not data_config.mean or not data_config.std:
62
61
  self.mean, self.std = self._calculate_mean_and_std()
63
62
 
64
- # if the transforms are not an instance of Compose
65
- # Check if the data_config is an instance of DataModel or InferenceModel
66
- # isinstance isn't working properly here
67
- if hasattr(data_config, "has_transform_list"):
68
- if data_config.has_transform_list():
69
- # update mean and std in configuration
70
- # the object is mutable and should then be recorded in the CAREamist
71
- data_config.set_mean_and_std(self.mean, self.std)
72
- else:
73
- data_config.set_mean_and_std(self.mean, self.std)
74
-
63
+ # update mean and std in configuration
64
+ # the object is mutable and should then be recorded in the CAREamist
65
+ data_config.set_mean_and_std(self.mean, self.std)
75
66
  else:
76
67
  self.mean = data_config.mean
77
68
  self.std = data_config.std
78
69
 
79
70
  # get transforms
80
- self.patch_transform = get_patch_transform(
81
- patch_transforms=data_config.transforms,
82
- with_target=target_files is not None,
83
- )
71
+ self.patch_transform = Compose(transform_list=data_config.transforms)
84
72
 
85
73
  def _calculate_mean_and_std(self) -> Tuple[float, float]:
86
74
  """
@@ -192,49 +180,10 @@ class PathIterableDataset(IterableDataset):
192
180
  # or (patch, None) only if no target is available
193
181
  # patch is of dimensions (C)ZYX
194
182
  for patch_data in patches:
195
- # if there is a target
196
- if self.target_files is not None:
197
- # Albumentations expects the channel dimension to be last
198
- # Taking the first element because patch_data can include target
199
- c_patch = np.moveaxis(patch_data[0], 0, -1)
200
- c_target = np.moveaxis(patch_data[1], 0, -1)
201
-
202
- # apply the transform to the patch and the target
203
- transformed = self.patch_transform(
204
- image=c_patch,
205
- target=c_target,
206
- )
207
-
208
- # move the axes back to the original position
209
- c_patch = np.moveaxis(transformed["image"], -1, 0)
210
- c_target = np.moveaxis(transformed["target"], -1, 0)
211
-
212
- yield (c_patch, c_target)
213
- elif self.data_config.has_n2v_manipulate():
214
- # Albumentations expects the channel dimension to be last
215
- # Taking the first element because patch_data can include target
216
- patch = np.moveaxis(patch_data[0], 0, -1)
217
-
218
- # apply transform
219
- transformed = self.patch_transform(image=patch)
220
-
221
- # retrieve the output of ManipulateN2V
222
- results = transformed["image"]
223
- masked_patch: np.ndarray = results[0]
224
- original_patch: np.ndarray = results[1]
225
- mask: np.ndarray = results[2]
226
-
227
- # move C axes back
228
- masked_patch = np.moveaxis(masked_patch, -1, 0)
229
- original_patch = np.moveaxis(original_patch, -1, 0)
230
- mask = np.moveaxis(mask, -1, 0)
231
-
232
- yield (masked_patch, original_patch, mask)
233
- else:
234
- raise ValueError(
235
- "Something went wrong! Not target file (no supervised "
236
- "training) and no N2V transform (no n2v training either)."
237
- )
183
+ yield self.patch_transform(
184
+ patch=patch_data[0],
185
+ target=patch_data[1],
186
+ )
238
187
 
239
188
  def get_number_of_files(self) -> int:
240
189
  """
@@ -346,7 +295,7 @@ class IterablePredictionDataset(PathIterableDataset):
346
295
 
347
296
  def __init__(
348
297
  self,
349
- prediction_config: InferenceModel,
298
+ prediction_config: InferenceConfig,
350
299
  src_files: List[Path],
351
300
  read_source_func: Callable = read_tiff,
352
301
  **kwargs: Any,
@@ -367,9 +316,8 @@ class IterablePredictionDataset(PathIterableDataset):
367
316
  self.tile = self.tile_size is not None and self.tile_overlap is not None
368
317
 
369
318
  # get tta transforms
370
- self.patch_transform = get_patch_transform(
371
- patch_transforms=prediction_config.transforms,
372
- with_target=False,
319
+ self.patch_transform = Compose(
320
+ transform_list=prediction_config.transforms,
373
321
  )
374
322
 
375
323
  def __iter__(
@@ -408,9 +356,6 @@ class IterablePredictionDataset(PathIterableDataset):
408
356
 
409
357
  # apply transform to patches
410
358
  for patch_array, tile_info in patch_gen:
411
- # albumentations expects the channel dimension to be last
412
- patch = np.moveaxis(patch_array, 0, -1)
413
- transformed_patch = self.patch_transform(image=patch)
414
- transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
359
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
415
360
 
416
361
  yield transformed_patch, tile_info
@@ -1,8 +1 @@
1
1
  """Patching and tiling functions."""
2
-
3
-
4
- __all__ = [
5
- "get_patch_transform",
6
- ]
7
-
8
- from .patch_transform import get_patch_transform
@@ -3,6 +3,7 @@ Tiling submodule.
3
3
 
4
4
  These functions are used to tile images into patches or tiles.
5
5
  """
6
+
6
7
  from pathlib import Path
7
8
  from typing import Callable, List, Tuple, Union
8
9
 
@@ -135,15 +135,12 @@ def _compute_patch_views(
135
135
  arr = np.stack([arr, target], axis=0)
136
136
  window_shape = [arr.shape[0], *window_shape]
137
137
  step = (arr.shape[0], *step)
138
- output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]]
138
+ output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
139
139
 
140
140
  patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
141
141
  *output_shape
142
142
  )
143
- if target is not None:
144
- rng.shuffle(patches, axis=1)
145
- else:
146
- rng.shuffle(patches, axis=0)
143
+ rng.shuffle(patches, axis=0)
147
144
  return patches
148
145
 
149
146
 
@@ -201,6 +198,9 @@ def extract_patches_sequential(
201
198
 
202
199
  if target is not None:
203
200
  # target was concatenated to patches in _compute_reshaped_view
204
- return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view?
201
+ return (
202
+ patches[:, 0, ...],
203
+ patches[:, 1, ...],
204
+ ) # TODO in _compute_reshaped_view?
205
205
  else:
206
206
  return patches, None
@@ -43,9 +43,11 @@ def _compute_crop_and_stitch_coords_1d(
43
43
  stitch_coords.append(
44
44
  (
45
45
  i + overlap // 2 if i > 0 else 0,
46
- i + tile_size - overlap // 2
47
- if crop_coords[-1][1] < axis_size
48
- else axis_size,
46
+ (
47
+ i + tile_size - overlap // 2
48
+ if crop_coords[-1][1] < axis_size
49
+ else axis_size
50
+ ),
49
51
  )
50
52
  )
51
53
 
@@ -53,9 +55,11 @@ def _compute_crop_and_stitch_coords_1d(
53
55
  overlap_crop_coords.append(
54
56
  (
55
57
  overlap // 2 if i > 0 else 0,
56
- tile_size - overlap // 2
57
- if crop_coords[-1][1] < axis_size
58
- else tile_size,
58
+ (
59
+ tile_size - overlap // 2
60
+ if crop_coords[-1][1] < axis_size
61
+ else tile_size
62
+ ),
59
63
  )
60
64
  )
61
65