careamics 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (54) hide show
  1. careamics/careamist.py +20 -4
  2. careamics/config/configuration.py +10 -5
  3. careamics/config/data/data_model.py +38 -1
  4. careamics/config/optimizer_models.py +1 -3
  5. careamics/config/training_model.py +0 -2
  6. careamics/dataset/dataset_utils/running_stats.py +7 -3
  7. careamics/dataset_ng/README.md +212 -0
  8. careamics/dataset_ng/dataset.py +233 -0
  9. careamics/dataset_ng/demos/bsd68_demo.ipynb +356 -0
  10. careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
  11. careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +734 -0
  12. careamics/dataset_ng/demos/demo_datamodule.ipynb +443 -0
  13. careamics/dataset_ng/{demo_dataset.ipynb → demos/demo_dataset.ipynb} +39 -15
  14. careamics/dataset_ng/{demo_patch_extractor.py → demos/demo_patch_extractor.py} +7 -9
  15. careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +292 -0
  16. careamics/dataset_ng/factory.py +408 -0
  17. careamics/dataset_ng/legacy_interoperability.py +168 -0
  18. careamics/dataset_ng/patch_extractor/__init__.py +3 -8
  19. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +6 -4
  20. careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -1
  21. careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +5 -1
  22. careamics/dataset_ng/patch_extractor/image_stack_loader.py +5 -75
  23. careamics/dataset_ng/patch_extractor/patch_extractor.py +5 -4
  24. careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +73 -106
  25. careamics/dataset_ng/patching_strategies/__init__.py +6 -1
  26. careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +31 -0
  27. careamics/dataset_ng/patching_strategies/random_patching.py +3 -1
  28. careamics/dataset_ng/patching_strategies/tiling_strategy.py +171 -0
  29. careamics/dataset_ng/patching_strategies/whole_sample.py +36 -0
  30. careamics/lightning/dataset_ng/data_module.py +488 -0
  31. careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
  32. careamics/lightning/dataset_ng/lightning_modules/care_module.py +58 -0
  33. careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +67 -0
  34. careamics/lightning/dataset_ng/lightning_modules/unet_module.py +143 -0
  35. careamics/lightning/lightning_module.py +3 -0
  36. careamics/lvae_training/dataset/__init__.py +8 -3
  37. careamics/lvae_training/dataset/config.py +3 -3
  38. careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
  39. careamics/lvae_training/dataset/multich_dataset.py +46 -17
  40. careamics/lvae_training/dataset/multicrop_dset.py +196 -0
  41. careamics/lvae_training/dataset/types.py +3 -3
  42. careamics/lvae_training/dataset/utils/index_manager.py +259 -0
  43. careamics/lvae_training/eval_utils.py +93 -3
  44. careamics/transforms/compose.py +1 -0
  45. careamics/transforms/normalize.py +18 -7
  46. careamics/utils/lightning_utils.py +25 -11
  47. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/METADATA +3 -3
  48. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/RECORD +51 -36
  49. careamics/dataset_ng/dataset/__init__.py +0 -3
  50. careamics/dataset_ng/dataset/dataset.py +0 -184
  51. careamics/dataset_ng/demo_patch_extractor_factory.py +0 -37
  52. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/WHEEL +0 -0
  53. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/entry_points.txt +0 -0
  54. {careamics-0.0.10.dist-info → careamics-0.0.12.dist-info}/licenses/LICENSE +0 -0
@@ -335,4 +335,6 @@ def _calc_n_patches(spatial_shape: Sequence[int], patch_size: Sequence[int]) ->
335
335
  f"spatial dimensions {len(spatial_shape)}, for `patch_size={patch_size}` "
336
336
  f"and `spatial_shape={spatial_shape}`."
337
337
  )
338
- return int(np.ceil(np.prod(spatial_shape) / np.prod(patch_size)))
338
+ patches_per_dim = [np.ceil(s / p) for s, p in zip(spatial_shape, patch_size)]
339
+ total_patches = int(np.prod(patches_per_dim))
340
+ return total_patches
@@ -0,0 +1,171 @@
1
+ """Module for the `TilingStrategy` class."""
2
+
3
+ import itertools
4
+ from collections.abc import Sequence
5
+
6
+ from .patching_strategy_protocol import TileSpecs
7
+
8
+
9
+ class TilingStrategy:
10
+ """
11
+ The tiling strategy should be used for prediction. The `get_patch_specs`
12
+ method returns `TileSpec` dictionaries that contains information on how to
13
+ stitch the tiles back together to create the full image.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ data_shapes: Sequence[Sequence[int]],
19
+ tile_size: Sequence[int],
20
+ overlaps: Sequence[int],
21
+ ):
22
+ """
23
+ The tiling strategy should be used for prediction. The `get_patch_specs`
24
+ method returns `TileSpec` dictionaries that contains information on how to
25
+ stitch the tiles back together to create the full image.
26
+
27
+ Parameters
28
+ ----------
29
+ data_shapes : sequence of (sequence of int)
30
+ The shapes of the underlying data. Each element is the dimension of the
31
+ axes SC(Z)YX.
32
+ tile_size : sequence of int
33
+ The size of the tile. The sequence will have length 2 or 3, for 2D and 3D
34
+ data respectively.
35
+ overlaps : sequence of int
36
+ How much a tile will overlap with adjacent tiles in each spatial dimension.
37
+ """
38
+ self.data_shapes = data_shapes
39
+ self.tile_size = tile_size
40
+ self.overlaps = overlaps
41
+ # tile_size and overlap should have same length validated in pydantic configs
42
+ self.tile_specs: list[TileSpecs] = self._generate_specs()
43
+
44
+ @property
45
+ def n_patches(self) -> int:
46
+ """
47
+ The number of patches that this patching strategy will return.
48
+
49
+ It also determines the maximum index that can be given to `get_patch_spec`.
50
+ """
51
+ return len(self.tile_specs)
52
+
53
+ def get_patch_spec(self, index: int) -> TileSpecs:
54
+ """Return the tile specs for a given index.
55
+
56
+ Parameters
57
+ ----------
58
+ index : int
59
+ A patch index.
60
+
61
+ Returns
62
+ -------
63
+ TileSpecs
64
+ A dictionary that specifies a single patch in a series of `ImageStacks`.
65
+ """
66
+ return self.tile_specs[index]
67
+
68
+ def _generate_specs(self) -> list[TileSpecs]:
69
+ tile_specs: list[TileSpecs] = []
70
+ for data_idx, data_shape in enumerate(self.data_shapes):
71
+ spatial_shape = data_shape[2:]
72
+
73
+ # spec info for each axis
74
+ axis_specs: list[tuple[list[int], list[int], list[int], list[int]]] = [
75
+ self._compute_1d_coords(
76
+ axis_size, self.tile_size[axis_idx], self.overlaps[axis_idx]
77
+ )
78
+ for axis_idx, axis_size in enumerate(spatial_shape)
79
+ ]
80
+
81
+ # combine by using zip
82
+ all_coords, all_stitch_coords, all_crop_coords, all_crop_size = zip(
83
+ *axis_specs
84
+ )
85
+ # patches will be the same for each sample in a stack
86
+ for sample_idx in range(data_shape[0]):
87
+ # iterate through all combinations using itertools.product
88
+ for coords, stitch_coords, crop_coords, crop_size in zip(
89
+ itertools.product(*all_coords),
90
+ itertools.product(*all_stitch_coords),
91
+ itertools.product(*all_crop_coords),
92
+ itertools.product(*all_crop_size),
93
+ ):
94
+ tile_specs.append(
95
+ {
96
+ # PatchSpecs
97
+ "data_idx": data_idx,
98
+ "sample_idx": sample_idx,
99
+ "coords": coords,
100
+ "patch_size": self.tile_size,
101
+ # TileSpecs additional fields
102
+ "crop_coords": crop_coords,
103
+ "crop_size": crop_size,
104
+ "stitch_coords": stitch_coords,
105
+ }
106
+ )
107
+ return tile_specs
108
+
109
+ @staticmethod
110
+ def _compute_1d_coords(
111
+ axis_size: int, tile_size: int, overlap: int
112
+ ) -> tuple[list[int], list[int], list[int], list[int]]:
113
+ """
114
+ Computes the TileSpec information for a single axis.
115
+
116
+ Parameters
117
+ ----------
118
+ axis_size : int
119
+ The size of the axis.
120
+ tile_size : int
121
+ The tile size.
122
+ overlap : int
123
+ The tile overlap.
124
+
125
+ Returns
126
+ -------
127
+ coords: list of int
128
+ The top-left (and first z-slice for 3D data) of a tile, in coords relative
129
+ to the image.
130
+ stitch_coords: list of int
131
+ Where the tile will be stitched back into an image, taking into account
132
+ that the tile will be cropped, in coords relative to the image.
133
+ crop_coords: list of int
134
+ The top-left side of where the tile will be cropped, in coordinates relative
135
+ to the tile.
136
+ crop_size: list of int
137
+ The size of the cropped tile.
138
+ """
139
+ coords: list[int] = []
140
+ stitch_coords: list[int] = []
141
+ crop_coords: list[int] = []
142
+ crop_size: list[int] = []
143
+
144
+ step = tile_size - overlap
145
+ for i in range(0, max(1, axis_size - overlap), step):
146
+ if i == 0:
147
+ coords.append(i)
148
+ crop_coords.append(0)
149
+ stitch_coords.append(0)
150
+ crop_size.append(tile_size - overlap // 2)
151
+ elif (i > 0) and (i + tile_size < axis_size):
152
+ coords.append(i)
153
+ crop_coords.append(overlap // 2)
154
+ stitch_coords.append(coords[-1] + crop_coords[-1])
155
+ crop_size.append(tile_size - overlap)
156
+ else:
157
+ previous_crop_size = crop_size[-1] if crop_size else 1
158
+ previous_stitch_coord = stitch_coords[-1] if stitch_coords else 0
159
+ previous_tile_end = previous_stitch_coord + previous_crop_size
160
+
161
+ coords.append(max(0, axis_size - tile_size))
162
+ stitch_coords.append(previous_tile_end)
163
+ crop_coords.append(stitch_coords[-1] - coords[-1])
164
+ crop_size.append(axis_size - stitch_coords[-1])
165
+
166
+ return (
167
+ coords,
168
+ stitch_coords,
169
+ crop_coords,
170
+ crop_size,
171
+ )
@@ -0,0 +1,36 @@
1
+ from collections.abc import Sequence
2
+
3
+ from .patching_strategy_protocol import PatchSpecs
4
+
5
+
6
+ class WholeSamplePatchingStrategy:
7
+ # TODO: warn this strategy should only be used with batch size = 1
8
+ # for the case of multiple image stacks with different dimensions
9
+
10
+ # TODO: docs
11
+ def __init__(self, data_shapes: Sequence[Sequence[int]]):
12
+ self.data_shapes = data_shapes
13
+
14
+ self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
15
+
16
+ @property
17
+ def n_patches(self) -> int:
18
+ return len(self.patch_specs)
19
+
20
+ def get_patch_spec(self, index: int) -> PatchSpecs:
21
+ return self.patch_specs[index]
22
+
23
+ def _initialize_patch_specs(self) -> list[PatchSpecs]:
24
+ patch_specs: list[PatchSpecs] = []
25
+ for data_idx, data_shape in enumerate(self.data_shapes):
26
+ spatial_shape = data_shape[2:]
27
+ for sample_idx in range(data_shape[0]):
28
+ patch_specs.append(
29
+ {
30
+ "data_idx": data_idx,
31
+ "sample_idx": sample_idx,
32
+ "coords": tuple(0 for _ in spatial_shape),
33
+ "patch_size": spatial_shape,
34
+ }
35
+ )
36
+ return patch_specs