careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (74) hide show
  1. careamics/careamist.py +4 -3
  2. careamics/cli/utils.py +1 -1
  3. careamics/config/algorithms/n2v_algorithm_model.py +1 -1
  4. careamics/config/architectures/unet_model.py +3 -0
  5. careamics/config/callback_model.py +23 -34
  6. careamics/config/configuration.py +47 -1
  7. careamics/config/configuration_factories.py +288 -23
  8. careamics/config/data/__init__.py +2 -0
  9. careamics/config/data/data_model.py +3 -3
  10. careamics/config/data/ng_data_model.py +381 -0
  11. careamics/config/data/patching_strategies/__init__.py +14 -0
  12. careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
  13. careamics/config/data/patching_strategies/_patched_model.py +56 -0
  14. careamics/config/data/patching_strategies/random_patching_model.py +21 -0
  15. careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
  16. careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
  17. careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
  18. careamics/config/inference_model.py +6 -3
  19. careamics/config/support/supported_data.py +7 -0
  20. careamics/config/support/supported_patching_strategies.py +22 -0
  21. careamics/config/validators/validator_utils.py +4 -3
  22. careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
  23. careamics/dataset/in_memory_dataset.py +2 -1
  24. careamics/dataset/iterable_dataset.py +2 -2
  25. careamics/dataset/iterable_pred_dataset.py +2 -2
  26. careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
  27. careamics/dataset/patching/patching.py +3 -2
  28. careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
  29. careamics/dataset/tiling/tiled_patching.py +2 -1
  30. careamics/dataset_ng/dataset.py +46 -50
  31. careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
  32. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
  33. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
  34. careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
  35. careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
  36. careamics/dataset_ng/factory.py +58 -15
  37. careamics/dataset_ng/legacy_interoperability.py +3 -1
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
  39. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
  40. careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
  41. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
  42. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
  43. careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
  44. careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
  45. careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
  46. careamics/file_io/read/get_func.py +2 -1
  47. careamics/lightning/dataset_ng/__init__.py +1 -0
  48. careamics/lightning/dataset_ng/data_module.py +218 -28
  49. careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
  50. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
  51. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
  52. careamics/lightning/lightning_module.py +2 -1
  53. careamics/lightning/predict_data_module.py +2 -1
  54. careamics/lightning/train_data_module.py +2 -1
  55. careamics/losses/loss_factory.py +2 -1
  56. careamics/lvae_training/dataset/multicrop_dset.py +1 -1
  57. careamics/model_io/bioimage/bioimage_utils.py +1 -1
  58. careamics/model_io/bioimage/model_description.py +1 -1
  59. careamics/model_io/bmz_io.py +1 -1
  60. careamics/model_io/model_io_utils.py +2 -2
  61. careamics/models/activation.py +2 -1
  62. careamics/models/unet.py +16 -10
  63. careamics/prediction_utils/prediction_outputs.py +1 -1
  64. careamics/prediction_utils/stitch_prediction.py +1 -1
  65. careamics/transforms/n2v_manipulate_torch.py +15 -9
  66. careamics/transforms/pixel_manipulation_torch.py +59 -92
  67. careamics/utils/lightning_utils.py +2 -2
  68. careamics/utils/metrics.py +2 -1
  69. careamics/utils/torch_utils.py +23 -0
  70. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
  71. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
  72. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
  73. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
  74. {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
@@ -16,12 +16,15 @@ class SupportedData(str, BaseEnum):
16
16
  Array data.
17
17
  TIFF : str
18
18
  TIFF image data.
19
+ CZI : str
20
+ CZI image data.
19
21
  CUSTOM : str
20
22
  Custom data.
21
23
  """
22
24
 
23
25
  ARRAY = "array"
24
26
  TIFF = "tiff"
27
+ CZI = "czi"
25
28
  CUSTOM = "custom"
26
29
  # ZARR = "zarr"
27
30
 
@@ -78,6 +81,8 @@ class SupportedData(str, BaseEnum):
78
81
  raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
79
82
  elif data_type == cls.TIFF:
80
83
  return "*.tif*"
84
+ elif data_type == cls.CZI:
85
+ return "*.czi"
81
86
  elif data_type == cls.CUSTOM:
82
87
  return "*.*"
83
88
  else:
@@ -102,6 +107,8 @@ class SupportedData(str, BaseEnum):
102
107
  raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
103
108
  elif data_type == cls.TIFF:
104
109
  return ".tiff"
110
+ elif data_type == cls.CZI:
111
+ return ".czi"
105
112
  elif data_type == cls.CUSTOM:
106
113
  # TODO: improve this message
107
114
  raise NotImplementedError("Custom extensions have to be passed elsewhere.")
@@ -0,0 +1,22 @@
1
+ """Patching strategies supported by Careamics."""
2
+
3
+ from careamics.utils import BaseEnum
4
+
5
+
6
+ class SupportedPatchingStrategy(str, BaseEnum):
7
+ """Patching strategies supported by Careamics."""
8
+
9
+ FIXED_RANDOM = "fixed_random"
10
+ """Fixed random patching strategy, used during training."""
11
+
12
+ RANDOM = "random"
13
+ """Random patching strategy, used during training."""
14
+
15
+ # SEQUENTIAL = "sequential"
16
+ # """Sequential patching strategy, used during training."""
17
+
18
+ TILED = "tiled"
19
+ """Tiled patching strategy, used during prediction."""
20
+
21
+ WHOLE = "whole"
22
+ """Whole image patching strategy, used during prediction."""
@@ -4,7 +4,8 @@ Validator functions.
4
4
  These functions are used to validate dimensions and axes of inputs.
5
5
  """
6
6
 
7
- from typing import Optional, Union
7
+ from collections.abc import Sequence
8
+ from typing import Optional
8
9
 
9
10
  _AXES = "STCZYX"
10
11
 
@@ -79,14 +80,14 @@ def value_ge_than_8_power_of_2(
79
80
 
80
81
 
81
82
  def patch_size_ge_than_8_power_of_2(
82
- patch_list: Optional[Union[list[int], Union[tuple[int, ...]]]],
83
+ patch_list: Optional[Sequence[int]],
83
84
  ) -> None:
84
85
  """
85
86
  Validate that each entry is greater or equal than 8 and a power of 2.
86
87
 
87
88
  Parameters
88
89
  ----------
89
- patch_list : list or typle of int, or None
90
+ patch_list : Sequence of int, or None
90
91
  Patch size.
91
92
 
92
93
  Raises
@@ -2,9 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Generator
5
+ from collections.abc import Callable, Generator
6
6
  from pathlib import Path
7
- from typing import Callable, Optional, Union
7
+ from typing import Optional, Union
8
8
 
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import get_worker_info
@@ -3,8 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import copy
6
+ from collections.abc import Callable
6
7
  from pathlib import Path
7
- from typing import Any, Callable, Optional, Union
8
+ from typing import Any, Optional, Union
8
9
 
9
10
  import numpy as np
10
11
  from torch.utils.data import Dataset
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import copy
6
- from collections.abc import Generator
6
+ from collections.abc import Callable, Generator
7
7
  from pathlib import Path
8
- from typing import Callable, Optional
8
+ from typing import Optional
9
9
 
10
10
  import numpy as np
11
11
  from torch.utils.data import IterableDataset
@@ -2,9 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Generator
5
+ from collections.abc import Callable, Generator
6
6
  from pathlib import Path
7
- from typing import Any, Callable
7
+ from typing import Any
8
8
 
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import IterableDataset
@@ -2,9 +2,9 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from collections.abc import Generator
5
+ from collections.abc import Callable, Generator
6
6
  from pathlib import Path
7
- from typing import Any, Callable
7
+ from typing import Any
8
8
 
9
9
  from numpy.typing import NDArray
10
10
  from torch.utils.data import IterableDataset
@@ -1,8 +1,9 @@
1
1
  """Patching functions."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
5
- from typing import Callable, Union
6
+ from typing import Union
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -89,7 +90,7 @@ def prepare_patches_supervised(
89
90
  """
90
91
  means, stds, num_samples = 0, 0, 0
91
92
  all_patches, all_targets = [], []
92
- for train_filename, target_filename in zip(train_files, target_files):
93
+ for train_filename, target_filename in zip(train_files, target_files, strict=False):
93
94
  try:
94
95
  sample: np.ndarray = read_source_func(train_filename, axes)
95
96
  target: np.ndarray = read_source_func(target_filename, axes)
@@ -78,7 +78,9 @@ def extract_tiles(
78
78
  ...,
79
79
  *[
80
80
  slice(coords, coords + extent)
81
- for coords, extent in zip(crop_coords_start, tile_size)
81
+ for coords, extent in zip(
82
+ crop_coords_start, tile_size, strict=False
83
+ )
82
84
  ],
83
85
  )
84
86
  tile = sample[crop_slices]
@@ -159,11 +161,14 @@ def compute_tile_info_legacy(
159
161
 
160
162
  # --- combine start and end
161
163
  stitch_coords = tuple(
162
- (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
164
+ (start, end)
165
+ for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
163
166
  )
164
167
  overlap_crop_coords = tuple(
165
168
  (start, end)
166
- for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
169
+ for start, end in zip(
170
+ overlap_crop_coords_start, overlap_crop_coords_end, strict=False
171
+ )
167
172
  )
168
173
 
169
174
  tile_info = TileInformation(
@@ -229,11 +234,14 @@ def compute_tile_info(
229
234
 
230
235
  # --- combine start and end
231
236
  stitch_coords = tuple(
232
- (start, end) for start, end in zip(stitch_coords_start, stitch_coords_end)
237
+ (start, end)
238
+ for start, end in zip(stitch_coords_start, stitch_coords_end, strict=False)
233
239
  )
234
240
  overlap_crop_coords = tuple(
235
241
  (start, end)
236
- for start, end in zip(overlap_crop_coords_start, overlap_crop_coords_end)
242
+ for start, end in zip(
243
+ overlap_crop_coords_start, overlap_crop_coords_end, strict=False
244
+ )
237
245
  )
238
246
 
239
247
  # --- Check if last tile
@@ -284,7 +292,9 @@ def compute_padding(
284
292
  pad_before = overlaps // 2
285
293
  pad_after = covered_shape - data_shape[-len(tile_size) :] - pad_before
286
294
 
287
- return tuple((before, after) for before, after in zip(pad_before, pad_after))
295
+ return tuple(
296
+ (before, after) for before, after in zip(pad_before, pad_after, strict=False)
297
+ )
288
298
 
289
299
 
290
300
  def n_tiles_1d(axis_size: int, tile_size: int, overlap: int) -> int:
@@ -127,7 +127,7 @@ def extract_tiles(
127
127
  # Rearrange crop coordinates from a list of coordinate pairs per axis to a list
128
128
  # grouped by type.
129
129
  all_crop_coords, all_stitch_coords, all_overlap_crop_coords = zip(
130
- *crop_and_stitch_coords_list
130
+ *crop_and_stitch_coords_list, strict=False
131
131
  )
132
132
 
133
133
  # Maximum tile index
@@ -139,6 +139,7 @@ def extract_tiles(
139
139
  itertools.product(*all_crop_coords),
140
140
  itertools.product(*all_stitch_coords),
141
141
  itertools.product(*all_overlap_crop_coords),
142
+ strict=False,
142
143
  )
143
144
  ):
144
145
  # Extract tile from the sample
@@ -8,7 +8,10 @@ from numpy.typing import NDArray
8
8
  from torch.utils.data import Dataset
9
9
  from tqdm.auto import tqdm
10
10
 
11
- from careamics.config import DataConfig, InferenceConfig
11
+ from careamics.config.data.ng_data_model import NGDataConfig
12
+ from careamics.config.support.supported_patching_strategies import (
13
+ SupportedPatchingStrategy,
14
+ )
12
15
  from careamics.config.transformations import NormalizeModel
13
16
  from careamics.dataset.dataset_utils.running_stats import WelfordStatistics
14
17
  from careamics.dataset.patching.patching import Stats
@@ -45,7 +48,7 @@ InputType = Union[Sequence[NDArray[Any]], Sequence[Path]]
45
48
  class CareamicsDataset(Dataset, Generic[GenericImageStack]):
46
49
  def __init__(
47
50
  self,
48
- data_config: Union[DataConfig, InferenceConfig],
51
+ data_config: NGDataConfig,
49
52
  mode: Mode,
50
53
  input_extractor: PatchExtractor[GenericImageStack],
51
54
  target_extractor: Optional[PatchExtractor[GenericImageStack]] = None,
@@ -65,33 +68,43 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
65
68
  def _initialize_patching_strategy(self) -> PatchingStrategy:
66
69
  patching_strategy: PatchingStrategy
67
70
  if self.mode == Mode.TRAINING:
68
- if isinstance(self.config, InferenceConfig):
69
- raise ValueError("Inference config cannot be used for training.")
71
+ if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
72
+ raise ValueError(
73
+ f"Only `random` patching strategy supported during training, got "
74
+ f"{self.config.patching.name}."
75
+ )
76
+
70
77
  patching_strategy = RandomPatchingStrategy(
71
78
  data_shapes=self.input_extractor.shape,
72
- patch_size=self.config.patch_size,
73
- # TODO: Add random seed to dataconfig
74
- seed=getattr(self.config, "random_seed", 42),
79
+ patch_size=self.config.patching.patch_size,
80
+ seed=self.config.seed,
75
81
  )
76
82
  elif self.mode == Mode.VALIDATING:
77
- if isinstance(self.config, InferenceConfig):
78
- raise ValueError("Inference config cannot be used for validating.")
83
+ if self.config.patching.name != SupportedPatchingStrategy.RANDOM:
84
+ raise ValueError(
85
+ f"Only `random` patching strategy supported during training, got "
86
+ f"{self.config.patching.name}."
87
+ )
88
+
79
89
  patching_strategy = FixedRandomPatchingStrategy(
80
90
  data_shapes=self.input_extractor.shape,
81
- patch_size=self.config.patch_size,
82
- # TODO: Add random seed to dataconfig
83
- seed=getattr(self.config, "random_seed", 42),
91
+ patch_size=self.config.patching.patch_size,
92
+ seed=self.config.seed,
84
93
  )
85
94
  elif self.mode == Mode.PREDICTING:
86
- if not isinstance(self.config, InferenceConfig):
87
- raise ValueError("Inference config must be used for predicting.")
88
- if (self.config.tile_size is not None) and (
89
- self.config.tile_overlap is not None
95
+ if (
96
+ self.config.patching.name != SupportedPatchingStrategy.TILED
97
+ and self.config.patching.name != SupportedPatchingStrategy.WHOLE
90
98
  ):
99
+ raise ValueError(
100
+ f"Only `tiled` and `whole` patching strategy supported during "
101
+ f"training, got {self.config.patching.name}."
102
+ )
103
+ elif self.config.patching.name == SupportedPatchingStrategy.TILED:
91
104
  patching_strategy = TilingStrategy(
92
105
  data_shapes=self.input_extractor.shape,
93
- tile_size=self.config.tile_size,
94
- overlaps=self.config.tile_overlap,
106
+ tile_size=self.config.patching.patch_size,
107
+ overlaps=self.config.patching.overlaps,
95
108
  )
96
109
  else:
97
110
  patching_strategy = WholeSamplePatchingStrategy(
@@ -103,32 +116,18 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
103
116
  return patching_strategy
104
117
 
105
118
  def _initialize_transforms(self) -> Optional[Compose]:
106
- if isinstance(self.config, DataConfig):
107
- if self.mode == Mode.TRAINING:
108
- # TODO: initialize normalization separately depending on configuration
109
- return Compose(
110
- transform_list=[
111
- NormalizeModel(
112
- image_means=self.input_stats.means,
113
- image_stds=self.input_stats.stds,
114
- target_means=self.target_stats.means,
115
- target_stds=self.target_stats.stds,
116
- )
117
- ]
118
- + list(self.config.transforms)
119
- )
119
+ normalize = NormalizeModel(
120
+ image_means=self.input_stats.means,
121
+ image_stds=self.input_stats.stds,
122
+ target_means=self.target_stats.means,
123
+ target_stds=self.target_stats.stds,
124
+ )
125
+ if self.mode == Mode.TRAINING:
126
+ # TODO: initialize normalization separately depending on configuration
127
+ return Compose(transform_list=[normalize] + list(self.config.transforms))
120
128
 
121
129
  # TODO: add TTA
122
- return Compose(
123
- transform_list=[
124
- NormalizeModel(
125
- image_means=self.input_stats.means,
126
- image_stds=self.input_stats.stds,
127
- target_means=self.target_stats.means,
128
- target_stds=self.target_stats.stds,
129
- )
130
- ]
131
- )
130
+ return Compose(transform_list=[normalize])
132
131
 
133
132
  def _calculate_stats(
134
133
  self, data_extractor: PatchExtractor[GenericImageStack]
@@ -158,14 +157,11 @@ class CareamicsDataset(Dataset, Generic[GenericImageStack]):
158
157
  input_stats = self._calculate_stats(self.input_extractor)
159
158
 
160
159
  target_stats = Stats((), ())
161
- if isinstance(self.config, DataConfig):
162
- if (
163
- self.config.target_means is not None
164
- and self.config.target_stds is not None
165
- ):
166
- target_stats = Stats(self.config.target_means, self.config.target_stds)
167
- elif self.target_extractor is not None:
168
- target_stats = self._calculate_stats(self.target_extractor)
160
+
161
+ if self.config.target_means is not None and self.config.target_stds is not None:
162
+ target_stats = Stats(self.config.target_means, self.config.target_stds)
163
+ elif self.target_extractor is not None:
164
+ target_stats = self._calculate_stats(self.target_extractor)
169
165
 
170
166
  return input_stats, target_stats
171
167
 
@@ -13,8 +13,11 @@
13
13
  "import tifffile\n",
14
14
  "from careamics_portfolio import PortfolioManager\n",
15
15
  "\n",
16
- "from careamics.config.configuration_factories import create_n2v_configuration\n",
17
- "from careamics.config.support import SupportedTransform\n",
16
+ "from careamics.config.configuration_factories import (\n",
17
+ " _create_ng_data_configuration,\n",
18
+ " create_n2v_configuration,\n",
19
+ ")\n",
20
+ "from careamics.config.data import NGDataConfig\n",
18
21
  "from careamics.lightning.callbacks import HyperParametersCallback\n",
19
22
  "from careamics.lightning.dataset_ng.data_module import CareamicsDataModule\n",
20
23
  "from careamics.lightning.dataset_ng.lightning_modules import N2VModule"
@@ -29,7 +32,8 @@
29
32
  "# Set seeds for reproducibility\n",
30
33
  "from pytorch_lightning import seed_everything\n",
31
34
  "\n",
32
- "seed_everything(42)"
35
+ "seed = 42\n",
36
+ "seed_everything(seed)"
33
37
  ]
34
38
  },
35
39
  {
@@ -110,17 +114,17 @@
110
114
  " num_epochs=100,\n",
111
115
  ")\n",
112
116
  "\n",
113
- "# Ensuring that transforms are set\n",
114
- "config.data_config.transforms = [\n",
115
- " {\n",
116
- " \"name\": SupportedTransform.XY_FLIP.value,\n",
117
- " \"flip_x\": True,\n",
118
- " \"flip_y\": True,\n",
119
- " },\n",
120
- " {\n",
121
- " \"name\": SupportedTransform.XY_RANDOM_ROTATE90.value,\n",
122
- " },\n",
123
- "]"
117
+ "# TODO until the NGDataConfig is accepted by the Confiugration, these are separte\n",
118
+ "ng_data_config = _create_ng_data_configuration(\n",
119
+ " data_type=config.data_config.data_type,\n",
120
+ " axes=config.data_config.axes,\n",
121
+ " patch_size=config.data_config.patch_size,\n",
122
+ " batch_size=config.data_config.batch_size,\n",
123
+ " augmentations=config.data_config.transforms,\n",
124
+ " train_dataloader_params=config.data_config.train_dataloader_params,\n",
125
+ " val_dataloader_params=config.data_config.val_dataloader_params,\n",
126
+ " seed=seed,\n",
127
+ ")\n"
124
128
  ]
125
129
  },
126
130
  {
@@ -137,7 +141,7 @@
137
141
  "outputs": [],
138
142
  "source": [
139
143
  "train_data_module = CareamicsDataModule(\n",
140
- " data_config=config.data_config,\n",
144
+ " data_config=ng_data_config,\n",
141
145
  " train_data=train_files,\n",
142
146
  " val_data=val_files,\n",
143
147
  ")\n",
@@ -224,15 +228,16 @@
224
228
  "metadata": {},
225
229
  "outputs": [],
226
230
  "source": [
227
- "from careamics.config.inference_model import InferenceConfig\n",
228
231
  "from careamics.dataset_ng.legacy_interoperability import imageregions_to_tileinfos\n",
229
232
  "from careamics.prediction_utils import convert_outputs\n",
230
233
  "\n",
231
- "config = InferenceConfig(\n",
232
- " model_config=config,\n",
234
+ "config = NGDataConfig(\n",
233
235
  " data_type=\"tiff\",\n",
234
- " tile_size=(128, 128),\n",
235
- " tile_overlap=(32, 32),\n",
236
+ " patching={\n",
237
+ " \"name\": \"tiled\",\n",
238
+ " \"patch_size\": (128, 128),\n",
239
+ " \"overlaps\": (32, 32),\n",
240
+ " },\n",
236
241
  " axes=\"YX\",\n",
237
242
  " batch_size=1,\n",
238
243
  " image_means=train_data_module.train_dataset.input_stats.means,\n",
@@ -319,7 +324,7 @@
319
324
  "psnrs = np.zeros((len(predictions), 1))\n",
320
325
  "scale_invariant_psnrs = np.zeros((len(predictions), 1))\n",
321
326
  "\n",
322
- "for i, (pred, gt) in enumerate(zip(predictions, gts)):\n",
327
+ "for i, (pred, gt) in enumerate(zip(predictions, gts, strict=False)):\n",
323
328
  " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
324
329
  " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
325
330
  "\n",
@@ -334,7 +339,7 @@
334
339
  ],
335
340
  "metadata": {
336
341
  "kernelspec": {
337
- "display_name": "Python 3",
342
+ "display_name": "czi",
338
343
  "language": "python",
339
344
  "name": "python3"
340
345
  },
@@ -348,7 +353,7 @@
348
353
  "name": "python",
349
354
  "nbconvert_exporter": "python",
350
355
  "pygments_lexer": "ipython3",
351
- "version": "3.10.14"
356
+ "version": "3.12.11"
352
357
  }
353
358
  },
354
359
  "nbformat": 4,
@@ -293,7 +293,7 @@
293
293
  "psnrs = np.zeros((len(prediction), 1))\n",
294
294
  "scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
295
295
  "\n",
296
- "for i, (pred, gt) in enumerate(zip(prediction, gts)):\n",
296
+ "for i, (pred, gt) in enumerate(zip(prediction, gts, strict=False)):\n",
297
297
  " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
298
298
  " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
299
299
  "\n",
@@ -698,7 +698,7 @@
698
698
  "psnrs = np.zeros((len(prediction), 1))\n",
699
699
  "scale_invariant_psnrs = np.zeros((len(prediction), 1))\n",
700
700
  "\n",
701
- "for i, (pred, gt) in enumerate(zip(prediction, gts)):\n",
701
+ "for i, (pred, gt) in enumerate(zip(prediction, gts, strict=False)):\n",
702
702
  " psnrs[i] = psnr(gt, pred.squeeze(), data_range=gt.max() - gt.min())\n",
703
703
  " scale_invariant_psnrs[i] = scale_invariant_psnr(gt, pred.squeeze())\n",
704
704
  "\n",