careamics 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +4 -3
- careamics/cli/utils.py +1 -1
- careamics/config/algorithms/n2v_algorithm_model.py +1 -1
- careamics/config/architectures/unet_model.py +3 -0
- careamics/config/callback_model.py +23 -34
- careamics/config/configuration.py +47 -1
- careamics/config/configuration_factories.py +288 -23
- careamics/config/data/__init__.py +2 -0
- careamics/config/data/data_model.py +3 -3
- careamics/config/data/ng_data_model.py +381 -0
- careamics/config/data/patching_strategies/__init__.py +14 -0
- careamics/config/data/patching_strategies/_overlapping_patched_model.py +103 -0
- careamics/config/data/patching_strategies/_patched_model.py +56 -0
- careamics/config/data/patching_strategies/random_patching_model.py +21 -0
- careamics/config/data/patching_strategies/sequential_patching_model.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_model.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_model.py +12 -0
- careamics/config/inference_model.py +6 -3
- careamics/config/support/supported_data.py +7 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/validators/validator_utils.py +4 -3
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +2 -1
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +2 -2
- careamics/dataset/iterable_tiled_pred_dataset.py +2 -2
- careamics/dataset/patching/patching.py +3 -2
- careamics/dataset/tiling/lvae_tiled_patching.py +16 -6
- careamics/dataset/tiling/tiled_patching.py +2 -1
- careamics/dataset_ng/dataset.py +46 -50
- careamics/dataset_ng/demos/bsd68_demo.ipynb +28 -23
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +1 -1
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +1 -1
- careamics/dataset_ng/demos/demo_datamodule.ipynb +50 -46
- careamics/dataset_ng/demos/demo_dataset.ipynb +32 -49
- careamics/dataset_ng/factory.py +58 -15
- careamics/dataset_ng/legacy_interoperability.py +3 -1
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +1 -1
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +2 -0
- careamics/dataset_ng/patch_extractor/image_stack/czi_image_stack.py +360 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -1
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +43 -1
- careamics/dataset_ng/patching_strategies/random_patching.py +4 -2
- careamics/dataset_ng/patching_strategies/sequential_patching.py +5 -5
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +2 -1
- careamics/file_io/read/get_func.py +2 -1
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/data_module.py +218 -28
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +44 -5
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +42 -3
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +73 -4
- careamics/lightning/lightning_module.py +2 -1
- careamics/lightning/predict_data_module.py +2 -1
- careamics/lightning/train_data_module.py +2 -1
- careamics/losses/loss_factory.py +2 -1
- careamics/lvae_training/dataset/multicrop_dset.py +1 -1
- careamics/model_io/bioimage/bioimage_utils.py +1 -1
- careamics/model_io/bioimage/model_description.py +1 -1
- careamics/model_io/bmz_io.py +1 -1
- careamics/model_io/model_io_utils.py +2 -2
- careamics/models/activation.py +2 -1
- careamics/models/unet.py +16 -10
- careamics/prediction_utils/prediction_outputs.py +1 -1
- careamics/prediction_utils/stitch_prediction.py +1 -1
- careamics/transforms/n2v_manipulate_torch.py +15 -9
- careamics/transforms/pixel_manipulation_torch.py +59 -92
- careamics/utils/lightning_utils.py +2 -2
- careamics/utils/metrics.py +2 -1
- careamics/utils/torch_utils.py +23 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/METADATA +10 -9
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/RECORD +74 -63
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/WHEEL +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.12.dist-info → careamics-0.0.14.dist-info}/licenses/LICENSE +0 -0
careamics/losses/loss_factory.py
CHANGED
|
@@ -6,8 +6,9 @@ This module contains a factory function for creating loss functions.
|
|
|
6
6
|
|
|
7
7
|
from __future__ import annotations
|
|
8
8
|
|
|
9
|
+
from collections.abc import Callable
|
|
9
10
|
from dataclasses import dataclass
|
|
10
|
-
from typing import
|
|
11
|
+
from typing import Union
|
|
11
12
|
|
|
12
13
|
from torch import Tensor as tensor
|
|
13
14
|
|
|
@@ -148,7 +148,7 @@ def _create_inputs_ouputs(
|
|
|
148
148
|
inv_means = []
|
|
149
149
|
inv_stds = []
|
|
150
150
|
if means and stds:
|
|
151
|
-
for mean, std in zip(means, stds):
|
|
151
|
+
for mean, std in zip(means, stds, strict=False):
|
|
152
152
|
inv_means.append(-mean / (std + eps))
|
|
153
153
|
inv_stds.append(1 / (std + eps) - eps)
|
|
154
154
|
|
careamics/model_io/bmz_io.py
CHANGED
|
@@ -12,7 +12,7 @@ from careamics.utils import check_path_exists
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def load_pretrained(
|
|
15
|
-
path: Union[Path, str]
|
|
15
|
+
path: Union[Path, str],
|
|
16
16
|
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
17
17
|
"""
|
|
18
18
|
Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
|
|
@@ -47,7 +47,7 @@ def load_pretrained(
|
|
|
47
47
|
|
|
48
48
|
|
|
49
49
|
def _load_checkpoint(
|
|
50
|
-
path: Union[Path, str]
|
|
50
|
+
path: Union[Path, str],
|
|
51
51
|
) -> tuple[Union[FCNModule, VAEModule], Configuration]:
|
|
52
52
|
"""
|
|
53
53
|
Load a model from a checkpoint and return both model and configuration.
|
careamics/models/activation.py
CHANGED
careamics/models/unet.py
CHANGED
|
@@ -205,16 +205,23 @@ class UnetDecoder(nn.Module):
|
|
|
205
205
|
decoder_blocks: list[nn.Module] = []
|
|
206
206
|
for n in range(depth):
|
|
207
207
|
decoder_blocks.append(upsampling)
|
|
208
|
-
|
|
209
|
-
|
|
208
|
+
|
|
209
|
+
in_channels = (num_channels_init * 2 ** (depth - n - 1)) * groups
|
|
210
|
+
# final decoder block has the same number in and out features
|
|
211
|
+
out_channels = in_channels // 2 if n != depth - 1 else in_channels
|
|
212
|
+
if not (n2v2 and (n == depth - 1)):
|
|
213
|
+
in_channels = in_channels * 2 # accounting for skip connection concat
|
|
214
|
+
|
|
210
215
|
decoder_blocks.append(
|
|
211
216
|
Conv_Block(
|
|
212
217
|
conv_dim,
|
|
213
|
-
in_channels=
|
|
214
|
-
in_channels + in_channels // 2 if n > 0 else in_channels
|
|
215
|
-
),
|
|
218
|
+
in_channels=in_channels,
|
|
216
219
|
out_channels=out_channels,
|
|
217
|
-
|
|
220
|
+
# TODO: Tensorflow n2v implementation has intermediate channel
|
|
221
|
+
# multiplication for skip_skipone=True but not skip_skipone=False
|
|
222
|
+
# this needs to be benchmarked.
|
|
223
|
+
# final decoder block doesn't multiply the intermediate features
|
|
224
|
+
intermediate_channel_multiplier=2 if n != depth - 1 else 1,
|
|
218
225
|
dropout_perc=dropout,
|
|
219
226
|
activation="ReLU",
|
|
220
227
|
use_batch_norm=use_batch_norm,
|
|
@@ -241,6 +248,7 @@ class UnetDecoder(nn.Module):
|
|
|
241
248
|
"""
|
|
242
249
|
x: torch.Tensor = features[0]
|
|
243
250
|
skip_connections: tuple[torch.Tensor, ...] = features[-1:0:-1]
|
|
251
|
+
depth = len(skip_connections)
|
|
244
252
|
|
|
245
253
|
x = self.bottleneck(x)
|
|
246
254
|
|
|
@@ -249,10 +257,8 @@ class UnetDecoder(nn.Module):
|
|
|
249
257
|
if isinstance(module, nn.Upsample):
|
|
250
258
|
# divide index by 2 because of upsampling layers
|
|
251
259
|
skip_connection: torch.Tensor = skip_connections[i // 2]
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
x = self._interleave(x, skip_connection, self.groups)
|
|
255
|
-
else:
|
|
260
|
+
# top level skip connection not added for n2v2
|
|
261
|
+
if (not self.n2v2) or (self.n2v2 and (i // 2 < depth - 1)):
|
|
256
262
|
x = self._interleave(x, skip_connection, self.groups)
|
|
257
263
|
return x
|
|
258
264
|
|
|
@@ -87,7 +87,7 @@ def combine_batches(
|
|
|
87
87
|
|
|
88
88
|
|
|
89
89
|
def _combine_tiled_batches(
|
|
90
|
-
predictions: list[tuple[NDArray, list[TileInformation]]]
|
|
90
|
+
predictions: list[tuple[NDArray, list[TileInformation]]],
|
|
91
91
|
) -> tuple[list[NDArray], list[TileInformation]]:
|
|
92
92
|
"""
|
|
93
93
|
Combine batches from tiled output.
|
|
@@ -94,7 +94,7 @@ def stitch_prediction_single(
|
|
|
94
94
|
input_shape = (1, tile_channels, *tile_infos[0].array_shape[1:])
|
|
95
95
|
predicted_image = np.zeros(input_shape, dtype=np.float32)
|
|
96
96
|
|
|
97
|
-
for tile, tile_info in zip(tiles, tile_infos):
|
|
97
|
+
for tile, tile_info in zip(tiles, tile_infos, strict=False):
|
|
98
98
|
|
|
99
99
|
# Compute coordinates for cropping predicted tile
|
|
100
100
|
crop_slices: tuple[Union[builtins.ellipsis, slice], ...] = (
|
|
@@ -27,6 +27,8 @@ class N2VManipulateTorch:
|
|
|
27
27
|
N2V manipulation configuration.
|
|
28
28
|
seed : Optional[int], optional
|
|
29
29
|
Random seed, by default None.
|
|
30
|
+
device : str
|
|
31
|
+
The device on which operations take place, e.g. "cuda", "cpu" or "mps".
|
|
30
32
|
|
|
31
33
|
Attributes
|
|
32
34
|
----------
|
|
@@ -48,6 +50,7 @@ class N2VManipulateTorch:
|
|
|
48
50
|
self,
|
|
49
51
|
n2v_manipulate_config: N2VManipulateModel,
|
|
50
52
|
seed: Optional[int] = None,
|
|
53
|
+
device: Optional[str] = None,
|
|
51
54
|
):
|
|
52
55
|
"""Constructor.
|
|
53
56
|
|
|
@@ -57,6 +60,8 @@ class N2VManipulateTorch:
|
|
|
57
60
|
N2V manipulation configuration.
|
|
58
61
|
seed : Optional[int], optional
|
|
59
62
|
Random seed, by default None.
|
|
63
|
+
device : str
|
|
64
|
+
The device on which operations take place, e.g. "cuda", "cpu" or "mps".
|
|
60
65
|
"""
|
|
61
66
|
self.masked_pixel_percentage = n2v_manipulate_config.masked_pixel_percentage
|
|
62
67
|
self.roi_size = n2v_manipulate_config.roi_size
|
|
@@ -78,15 +83,16 @@ class N2VManipulateTorch:
|
|
|
78
83
|
|
|
79
84
|
# PyTorch random generator
|
|
80
85
|
# TODO refactor into careamics.utils.torch_utils.get_device
|
|
81
|
-
if
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
86
|
+
if device is None:
|
|
87
|
+
if torch.cuda.is_available():
|
|
88
|
+
device = "cuda"
|
|
89
|
+
elif torch.backends.mps.is_available() and platform.processor() in (
|
|
90
|
+
"arm",
|
|
91
|
+
"arm64",
|
|
92
|
+
):
|
|
93
|
+
device = "mps"
|
|
94
|
+
else:
|
|
95
|
+
device = "cpu"
|
|
90
96
|
|
|
91
97
|
self.rng = (
|
|
92
98
|
torch.Generator(device=device).manual_seed(seed)
|
|
@@ -75,50 +75,23 @@ def _apply_struct_mask_torch(
|
|
|
75
75
|
return patch
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
def _odd_jitter_func_torch(step: float, rng: torch.Generator) -> torch.Tensor:
|
|
79
|
-
"""
|
|
80
|
-
Randomly sample a jitter to be applied to the masking grid.
|
|
81
|
-
|
|
82
|
-
This is done to account for cases where the step size is not an integer.
|
|
83
|
-
|
|
84
|
-
Parameters
|
|
85
|
-
----------
|
|
86
|
-
step : float
|
|
87
|
-
Step size of the grid, output of np.linspace.
|
|
88
|
-
rng : torch.Generator
|
|
89
|
-
Random number generator.
|
|
90
|
-
|
|
91
|
-
Returns
|
|
92
|
-
-------
|
|
93
|
-
torch.Tensor
|
|
94
|
-
Array of random jitter to be added to the grid.
|
|
95
|
-
"""
|
|
96
|
-
step_floor = torch.floor(torch.tensor(step))
|
|
97
|
-
odd_jitter = (
|
|
98
|
-
step_floor
|
|
99
|
-
if step_floor == step
|
|
100
|
-
else torch.randint(high=2, size=(1,), generator=rng)
|
|
101
|
-
)
|
|
102
|
-
return step_floor if odd_jitter == 0 else step_floor + 1
|
|
103
|
-
|
|
104
|
-
|
|
105
78
|
def _get_stratified_coords_torch(
|
|
106
79
|
mask_pixel_perc: float,
|
|
107
80
|
shape: tuple[int, ...],
|
|
108
|
-
rng:
|
|
81
|
+
rng: torch.Generator,
|
|
109
82
|
) -> torch.Tensor:
|
|
110
83
|
"""
|
|
111
84
|
Generate coordinates of the pixels to mask.
|
|
112
85
|
|
|
113
|
-
# TODO add more details
|
|
114
86
|
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
115
|
-
the distance between masked pixels is approximately the same.
|
|
87
|
+
the distance between masked pixels is approximately the same. This is achieved by
|
|
88
|
+
defining a grid and sampling a pixel in each grid square. The grid is defined such
|
|
89
|
+
that the resulting density of masked pixels is the desired masked pixel percentage.
|
|
116
90
|
|
|
117
91
|
Parameters
|
|
118
92
|
----------
|
|
119
93
|
mask_pixel_perc : float
|
|
120
|
-
|
|
121
|
-
calculating the distance between masked pixels across each axis.
|
|
94
|
+
Expected value for percentage of masked pixels across the whole image.
|
|
122
95
|
shape : tuple[int, ...]
|
|
123
96
|
Shape of the input patch.
|
|
124
97
|
rng : torch.Generator or None
|
|
@@ -129,60 +102,51 @@ def _get_stratified_coords_torch(
|
|
|
129
102
|
np.ndarray
|
|
130
103
|
Array of coordinates of the masked pixels.
|
|
131
104
|
"""
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
#
|
|
136
|
-
#
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
# create a 2D meshgrid of coordinates
|
|
166
|
-
coordinate_grid_list = torch.meshgrid(*pixel_coords, indexing="ij")
|
|
167
|
-
coordinate_grid = torch.stack(
|
|
168
|
-
[g.flatten() for g in coordinate_grid_list], dim=-1
|
|
169
|
-
).to(rng.device)
|
|
170
|
-
|
|
171
|
-
# add a random jitter increment so that the coordinates do not lie on the grid
|
|
172
|
-
random_increment = torch.randint(
|
|
173
|
-
high=int(_odd_jitter_func_torch(float(max(steps)), rng)),
|
|
174
|
-
size=torch.tensor(coordinate_grid.shape).to(rng.device).tolist(),
|
|
175
|
-
generator=rng,
|
|
176
|
-
device=rng.device,
|
|
105
|
+
# Implementation logic:
|
|
106
|
+
# find a box size s.t sampling 1 pixel within the box will result in the desired
|
|
107
|
+
# pixel percentage. Make a grid of these boxes that cover the patch (the area of
|
|
108
|
+
# the grid will be greater than or equal to the area of the patch) and sample 1
|
|
109
|
+
# pixel in each box. The density of masked pixels is an intensive property therefore
|
|
110
|
+
# any subset of this area will have the desired expected masked pixel percentage.
|
|
111
|
+
# We can get our desired patch with our desired expected masked pixel percentage by
|
|
112
|
+
# simply filtering out masked pixels that lie outside of our patch bounds.
|
|
113
|
+
|
|
114
|
+
batch_size = shape[0]
|
|
115
|
+
spatial_shape = shape[1:]
|
|
116
|
+
|
|
117
|
+
n_dims = len(spatial_shape)
|
|
118
|
+
expected_area_per_pixel = 1 / (mask_pixel_perc / 100)
|
|
119
|
+
|
|
120
|
+
# keep the grid size in floats for a more accurate expected masked pixel percentage
|
|
121
|
+
grid_size = expected_area_per_pixel ** (1 / n_dims)
|
|
122
|
+
grid_dims = torch.ceil(torch.tensor(spatial_shape) / grid_size).int()
|
|
123
|
+
|
|
124
|
+
# coords on a fixed grid (top left corner)
|
|
125
|
+
coords = torch.stack(
|
|
126
|
+
torch.meshgrid(
|
|
127
|
+
torch.arange(batch_size, dtype=torch.float),
|
|
128
|
+
*[torch.arange(0, grid_dims[i].item()) * grid_size for i in range(n_dims)],
|
|
129
|
+
indexing="ij",
|
|
130
|
+
),
|
|
131
|
+
-1,
|
|
132
|
+
).reshape(-1, n_dims + 1)
|
|
133
|
+
|
|
134
|
+
# add random offset to get a random coord in each grid box
|
|
135
|
+
# also keep the offset in floats
|
|
136
|
+
offset = (
|
|
137
|
+
torch.rand((len(coords), n_dims), device=rng.device, generator=rng) * grid_size
|
|
177
138
|
)
|
|
178
|
-
|
|
139
|
+
coords = coords.to(rng.device)
|
|
140
|
+
coords[:, 1:] += offset
|
|
141
|
+
coords = torch.floor(coords).int()
|
|
179
142
|
|
|
180
|
-
#
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
torch.
|
|
184
|
-
|
|
185
|
-
|
|
143
|
+
# filter pixels out of bounds
|
|
144
|
+
out_of_bounds = (
|
|
145
|
+
coords[:, 1:]
|
|
146
|
+
>= torch.tensor(spatial_shape, device=rng.device).reshape(1, n_dims)
|
|
147
|
+
).any(1)
|
|
148
|
+
coords = coords[~out_of_bounds]
|
|
149
|
+
return coords
|
|
186
150
|
|
|
187
151
|
|
|
188
152
|
def uniform_manipulate_torch(
|
|
@@ -198,7 +162,7 @@ def uniform_manipulate_torch(
|
|
|
198
162
|
|
|
199
163
|
# TODO add more details, especially about batch
|
|
200
164
|
|
|
201
|
-
Manipulated pixels are selected
|
|
165
|
+
Manipulated pixels are selected uniformly selected in a subpatch, away from a grid
|
|
202
166
|
with an approximate uniform probability to be selected across the whole patch.
|
|
203
167
|
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
204
168
|
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
@@ -254,18 +218,21 @@ def uniform_manipulate_torch(
|
|
|
254
218
|
random_increment = roi_span[
|
|
255
219
|
torch.randint(
|
|
256
220
|
low=min(roi_span),
|
|
257
|
-
high=max(roi_span) + 1,
|
|
258
|
-
|
|
221
|
+
high=max(roi_span) + 1,
|
|
222
|
+
# one less coord dim: we shouldn't add a random increment to the batch coord
|
|
223
|
+
size=(subpatch_centers.shape[0], subpatch_centers.shape[1] - 1),
|
|
259
224
|
generator=rng,
|
|
260
225
|
device=patch.device,
|
|
261
226
|
)
|
|
262
227
|
]
|
|
263
228
|
|
|
264
229
|
# compute the replacement pixel coordinates
|
|
265
|
-
replacement_coords =
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
230
|
+
replacement_coords = subpatch_centers.clone()
|
|
231
|
+
# only add random increment to the spatial dimensions, not the batch dimension
|
|
232
|
+
replacement_coords[:, 1:] = torch.clamp(
|
|
233
|
+
replacement_coords[:, 1:] + random_increment,
|
|
234
|
+
torch.zeros_like(torch.tensor(patch.shape[1:])).to(device=patch.device),
|
|
235
|
+
torch.tensor([v - 1 for v in patch.shape[1:]]).to(device=patch.device),
|
|
269
236
|
)
|
|
270
237
|
|
|
271
238
|
# replace the pixels in the patch
|
|
@@ -313,7 +280,7 @@ def median_manipulate_torch(
|
|
|
313
280
|
struct_params : StructMaskParameters or None, optional
|
|
314
281
|
Parameters for the structN2V mask (axis and span).
|
|
315
282
|
rng : torch.default_generator or None, optional
|
|
316
|
-
Random number
|
|
283
|
+
Random number generator, by default None.
|
|
317
284
|
|
|
318
285
|
Returns
|
|
319
286
|
-------
|
|
@@ -32,13 +32,13 @@ def read_csv_logger(experiment_name: str, log_folder: Union[str, Path]) -> dict:
|
|
|
32
32
|
lines = f.readlines()
|
|
33
33
|
|
|
34
34
|
header = lines[0].strip().split(",")
|
|
35
|
-
metrics = {value: [] for value in header}
|
|
35
|
+
metrics: dict[str, list] = {value: [] for value in header}
|
|
36
36
|
print(metrics)
|
|
37
37
|
|
|
38
38
|
for single_line in lines[1:]:
|
|
39
39
|
values = single_line.strip().split(",")
|
|
40
40
|
|
|
41
|
-
for k, v in zip(header, values):
|
|
41
|
+
for k, v in zip(header, values, strict=False):
|
|
42
42
|
metrics[k].append(v)
|
|
43
43
|
|
|
44
44
|
# train and val are not logged on the same row and can have different lengths
|
careamics/utils/metrics.py
CHANGED
careamics/utils/torch_utils.py
CHANGED
|
@@ -5,6 +5,7 @@ These functions are used to control certain aspects and behaviours of PyTorch.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import inspect
|
|
8
|
+
import platform
|
|
8
9
|
from typing import Union
|
|
9
10
|
|
|
10
11
|
import torch
|
|
@@ -16,6 +17,28 @@ from ..utils.logging import get_logger
|
|
|
16
17
|
logger = get_logger(__name__) # TODO are logger still needed?
|
|
17
18
|
|
|
18
19
|
|
|
20
|
+
def get_device() -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the device on which operations take place.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
str
|
|
27
|
+
The device on which operations take place, e.g. "cuda", "cpu" or "mps".
|
|
28
|
+
"""
|
|
29
|
+
if torch.cuda.is_available():
|
|
30
|
+
device = "cuda"
|
|
31
|
+
elif torch.backends.mps.is_available() and platform.processor() in (
|
|
32
|
+
"arm",
|
|
33
|
+
"arm64",
|
|
34
|
+
):
|
|
35
|
+
device = "mps"
|
|
36
|
+
else:
|
|
37
|
+
device = "cpu"
|
|
38
|
+
|
|
39
|
+
return device
|
|
40
|
+
|
|
41
|
+
|
|
19
42
|
def filter_parameters(
|
|
20
43
|
func: type,
|
|
21
44
|
user_params: dict,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.14
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
@@ -10,27 +10,28 @@ License-File: LICENSE
|
|
|
10
10
|
Classifier: Development Status :: 3 - Alpha
|
|
11
11
|
Classifier: License :: OSI Approved :: BSD License
|
|
12
12
|
Classifier: Programming Language :: Python :: 3
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
16
|
Classifier: Typing :: Typed
|
|
18
|
-
Requires-Python: >=3.
|
|
17
|
+
Requires-Python: >=3.10
|
|
19
18
|
Requires-Dist: bioimageio-core==0.7
|
|
20
|
-
Requires-Dist: matplotlib<=3.10.
|
|
19
|
+
Requires-Dist: matplotlib<=3.10.3
|
|
21
20
|
Requires-Dist: numpy<2.0.0
|
|
22
21
|
Requires-Dist: pillow<=11.2.1
|
|
23
22
|
Requires-Dist: psutil<=7.0.0
|
|
24
23
|
Requires-Dist: pydantic<=2.12,>=2.11
|
|
25
|
-
Requires-Dist: pytorch-lightning<=2.5.
|
|
24
|
+
Requires-Dist: pytorch-lightning<=2.5.2,>=2.2
|
|
26
25
|
Requires-Dist: pyyaml!=6.0.0,<=6.0.2
|
|
27
26
|
Requires-Dist: scikit-image<=0.25.2
|
|
28
|
-
Requires-Dist: tifffile<=2025.
|
|
29
|
-
Requires-Dist: torch<=2.
|
|
30
|
-
Requires-Dist: torchvision<=0.
|
|
31
|
-
Requires-Dist: typer<=0.
|
|
27
|
+
Requires-Dist: tifffile<=2025.5.10
|
|
28
|
+
Requires-Dist: torch<=2.7.1,>=2.0
|
|
29
|
+
Requires-Dist: torchvision<=0.22.1
|
|
30
|
+
Requires-Dist: typer<=0.16.0,>=0.12.3
|
|
32
31
|
Requires-Dist: xarray<2025.3.0
|
|
33
32
|
Requires-Dist: zarr<3.0.0
|
|
33
|
+
Provides-Extra: czi
|
|
34
|
+
Requires-Dist: pylibczirw<6.0.0,>=4.1.2; extra == 'czi'
|
|
34
35
|
Provides-Extra: dev
|
|
35
36
|
Requires-Dist: onnx; extra == 'dev'
|
|
36
37
|
Requires-Dist: pre-commit; extra == 'dev'
|