careamics 0.1.0rc4__py3-none-any.whl → 0.1.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (62) hide show
  1. careamics/careamist.py +12 -11
  2. careamics/config/__init__.py +0 -1
  3. careamics/config/architectures/unet_model.py +1 -0
  4. careamics/config/callback_model.py +1 -0
  5. careamics/config/configuration_example.py +0 -2
  6. careamics/config/configuration_factory.py +112 -42
  7. careamics/config/configuration_model.py +14 -16
  8. careamics/config/data_model.py +59 -157
  9. careamics/config/inference_model.py +19 -20
  10. careamics/config/references/algorithm_descriptions.py +1 -0
  11. careamics/config/references/references.py +1 -0
  12. careamics/config/support/supported_extraction_strategies.py +1 -0
  13. careamics/config/training_model.py +1 -0
  14. careamics/config/transformations/n2v_manipulate_model.py +1 -0
  15. careamics/config/transformations/nd_flip_model.py +6 -11
  16. careamics/config/transformations/normalize_model.py +1 -0
  17. careamics/config/transformations/transform_model.py +1 -0
  18. careamics/config/transformations/xy_random_rotate90_model.py +6 -8
  19. careamics/config/validators/validator_utils.py +1 -0
  20. careamics/conftest.py +1 -0
  21. careamics/dataset/dataset_utils/__init__.py +0 -1
  22. careamics/dataset/dataset_utils/dataset_utils.py +1 -0
  23. careamics/dataset/in_memory_dataset.py +14 -45
  24. careamics/dataset/iterable_dataset.py +13 -68
  25. careamics/dataset/patching/__init__.py +0 -7
  26. careamics/dataset/patching/patching.py +1 -0
  27. careamics/dataset/patching/sequential_patching.py +6 -6
  28. careamics/dataset/patching/tiled_patching.py +10 -6
  29. careamics/lightning_datamodule.py +20 -24
  30. careamics/lightning_module.py +1 -1
  31. careamics/lightning_prediction_datamodule.py +15 -10
  32. careamics/losses/__init__.py +0 -1
  33. careamics/losses/loss_factory.py +1 -0
  34. careamics/model_io/__init__.py +0 -1
  35. careamics/model_io/bioimage/_readme_factory.py +2 -1
  36. careamics/model_io/bioimage/bioimage_utils.py +1 -0
  37. careamics/model_io/bioimage/model_description.py +1 -0
  38. careamics/model_io/bmz_io.py +2 -1
  39. careamics/models/layers.py +1 -0
  40. careamics/models/model_factory.py +1 -0
  41. careamics/models/unet.py +91 -17
  42. careamics/prediction/stitch_prediction.py +1 -0
  43. careamics/transforms/__init__.py +2 -23
  44. careamics/transforms/compose.py +98 -0
  45. careamics/transforms/n2v_manipulate.py +18 -23
  46. careamics/transforms/nd_flip.py +38 -64
  47. careamics/transforms/normalize.py +45 -34
  48. careamics/transforms/pixel_manipulation.py +2 -2
  49. careamics/transforms/transform.py +33 -0
  50. careamics/transforms/tta.py +2 -2
  51. careamics/transforms/xy_random_rotate90.py +41 -68
  52. careamics/utils/__init__.py +0 -1
  53. careamics/utils/context.py +1 -0
  54. careamics/utils/logging.py +1 -0
  55. careamics/utils/metrics.py +1 -0
  56. careamics/utils/torch_utils.py +1 -0
  57. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/METADATA +16 -61
  58. careamics-0.1.0rc5.dist-info/RECORD +111 -0
  59. careamics/dataset/patching/patch_transform.py +0 -44
  60. careamics-0.1.0rc4.dist-info/RECORD +0 -110
  61. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/WHEEL +0 -0
  62. {careamics-0.1.0rc4.dist-info → careamics-0.1.0rc5.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,5 @@
1
1
  """In-memory dataset module."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import copy
@@ -8,11 +9,12 @@ from typing import Any, Callable, List, Optional, Tuple, Union
8
9
  import numpy as np
9
10
  from torch.utils.data import Dataset
10
11
 
12
+ from careamics.transforms import Compose
13
+
11
14
  from ..config import DataConfig, InferenceConfig
12
15
  from ..config.tile_information import TileInformation
13
16
  from ..utils.logging import get_logger
14
17
  from .dataset_utils import read_tiff, reshape_array
15
- from .patching.patch_transform import get_patch_transform
16
18
  from .patching.patching import (
17
19
  prepare_patches_supervised,
18
20
  prepare_patches_supervised_array,
@@ -60,18 +62,15 @@ class InMemoryDataset(Dataset):
60
62
  self.mean, self.std = computed_mean, computed_std
61
63
  logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}")
62
64
 
63
- # if the transforms are not an instance of Compose
64
- if self.data_config.has_transform_list():
65
- # update mean and std in configuration
66
- # the object is mutable and should then be recorded in the CAREamist obj
67
- self.data_config.set_mean_and_std(self.mean, self.std)
65
+ # update mean and std in configuration
66
+ # the object is mutable and should then be recorded in the CAREamist obj
67
+ self.data_config.set_mean_and_std(self.mean, self.std)
68
68
  else:
69
69
  self.mean, self.std = self.data_config.mean, self.data_config.std
70
70
 
71
71
  # get transforms
72
- self.patch_transform = get_patch_transform(
73
- patch_transforms=self.data_config.transforms,
74
- with_target=self.data_target is not None,
72
+ self.patch_transform = Compose(
73
+ transform_list=self.data_config.transforms,
75
74
  )
76
75
 
77
76
  def _prepare_patches(
@@ -166,33 +165,10 @@ class InMemoryDataset(Dataset):
166
165
  # get target
167
166
  target = self.data_targets[index]
168
167
 
169
- # Albumentations requires Channel last
170
- c_patch = np.moveaxis(patch, 0, -1)
171
- c_target = np.moveaxis(target, 0, -1)
172
-
173
- # Apply transforms
174
- transformed = self.patch_transform(image=c_patch, target=c_target)
175
-
176
- # move axes back
177
- patch = np.moveaxis(transformed["image"], -1, 0)
178
- target = np.moveaxis(transformed["target"], -1, 0)
179
-
180
- return patch, target
168
+ return self.patch_transform(patch=patch, target=target)
181
169
 
182
170
  elif self.data_config.has_n2v_manipulate():
183
- # Albumentations requires Channel last
184
- patch = np.moveaxis(patch, 0, -1)
185
-
186
- # Apply transforms
187
- transformed_patch = self.patch_transform(image=patch)["image"]
188
- manip_patch, patch, mask = transformed_patch
189
-
190
- # move C axes back
191
- manip_patch = np.moveaxis(manip_patch, -1, 0)
192
- patch = np.moveaxis(patch, -1, 0)
193
- mask = np.moveaxis(mask, -1, 0)
194
-
195
- return (manip_patch, patch, mask)
171
+ return self.patch_transform(patch=patch)
196
172
  else:
197
173
  raise ValueError(
198
174
  "Something went wrong! No target provided (not supervised training) "
@@ -318,9 +294,8 @@ class InMemoryPredictionDataset(Dataset):
318
294
  self.mean, self.std = self.pred_config.mean, self.pred_config.std
319
295
 
320
296
  # get transforms
321
- self.patch_transform = get_patch_transform(
322
- patch_transforms=self.pred_config.transforms,
323
- with_target=self.data_target is not None,
297
+ self.patch_transform = Compose(
298
+ transform_list=self.pred_config.transforms,
324
299
  )
325
300
 
326
301
  def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]:
@@ -379,13 +354,7 @@ class InMemoryPredictionDataset(Dataset):
379
354
  """
380
355
  tile_array, tile_info = self.data[index]
381
356
 
382
- # Albumentations requires channel last, use the XArrayTile array
383
- patch = np.moveaxis(tile_array, 0, -1)
384
-
385
357
  # Apply transforms
386
- transformed_patch = self.patch_transform(image=patch)["image"]
387
-
388
- # move C axes back
389
- transformed_patch = np.moveaxis(transformed_patch, -1, 0)
358
+ transformed_tile, _ = self.patch_transform(patch=tile_array)
390
359
 
391
- return transformed_patch, tile_info
360
+ return transformed_tile, tile_info
@@ -7,13 +7,12 @@ from typing import Any, Callable, Generator, List, Optional, Tuple, Union
7
7
  import numpy as np
8
8
  from torch.utils.data import IterableDataset, get_worker_info
9
9
 
10
+ from careamics.transforms import Compose
11
+
10
12
  from ..config import DataConfig, InferenceConfig
11
13
  from ..config.tile_information import TileInformation
12
14
  from ..utils.logging import get_logger
13
15
  from .dataset_utils import read_tiff, reshape_array
14
- from .patching import (
15
- get_patch_transform,
16
- )
17
16
  from .patching.random_patching import extract_patches_random
18
17
  from .patching.tiled_patching import extract_tiles
19
18
 
@@ -61,26 +60,15 @@ class PathIterableDataset(IterableDataset):
61
60
  if not data_config.mean or not data_config.std:
62
61
  self.mean, self.std = self._calculate_mean_and_std()
63
62
 
64
- # if the transforms are not an instance of Compose
65
- # Check if the data_config is an instance of DataModel or InferenceModel
66
- # isinstance isn't working properly here
67
- if hasattr(data_config, "has_transform_list"):
68
- if data_config.has_transform_list():
69
- # update mean and std in configuration
70
- # the object is mutable and should then be recorded in the CAREamist
71
- data_config.set_mean_and_std(self.mean, self.std)
72
- else:
73
- data_config.set_mean_and_std(self.mean, self.std)
74
-
63
+ # update mean and std in configuration
64
+ # the object is mutable and should then be recorded in the CAREamist
65
+ data_config.set_mean_and_std(self.mean, self.std)
75
66
  else:
76
67
  self.mean = data_config.mean
77
68
  self.std = data_config.std
78
69
 
79
70
  # get transforms
80
- self.patch_transform = get_patch_transform(
81
- patch_transforms=data_config.transforms,
82
- with_target=target_files is not None,
83
- )
71
+ self.patch_transform = Compose(transform_list=data_config.transforms)
84
72
 
85
73
  def _calculate_mean_and_std(self) -> Tuple[float, float]:
86
74
  """
@@ -192,49 +180,10 @@ class PathIterableDataset(IterableDataset):
192
180
  # or (patch, None) only if no target is available
193
181
  # patch is of dimensions (C)ZYX
194
182
  for patch_data in patches:
195
- # if there is a target
196
- if self.target_files is not None:
197
- # Albumentations expects the channel dimension to be last
198
- # Taking the first element because patch_data can include target
199
- c_patch = np.moveaxis(patch_data[0], 0, -1)
200
- c_target = np.moveaxis(patch_data[1], 0, -1)
201
-
202
- # apply the transform to the patch and the target
203
- transformed = self.patch_transform(
204
- image=c_patch,
205
- target=c_target,
206
- )
207
-
208
- # move the axes back to the original position
209
- c_patch = np.moveaxis(transformed["image"], -1, 0)
210
- c_target = np.moveaxis(transformed["target"], -1, 0)
211
-
212
- yield (c_patch, c_target)
213
- elif self.data_config.has_n2v_manipulate():
214
- # Albumentations expects the channel dimension to be last
215
- # Taking the first element because patch_data can include target
216
- patch = np.moveaxis(patch_data[0], 0, -1)
217
-
218
- # apply transform
219
- transformed = self.patch_transform(image=patch)
220
-
221
- # retrieve the output of ManipulateN2V
222
- results = transformed["image"]
223
- masked_patch: np.ndarray = results[0]
224
- original_patch: np.ndarray = results[1]
225
- mask: np.ndarray = results[2]
226
-
227
- # move C axes back
228
- masked_patch = np.moveaxis(masked_patch, -1, 0)
229
- original_patch = np.moveaxis(original_patch, -1, 0)
230
- mask = np.moveaxis(mask, -1, 0)
231
-
232
- yield (masked_patch, original_patch, mask)
233
- else:
234
- raise ValueError(
235
- "Something went wrong! Not target file (no supervised "
236
- "training) and no N2V transform (no n2v training either)."
237
- )
183
+ yield self.patch_transform(
184
+ patch=patch_data[0],
185
+ target=patch_data[1],
186
+ )
238
187
 
239
188
  def get_number_of_files(self) -> int:
240
189
  """
@@ -367,9 +316,8 @@ class IterablePredictionDataset(PathIterableDataset):
367
316
  self.tile = self.tile_size is not None and self.tile_overlap is not None
368
317
 
369
318
  # get tta transforms
370
- self.patch_transform = get_patch_transform(
371
- patch_transforms=prediction_config.transforms,
372
- with_target=False,
319
+ self.patch_transform = Compose(
320
+ transform_list=prediction_config.transforms,
373
321
  )
374
322
 
375
323
  def __iter__(
@@ -408,9 +356,6 @@ class IterablePredictionDataset(PathIterableDataset):
408
356
 
409
357
  # apply transform to patches
410
358
  for patch_array, tile_info in patch_gen:
411
- # albumentations expects the channel dimension to be last
412
- patch = np.moveaxis(patch_array, 0, -1)
413
- transformed_patch = self.patch_transform(image=patch)
414
- transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0)
359
+ transformed_patch, _ = self.patch_transform(patch=patch_array)
415
360
 
416
361
  yield transformed_patch, tile_info
@@ -1,8 +1 @@
1
1
  """Patching and tiling functions."""
2
-
3
-
4
- __all__ = [
5
- "get_patch_transform",
6
- ]
7
-
8
- from .patch_transform import get_patch_transform
@@ -3,6 +3,7 @@ Tiling submodule.
3
3
 
4
4
  These functions are used to tile images into patches or tiles.
5
5
  """
6
+
6
7
  from pathlib import Path
7
8
  from typing import Callable, List, Tuple, Union
8
9
 
@@ -135,15 +135,12 @@ def _compute_patch_views(
135
135
  arr = np.stack([arr, target], axis=0)
136
136
  window_shape = [arr.shape[0], *window_shape]
137
137
  step = (arr.shape[0], *step)
138
- output_shape = [arr.shape[0], -1, arr.shape[2], *output_shape[2:]]
138
+ output_shape = [-1, arr.shape[0], arr.shape[2], *output_shape[2:]]
139
139
 
140
140
  patches = view_as_windows(arr, window_shape=window_shape, step=step).reshape(
141
141
  *output_shape
142
142
  )
143
- if target is not None:
144
- rng.shuffle(patches, axis=1)
145
- else:
146
- rng.shuffle(patches, axis=0)
143
+ rng.shuffle(patches, axis=0)
147
144
  return patches
148
145
 
149
146
 
@@ -201,6 +198,9 @@ def extract_patches_sequential(
201
198
 
202
199
  if target is not None:
203
200
  # target was concatenated to patches in _compute_reshaped_view
204
- return (patches[0, ...], patches[1, ...]) # TODO in _compute_reshaped_view?
201
+ return (
202
+ patches[:, 0, ...],
203
+ patches[:, 1, ...],
204
+ ) # TODO in _compute_reshaped_view?
205
205
  else:
206
206
  return patches, None
@@ -43,9 +43,11 @@ def _compute_crop_and_stitch_coords_1d(
43
43
  stitch_coords.append(
44
44
  (
45
45
  i + overlap // 2 if i > 0 else 0,
46
- i + tile_size - overlap // 2
47
- if crop_coords[-1][1] < axis_size
48
- else axis_size,
46
+ (
47
+ i + tile_size - overlap // 2
48
+ if crop_coords[-1][1] < axis_size
49
+ else axis_size
50
+ ),
49
51
  )
50
52
  )
51
53
 
@@ -53,9 +55,11 @@ def _compute_crop_and_stitch_coords_1d(
53
55
  overlap_crop_coords.append(
54
56
  (
55
57
  overlap // 2 if i > 0 else 0,
56
- tile_size - overlap // 2
57
- if crop_coords[-1][1] < axis_size
58
- else tile_size,
58
+ (
59
+ tile_size - overlap // 2
60
+ if crop_coords[-1][1] < axis_size
61
+ else tile_size
62
+ ),
59
63
  )
60
64
  )
61
65
 
@@ -1,10 +1,10 @@
1
1
  """Training and validation Lightning data modules."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
7
- from albumentations import Compose
8
8
  from torch.utils.data import DataLoader
9
9
 
10
10
  from careamics.config import DataConfig
@@ -341,9 +341,9 @@ class CAREamicsTrainData(L.LightningDataModule):
341
341
  self.train_dataset = InMemoryDataset(
342
342
  data_config=self.data_config,
343
343
  inputs=self.train_files,
344
- data_target=self.train_target_files
345
- if self.train_data_target
346
- else None,
344
+ data_target=(
345
+ self.train_target_files if self.train_data_target else None
346
+ ),
347
347
  read_source_func=self.read_source_func,
348
348
  )
349
349
 
@@ -352,9 +352,9 @@ class CAREamicsTrainData(L.LightningDataModule):
352
352
  self.val_dataset = InMemoryDataset(
353
353
  data_config=self.data_config,
354
354
  inputs=self.val_files,
355
- data_target=self.val_target_files
356
- if self.val_data_target
357
- else None,
355
+ data_target=(
356
+ self.val_target_files if self.val_data_target else None
357
+ ),
358
358
  read_source_func=self.read_source_func,
359
359
  )
360
360
  else:
@@ -370,9 +370,9 @@ class CAREamicsTrainData(L.LightningDataModule):
370
370
  self.train_dataset = PathIterableDataset(
371
371
  data_config=self.data_config,
372
372
  src_files=self.train_files,
373
- target_files=self.train_target_files
374
- if self.train_data_target
375
- else None,
373
+ target_files=(
374
+ self.train_target_files if self.train_data_target else None
375
+ ),
376
376
  read_source_func=self.read_source_func,
377
377
  )
378
378
 
@@ -382,9 +382,9 @@ class CAREamicsTrainData(L.LightningDataModule):
382
382
  self.val_dataset = PathIterableDataset(
383
383
  data_config=self.data_config,
384
384
  src_files=self.val_files,
385
- target_files=self.val_target_files
386
- if self.val_data_target
387
- else None,
385
+ target_files=(
386
+ self.val_target_files if self.val_data_target else None
387
+ ),
388
388
  read_source_func=self.read_source_func,
389
389
  )
390
390
  elif len(self.train_files) <= self.val_minimum_split:
@@ -452,8 +452,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
452
452
  In particular, N2V requires a specific transformation (N2V manipulates), which is
453
453
  not compatible with supervised training. The default transformations applied to the
454
454
  training patches are defined in `careamics.config.data_model`. To use different
455
- transformations, pass a list of transforms or an albumentation `Compose` as
456
- `transforms` parameter. See examples for more details.
455
+ transformations, pass a list of transforms. See examples for more details.
457
456
 
458
457
  By default, CAREamics only supports types defined in
459
458
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -488,7 +487,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
488
487
  Batch size.
489
488
  val_data : Optional[Union[str, Path]], optional
490
489
  Validation data, by default None.
491
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
490
+ transforms : List[TRANSFORMS_UNION], optional
492
491
  List of transforms to apply to training patches. If None, default transforms
493
492
  are applied.
494
493
  train_target_data : Optional[Union[str, Path]], optional
@@ -584,7 +583,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
584
583
  axes: str,
585
584
  batch_size: int,
586
585
  val_data: Optional[Union[str, Path]] = None,
587
- transforms: Optional[Union[List[TRANSFORMS_UNION], Compose]] = None,
586
+ transforms: Optional[List[TRANSFORMS_UNION]] = None,
588
587
  train_target_data: Optional[Union[str, Path]] = None,
589
588
  val_target_data: Optional[Union[str, Path]] = None,
590
589
  read_source_func: Optional[Callable] = None,
@@ -617,8 +616,8 @@ class TrainingDataWrapper(CAREamicsTrainData):
617
616
  In particular, N2V requires a specific transformation (N2V manipulates), which
618
617
  is not compatible with supervised training. The default transformations applied
619
618
  to the training patches are defined in `careamics.config.data_model`. To use
620
- different transformations, pass a list of transforms or an albumentation
621
- `Compose` as `transforms` parameter. See examples for more details.
619
+ different transformations, pass a list of transforms. See examples for more
620
+ details.
622
621
 
623
622
  By default, CAREamics only supports types defined in
624
623
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -655,7 +654,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
655
654
  Batch size.
656
655
  val_data : Optional[Union[str, Path]], optional
657
656
  Validation data, by default None.
658
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
657
+ transforms : Optional[List[TRANSFORMS_UNION]], optional
659
658
  List of transforms to apply to training patches. If None, default transforms
660
659
  are applied.
661
660
  train_target_data : Optional[Union[str, Path]], optional
@@ -709,10 +708,7 @@ class TrainingDataWrapper(CAREamicsTrainData):
709
708
  self.data_config = DataConfig(**data_dict)
710
709
 
711
710
  # N2V specific checks, N2V, structN2V, and transforms
712
- if (
713
- self.data_config.has_transform_list()
714
- and self.data_config.has_n2v_manipulate()
715
- ):
711
+ if self.data_config.has_n2v_manipulate():
716
712
  # there is not target, n2v2 and structN2V can be changed
717
713
  if train_target_data is None:
718
714
  self.data_config.set_N2V2(use_n2v2)
@@ -162,7 +162,7 @@ class CAREamicsModule(L.LightningModule):
162
162
  mean=self._trainer.datamodule.predict_dataset.mean,
163
163
  std=self._trainer.datamodule.predict_dataset.std,
164
164
  )
165
- denormalized_output = denorm(image=output)["image"]
165
+ denormalized_output, _ = denorm(patch=output)
166
166
 
167
167
  if len(aux) > 0:
168
168
  return denormalized_output, aux
@@ -1,10 +1,10 @@
1
1
  """Prediction Lightning data modules."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, List, Literal, Optional, Tuple, Union
4
5
 
5
6
  import numpy as np
6
7
  import pytorch_lightning as L
7
- from albumentations import Compose
8
8
  from torch.utils.data import DataLoader
9
9
  from torch.utils.data.dataloader import default_collate
10
10
 
@@ -39,7 +39,7 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
39
39
 
40
40
  Parameters
41
41
  ----------
42
- batch : Tuple[Tuple[np.ndarray, TileInformation], ...]
42
+ batch : List[Tuple[np.ndarray, TileInformation], ...]
43
43
  Batch of tiles.
44
44
 
45
45
  Returns
@@ -257,14 +257,13 @@ class PredictDataWrapper(CAREamicsPredictData):
257
257
 
258
258
  The default transformations applied to the images are defined in
259
259
  `careamics.config.inference_model`. To use different transformations, pass a list
260
- of transforms or an albumentation `Compose` as `transforms` parameter. See examples
260
+ of transforms. See examples
261
261
  for more details.
262
262
 
263
263
  The `mean` and `std` parameters are only used if Normalization is defined either
264
- in the default transformations or in the `transforms` parameter, but not with
265
- a `Compose` object. If you pass a `Normalization` transform in a list as
266
- `transforms`, then the mean and std parameters will be overwritten by those passed
267
- to this method.
264
+ in the default transformations or in the `transforms` parameter. If you pass a
265
+ `Normalization` transform in a list as `transforms`, then the mean and std
266
+ parameters will be overwritten by those passed to this method.
268
267
 
269
268
  By default, CAREamics only supports types defined in
270
269
  `careamics.config.support.SupportedData`. To read custom data types, you can set
@@ -276,6 +275,12 @@ class PredictDataWrapper(CAREamicsPredictData):
276
275
  dataloaders, except for `batch_size`, which is set by the `batch_size`
277
276
  parameter.
278
277
 
278
+ Note that if you are using a UNet model and tiling, the tile size must be
279
+ divisible in every dimension by 2**d, where d is the depth of the model. This
280
+ avoids artefacts arising from the broken shift invariance induced by the
281
+ pooling layers of the UNet. If your image has less dimensions, as it may
282
+ happen in the Z dimension, consider padding your image.
283
+
279
284
  Parameters
280
285
  ----------
281
286
  pred_data : Union[str, Path, np.ndarray]
@@ -298,7 +303,7 @@ class PredictDataWrapper(CAREamicsPredictData):
298
303
  Batch size.
299
304
  tta_transforms : bool, optional
300
305
  Use test time augmentation, by default True.
301
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
306
+ transforms : List, optional
302
307
  List of transforms to apply to prediction patches. If None, default
303
308
  transforms are applied.
304
309
  read_source_func : Optional[Callable], optional
@@ -321,7 +326,7 @@ class PredictDataWrapper(CAREamicsPredictData):
321
326
  axes: str = "YX",
322
327
  batch_size: int = 1,
323
328
  tta_transforms: bool = True,
324
- transforms: Optional[Union[List, Compose]] = None,
329
+ transforms: Optional[List] = None,
325
330
  read_source_func: Optional[Callable] = None,
326
331
  extension_filter: str = "",
327
332
  dataloader_params: Optional[dict] = None,
@@ -351,7 +356,7 @@ class PredictDataWrapper(CAREamicsPredictData):
351
356
  Batch size.
352
357
  tta_transforms : bool, optional
353
358
  Use test time augmentation, by default True.
354
- transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
359
+ transforms : Optional[List], optional
355
360
  List of transforms to apply to prediction patches. If None, default
356
361
  transforms are applied.
357
362
  read_source_func : Optional[Callable], optional
@@ -1,6 +1,5 @@
1
1
  """Losses module."""
2
2
 
3
-
4
3
  from .loss_factory import loss_factory
5
4
 
6
5
  # from .noise_model_factory import noise_model_factory as noise_model_factory
@@ -3,6 +3,7 @@ Loss factory module.
3
3
 
4
4
  This module contains a factory function for creating loss functions.
5
5
  """
6
+
6
7
  from typing import Callable, Union
7
8
 
8
9
  from ..config.support import SupportedLoss
@@ -1,6 +1,5 @@
1
1
  """Model I/O utilities."""
2
2
 
3
-
4
3
  __all__ = ["load_pretrained", "export_to_bmz"]
5
4
 
6
5
 
@@ -1,4 +1,5 @@
1
1
  """Functions used to create a README.md file for BMZ export."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Optional
4
5
 
@@ -117,4 +118,4 @@ def readme_factory(
117
118
 
118
119
  readme.write_text("".join(description))
119
120
 
120
- return readme
121
+ return readme.absolute()
@@ -1,4 +1,5 @@
1
1
  """Bioimage.io utils."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Module use to build BMZ model description."""
2
+
2
3
  from pathlib import Path
3
4
  from typing import List, Optional, Tuple, Union
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Function to export to the BioImage Model Zoo format."""
2
+
2
3
  import tempfile
3
4
  from pathlib import Path
4
5
  from typing import List, Optional, Tuple, Union
@@ -177,7 +178,7 @@ def export_to_bmz(
177
178
  )
178
179
 
179
180
  # test model description
180
- summary: ValidationSummary = test_model(model_description)
181
+ summary: ValidationSummary = test_model(model_description, decimal=0)
181
182
  if summary.status == "failed":
182
183
  raise ValueError(f"Model description test failed: {summary}")
183
184
 
@@ -3,6 +3,7 @@ Layer module.
3
3
 
4
4
  This submodule contains layers used in the CAREamics models.
5
5
  """
6
+
6
7
  from typing import List, Optional, Tuple, Union
7
8
 
8
9
  import torch
@@ -3,6 +3,7 @@ Model factory.
3
3
 
4
4
  Model creation factory functions.
5
5
  """
6
+
6
7
  from typing import Union
7
8
 
8
9
  import torch