careamics 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +0 -4
- careamics/careamist.py +0 -1
- careamics/config/__init__.py +1 -13
- careamics/config/algorithms/care_algorithm_model.py +84 -0
- careamics/config/algorithms/n2n_algorithm_model.py +85 -0
- careamics/config/algorithms/n2v_algorithm_model.py +269 -1
- careamics/config/configuration.py +21 -13
- careamics/config/configuration_factories.py +179 -187
- careamics/config/configuration_io.py +2 -2
- careamics/config/data/__init__.py +1 -4
- careamics/config/data/data_model.py +46 -62
- careamics/config/support/supported_transforms.py +1 -1
- careamics/config/transformations/__init__.py +0 -2
- careamics/config/transformations/n2v_manipulate_model.py +15 -0
- careamics/config/transformations/transform_unions.py +0 -13
- careamics/dataset/dataset_utils/iterate_over_files.py +2 -2
- careamics/dataset/in_memory_dataset.py +3 -10
- careamics/dataset/in_memory_pred_dataset.py +3 -5
- careamics/dataset/in_memory_tiled_pred_dataset.py +2 -2
- careamics/dataset/iterable_dataset.py +2 -2
- careamics/dataset/iterable_pred_dataset.py +3 -5
- careamics/dataset/iterable_tiled_pred_dataset.py +3 -3
- careamics/dataset_ng/dataset/__init__.py +3 -0
- careamics/dataset_ng/dataset/dataset.py +184 -0
- careamics/dataset_ng/demo_dataset.ipynb +271 -0
- careamics/dataset_ng/demo_patch_extractor.py +53 -0
- careamics/dataset_ng/demo_patch_extractor_factory.py +37 -0
- careamics/dataset_ng/patch_extractor/__init__.py +10 -0
- careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +111 -0
- careamics/dataset_ng/patch_extractor/image_stack/__init__.py +9 -0
- careamics/dataset_ng/patch_extractor/image_stack/image_stack_protocol.py +53 -0
- careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +55 -0
- careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +163 -0
- careamics/dataset_ng/patch_extractor/image_stack_loader.py +140 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +29 -0
- careamics/dataset_ng/patch_extractor/patch_extractor_factory.py +208 -0
- careamics/dataset_ng/patching_strategies/__init__.py +11 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +82 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +338 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +75 -0
- careamics/lightning/lightning_module.py +78 -27
- careamics/lightning/train_data_module.py +8 -39
- careamics/losses/fcn/losses.py +17 -10
- careamics/model_io/bioimage/bioimage_utils.py +5 -3
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +2 -2
- careamics/model_io/model_io_utils.py +2 -2
- careamics/transforms/__init__.py +2 -1
- careamics/transforms/compose.py +5 -15
- careamics/transforms/n2v_manipulate_torch.py +143 -0
- careamics/transforms/pixel_manipulation.py +1 -0
- careamics/transforms/pixel_manipulation_torch.py +418 -0
- careamics/utils/version.py +38 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/METADATA +7 -8
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/RECORD +58 -41
- careamics/config/care_configuration.py +0 -100
- careamics/config/data/n2v_data_model.py +0 -193
- careamics/config/n2n_configuration.py +0 -101
- careamics/config/n2v_configuration.py +0 -266
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/WHEEL +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.9.dist-info → careamics-0.0.10.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,418 @@
|
|
|
1
|
+
"""N2V manipulation functions for PyTorch."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .struct_mask_parameters import StructMaskParameters
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _apply_struct_mask_torch(
|
|
11
|
+
patch: torch.Tensor,
|
|
12
|
+
coords: torch.Tensor,
|
|
13
|
+
struct_params: StructMaskParameters,
|
|
14
|
+
rng: Optional[torch.Generator] = None,
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
"""Apply structN2V masks to patch.
|
|
17
|
+
|
|
18
|
+
Each point in `coords` corresponds to the center of a mask. Masks are parameterized
|
|
19
|
+
by `struct_params`, and pixels in the mask (with respect to `coords`) are replaced
|
|
20
|
+
by a random value.
|
|
21
|
+
|
|
22
|
+
Note that the structN2V mask is applied in 2D at the coordinates given by `coords`.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
patch : torch.Tensor
|
|
27
|
+
Patch to be manipulated, (batch, y, x) or (batch, z, y, x).
|
|
28
|
+
coords : torch.Tensor
|
|
29
|
+
Coordinates of the ROI (subpatch) centers.
|
|
30
|
+
struct_params : StructMaskParameters
|
|
31
|
+
Parameters for the structN2V mask (axis and span).
|
|
32
|
+
rng : torch.Generator, optional
|
|
33
|
+
Random number generator.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
torch.Tensor
|
|
38
|
+
Patch with the structN2V mask applied.
|
|
39
|
+
"""
|
|
40
|
+
if rng is None:
|
|
41
|
+
rng = torch.Generator(device=patch.device)
|
|
42
|
+
|
|
43
|
+
# Relative axis
|
|
44
|
+
moving_axis = -1 - struct_params.axis
|
|
45
|
+
|
|
46
|
+
# Create a mask array
|
|
47
|
+
mask_shape = [1] * len(patch.shape)
|
|
48
|
+
mask_shape[moving_axis] = struct_params.span
|
|
49
|
+
mask = torch.ones(mask_shape, device=patch.device)
|
|
50
|
+
|
|
51
|
+
center = torch.tensor(mask.shape, device=patch.device) // 2
|
|
52
|
+
|
|
53
|
+
# Mark the center
|
|
54
|
+
mask[tuple(center)] = 0
|
|
55
|
+
|
|
56
|
+
# Displacements from center
|
|
57
|
+
displacements = torch.stack(torch.where(mask == 1)) - center.unsqueeze(1)
|
|
58
|
+
|
|
59
|
+
# Combine all coords (ndim, npts) with all displacements (ncoords, ndim)
|
|
60
|
+
mix = displacements.T.unsqueeze(-1) + coords.T.unsqueeze(0)
|
|
61
|
+
mix = mix.permute([1, 0, 2]).reshape([mask.ndim, -1]).T
|
|
62
|
+
|
|
63
|
+
# Filter out invalid indices
|
|
64
|
+
valid_indices = (mix[:, moving_axis] >= 0) & (
|
|
65
|
+
mix[:, moving_axis] < patch.shape[moving_axis]
|
|
66
|
+
)
|
|
67
|
+
mix = mix[valid_indices]
|
|
68
|
+
|
|
69
|
+
# Replace neighboring pixels with random values from a uniform distribution
|
|
70
|
+
random_values = torch.empty(len(mix), device=patch.device).uniform_(
|
|
71
|
+
patch.min().item(), patch.max().item(), generator=rng
|
|
72
|
+
)
|
|
73
|
+
patch[tuple(mix.T.tolist())] = random_values
|
|
74
|
+
|
|
75
|
+
return patch
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _odd_jitter_func_torch(step: float, rng: torch.Generator) -> torch.Tensor:
|
|
79
|
+
"""
|
|
80
|
+
Randomly sample a jitter to be applied to the masking grid.
|
|
81
|
+
|
|
82
|
+
This is done to account for cases where the step size is not an integer.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
step : float
|
|
87
|
+
Step size of the grid, output of np.linspace.
|
|
88
|
+
rng : torch.Generator
|
|
89
|
+
Random number generator.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
torch.Tensor
|
|
94
|
+
Array of random jitter to be added to the grid.
|
|
95
|
+
"""
|
|
96
|
+
step_floor = torch.floor(torch.tensor(step))
|
|
97
|
+
odd_jitter = (
|
|
98
|
+
step_floor
|
|
99
|
+
if step_floor == step
|
|
100
|
+
else torch.randint(high=2, size=(1,), generator=rng)
|
|
101
|
+
)
|
|
102
|
+
return step_floor if odd_jitter == 0 else step_floor + 1
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_stratified_coords_torch(
|
|
106
|
+
mask_pixel_perc: float,
|
|
107
|
+
shape: tuple[int, ...],
|
|
108
|
+
rng: Optional[torch.Generator] = None,
|
|
109
|
+
) -> torch.Tensor:
|
|
110
|
+
"""
|
|
111
|
+
Generate coordinates of the pixels to mask.
|
|
112
|
+
|
|
113
|
+
# TODO add more details
|
|
114
|
+
Randomly selects the coordinates of the pixels to mask in a stratified way, i.e.
|
|
115
|
+
the distance between masked pixels is approximately the same.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
mask_pixel_perc : float
|
|
120
|
+
Actual (quasi) percentage of masked pixels across the whole image. Used in
|
|
121
|
+
calculating the distance between masked pixels across each axis.
|
|
122
|
+
shape : tuple[int, ...]
|
|
123
|
+
Shape of the input patch.
|
|
124
|
+
rng : torch.Generator or None
|
|
125
|
+
Random number generator.
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
np.ndarray
|
|
130
|
+
Array of coordinates of the masked pixels.
|
|
131
|
+
"""
|
|
132
|
+
if rng is None:
|
|
133
|
+
rng = torch.Generator()
|
|
134
|
+
|
|
135
|
+
# Calculate the maximum distance between masked pixels. Inversely proportional to
|
|
136
|
+
# the percentage of masked pixels.
|
|
137
|
+
mask_pixel_distance = round((100 / mask_pixel_perc) ** (1 / len(shape)))
|
|
138
|
+
|
|
139
|
+
pixel_coords = []
|
|
140
|
+
steps = []
|
|
141
|
+
|
|
142
|
+
# loop over dimensions
|
|
143
|
+
for axis_size in shape:
|
|
144
|
+
# number of pixels to mask along the axis
|
|
145
|
+
num_pixels = int(torch.ceil(torch.tensor(axis_size / mask_pixel_distance)))
|
|
146
|
+
|
|
147
|
+
# create 1D grid of coordinates for the axis
|
|
148
|
+
axis_pixel_coords = torch.linspace(
|
|
149
|
+
0,
|
|
150
|
+
axis_size - (axis_size // num_pixels),
|
|
151
|
+
num_pixels,
|
|
152
|
+
dtype=torch.int32,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# calculate the step size between coordinates
|
|
156
|
+
step = (
|
|
157
|
+
axis_pixel_coords[1] - axis_pixel_coords[0]
|
|
158
|
+
if len(axis_pixel_coords) > 1
|
|
159
|
+
else axis_size
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
pixel_coords.append(axis_pixel_coords)
|
|
163
|
+
steps.append(step)
|
|
164
|
+
|
|
165
|
+
# create a 2D meshgrid of coordinates
|
|
166
|
+
coordinate_grid_list = torch.meshgrid(*pixel_coords, indexing="ij")
|
|
167
|
+
coordinate_grid = torch.stack(
|
|
168
|
+
[g.flatten() for g in coordinate_grid_list], dim=-1
|
|
169
|
+
).to(rng.device)
|
|
170
|
+
|
|
171
|
+
# add a random jitter increment so that the coordinates do not lie on the grid
|
|
172
|
+
random_increment = torch.randint(
|
|
173
|
+
high=int(_odd_jitter_func_torch(float(max(steps)), rng)),
|
|
174
|
+
size=torch.tensor(coordinate_grid.shape).to(rng.device).tolist(),
|
|
175
|
+
generator=rng,
|
|
176
|
+
device=rng.device,
|
|
177
|
+
)
|
|
178
|
+
coordinate_grid += random_increment
|
|
179
|
+
|
|
180
|
+
# make sure no coordinate lie outside the range
|
|
181
|
+
return torch.clamp(
|
|
182
|
+
coordinate_grid,
|
|
183
|
+
torch.zeros_like(torch.tensor(shape)).to(device=rng.device),
|
|
184
|
+
torch.tensor([v - 1 for v in shape]).to(device=rng.device),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def uniform_manipulate_torch(
|
|
189
|
+
patch: torch.Tensor,
|
|
190
|
+
mask_pixel_percentage: float,
|
|
191
|
+
subpatch_size: int = 11,
|
|
192
|
+
remove_center: bool = True,
|
|
193
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
194
|
+
rng: Optional[torch.Generator] = None,
|
|
195
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
196
|
+
"""
|
|
197
|
+
Manipulate pixels by replacing them with a neighbor values.
|
|
198
|
+
|
|
199
|
+
# TODO add more details, especially about batch
|
|
200
|
+
|
|
201
|
+
Manipulated pixels are selected unformly selected in a subpatch, away from a grid
|
|
202
|
+
with an approximate uniform probability to be selected across the whole patch.
|
|
203
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the
|
|
204
|
+
data, replacing the pixels in the mask with random values (excluding the pixel
|
|
205
|
+
already manipulated).
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
patch : torch.Tensor
|
|
210
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x). # TODO batch and channel.
|
|
211
|
+
mask_pixel_percentage : float
|
|
212
|
+
Approximate percentage of pixels to be masked.
|
|
213
|
+
subpatch_size : int
|
|
214
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
215
|
+
remove_center : bool
|
|
216
|
+
Whether to remove the center pixel from the subpatch, by default False.
|
|
217
|
+
struct_params : StructMaskParameters or None
|
|
218
|
+
Parameters for the structN2V mask (axis and span).
|
|
219
|
+
rng : torch.default_generator or None
|
|
220
|
+
Random number generator.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
225
|
+
tuple containing the manipulated patch and the corresponding mask.
|
|
226
|
+
"""
|
|
227
|
+
if rng is None:
|
|
228
|
+
rng = torch.Generator(device=patch.device)
|
|
229
|
+
# TODO do we need seed ?
|
|
230
|
+
|
|
231
|
+
# create a copy of the patch
|
|
232
|
+
transformed_patch = patch.clone()
|
|
233
|
+
|
|
234
|
+
# get the coordinates of the pixels to be masked
|
|
235
|
+
subpatch_centers = _get_stratified_coords_torch(
|
|
236
|
+
mask_pixel_percentage, patch.shape, rng
|
|
237
|
+
)
|
|
238
|
+
subpatch_centers = subpatch_centers.to(device=patch.device)
|
|
239
|
+
|
|
240
|
+
# TODO refactor with non negative indices?
|
|
241
|
+
# arrange the list of indices to represent the ROI around the pixel to be masked
|
|
242
|
+
roi_span_full = torch.arange(
|
|
243
|
+
-(subpatch_size // 2),
|
|
244
|
+
subpatch_size // 2 + 1,
|
|
245
|
+
dtype=torch.int32,
|
|
246
|
+
device=patch.device,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# remove the center pixel from the ROI
|
|
250
|
+
roi_span = roi_span_full[roi_span_full != 0] if remove_center else roi_span_full
|
|
251
|
+
|
|
252
|
+
# create a random increment to select the replacement value
|
|
253
|
+
# this increment is added to the center coordinates
|
|
254
|
+
random_increment = roi_span[
|
|
255
|
+
torch.randint(
|
|
256
|
+
low=min(roi_span),
|
|
257
|
+
high=max(roi_span) + 1, # TODO check this, it may exclude one value
|
|
258
|
+
size=subpatch_centers.shape,
|
|
259
|
+
generator=rng,
|
|
260
|
+
device=patch.device,
|
|
261
|
+
)
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
# compute the replacement pixel coordinates
|
|
265
|
+
replacement_coords = torch.clamp(
|
|
266
|
+
subpatch_centers + random_increment,
|
|
267
|
+
torch.zeros_like(torch.tensor(patch.shape)).to(device=patch.device),
|
|
268
|
+
torch.tensor([v - 1 for v in patch.shape]).to(device=patch.device),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# replace the pixels in the patch
|
|
272
|
+
# tuples and transpose are needed for proper indexing
|
|
273
|
+
replacement_pixels = patch[tuple(replacement_coords.T)]
|
|
274
|
+
transformed_patch[tuple(subpatch_centers.T)] = replacement_pixels
|
|
275
|
+
|
|
276
|
+
# create a mask representing the masked pixels
|
|
277
|
+
mask = (transformed_patch != patch).to(dtype=torch.uint8)
|
|
278
|
+
|
|
279
|
+
# apply structN2V mask if needed
|
|
280
|
+
if struct_params is not None:
|
|
281
|
+
transformed_patch = _apply_struct_mask_torch(
|
|
282
|
+
transformed_patch, subpatch_centers, struct_params, rng
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return transformed_patch, mask
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def median_manipulate_torch(
|
|
289
|
+
batch: torch.Tensor,
|
|
290
|
+
mask_pixel_percentage: float,
|
|
291
|
+
subpatch_size: int = 11,
|
|
292
|
+
struct_params: Optional[StructMaskParameters] = None,
|
|
293
|
+
rng: Optional[torch.Generator] = None,
|
|
294
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
295
|
+
"""
|
|
296
|
+
Manipulate pixels by replacing them with the median of their surrounding subpatch.
|
|
297
|
+
|
|
298
|
+
N2V2 version, manipulated pixels are selected randomly away from a grid with an
|
|
299
|
+
approximate uniform probability to be selected across the whole patch.
|
|
300
|
+
|
|
301
|
+
If `struct_params` is not None, an additional structN2V mask is applied to the data,
|
|
302
|
+
replacing the pixels in the mask with random values (excluding the pixel already
|
|
303
|
+
manipulated).
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
batch : torch.Tensor
|
|
308
|
+
Image patch, 2D or 3D, shape (y, x) or (z, y, x).
|
|
309
|
+
mask_pixel_percentage : float
|
|
310
|
+
Approximate percentage of pixels to be masked.
|
|
311
|
+
subpatch_size : int
|
|
312
|
+
Size of the subpatch the new pixel value is sampled from, by default 11.
|
|
313
|
+
struct_params : StructMaskParameters or None, optional
|
|
314
|
+
Parameters for the structN2V mask (axis and span).
|
|
315
|
+
rng : torch.default_generator or None, optional
|
|
316
|
+
Random number generato, by default None.
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
321
|
+
tuple containing the manipulated patch, the original patch and the mask.
|
|
322
|
+
"""
|
|
323
|
+
# get the coordinates of the future ROI centers
|
|
324
|
+
subpatch_center_coordinates = _get_stratified_coords_torch(
|
|
325
|
+
mask_pixel_percentage, batch.shape, rng
|
|
326
|
+
).to(
|
|
327
|
+
device=batch.device
|
|
328
|
+
) # (num_coordinates, batch + num_spatial_dims)
|
|
329
|
+
|
|
330
|
+
# Calculate the padding value for the input tensor
|
|
331
|
+
pad_value = subpatch_size // 2
|
|
332
|
+
|
|
333
|
+
# Generate all offsets for the ROIs. Iteration starting from 1 to skip the batch
|
|
334
|
+
offsets = torch.meshgrid(
|
|
335
|
+
[
|
|
336
|
+
torch.arange(-pad_value, pad_value + 1, device=batch.device)
|
|
337
|
+
for _ in range(1, subpatch_center_coordinates.shape[1])
|
|
338
|
+
],
|
|
339
|
+
indexing="ij",
|
|
340
|
+
)
|
|
341
|
+
offsets = torch.stack(
|
|
342
|
+
[axis_offset.flatten() for axis_offset in offsets], dim=1
|
|
343
|
+
) # (subpatch_size**2, num_spatial_dims)
|
|
344
|
+
|
|
345
|
+
# Create the list to assemble coordinates of the ROIs centers for each axis
|
|
346
|
+
coords_axes = []
|
|
347
|
+
# Create the list to assemble the span of coordinates defining the ROIs for each
|
|
348
|
+
# axis
|
|
349
|
+
coords_expands = []
|
|
350
|
+
for d in range(subpatch_center_coordinates.shape[1]):
|
|
351
|
+
coords_axes.append(subpatch_center_coordinates[:, d])
|
|
352
|
+
if d == 0:
|
|
353
|
+
# For batch dimension coordinates are not expanded (no offsets)
|
|
354
|
+
coords_expands.append(
|
|
355
|
+
subpatch_center_coordinates[:, d]
|
|
356
|
+
.unsqueeze(1)
|
|
357
|
+
.expand(-1, subpatch_size ** offsets.shape[1])
|
|
358
|
+
) # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
359
|
+
else:
|
|
360
|
+
# For spatial dimensions, coordinates are expanded with offsets, creating
|
|
361
|
+
# spans
|
|
362
|
+
coords_expands.append(
|
|
363
|
+
(
|
|
364
|
+
subpatch_center_coordinates[:, d].unsqueeze(1) + offsets[:, d - 1]
|
|
365
|
+
).clamp(0, batch.shape[d] - 1)
|
|
366
|
+
) # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
367
|
+
|
|
368
|
+
# create array of rois by indexing the batch with gathered coordinates
|
|
369
|
+
rois = batch[
|
|
370
|
+
tuple(coords_expands)
|
|
371
|
+
] # (num_coordinates, subpatch_size**num_spacial_dims)
|
|
372
|
+
|
|
373
|
+
if struct_params is not None:
|
|
374
|
+
# Create the structN2V mask
|
|
375
|
+
h, w = torch.meshgrid(
|
|
376
|
+
torch.arange(subpatch_size), torch.arange(subpatch_size), indexing="ij"
|
|
377
|
+
)
|
|
378
|
+
center_idx = subpatch_size // 2
|
|
379
|
+
halfspan = (struct_params.span - 1) // 2
|
|
380
|
+
|
|
381
|
+
# Determine the axis along which to apply the mask
|
|
382
|
+
if struct_params.axis == 0:
|
|
383
|
+
center_axis = h
|
|
384
|
+
span_axis = w
|
|
385
|
+
else:
|
|
386
|
+
center_axis = w
|
|
387
|
+
span_axis = h
|
|
388
|
+
|
|
389
|
+
# Create the mask
|
|
390
|
+
struct_mask = (
|
|
391
|
+
~(
|
|
392
|
+
(center_axis == center_idx)
|
|
393
|
+
& (span_axis >= center_idx - halfspan)
|
|
394
|
+
& (span_axis <= center_idx + halfspan)
|
|
395
|
+
)
|
|
396
|
+
).flatten()
|
|
397
|
+
rois_filtered = rois[:, struct_mask]
|
|
398
|
+
else:
|
|
399
|
+
# Remove the center pixel value from the rois
|
|
400
|
+
center_idx = (subpatch_size ** offsets.shape[1]) // 2
|
|
401
|
+
rois_filtered = torch.cat(
|
|
402
|
+
[rois[:, :center_idx], rois[:, center_idx + 1 :]], dim=1
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# compute the medians.
|
|
406
|
+
medians = rois_filtered.median(dim=1).values # (num_coordinates,)
|
|
407
|
+
|
|
408
|
+
# Update the output tensor with medians
|
|
409
|
+
output_batch = batch.clone()
|
|
410
|
+
output_batch[tuple(coords_axes)] = medians
|
|
411
|
+
mask = torch.where(output_batch != batch, 1, 0).to(torch.uint8)
|
|
412
|
+
|
|
413
|
+
if struct_params is not None:
|
|
414
|
+
output_batch = _apply_struct_mask_torch(
|
|
415
|
+
output_batch, subpatch_center_coordinates, struct_params
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
return output_batch, mask
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Version utility."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from careamics import __version__
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_careamics_version() -> str:
|
|
11
|
+
"""Get clean CAREamics version.
|
|
12
|
+
|
|
13
|
+
This method returns the latest `Major.Minor.Patch` version of CAREamics, removing
|
|
14
|
+
any local version identifier.
|
|
15
|
+
|
|
16
|
+
Returns
|
|
17
|
+
-------
|
|
18
|
+
str
|
|
19
|
+
Clean CAREamics version.
|
|
20
|
+
"""
|
|
21
|
+
parts = __version__.split(".")
|
|
22
|
+
|
|
23
|
+
# for local installs that do not detect the latest versions via tags
|
|
24
|
+
# (typically our CI will install `0.1.devX<hash>` versions)
|
|
25
|
+
if "dev" in parts[-1]:
|
|
26
|
+
parts[-1] = "*"
|
|
27
|
+
clean_version = ".".join(parts[:3])
|
|
28
|
+
|
|
29
|
+
logger.warning(
|
|
30
|
+
f"Your CAREamics version seems to be a locally modified version "
|
|
31
|
+
f"({__version__}). The recorded version for loading models will be "
|
|
32
|
+
f"{clean_version}, which may not exist. If you want to ensure "
|
|
33
|
+
f"exporting the model with an existing version, please install the "
|
|
34
|
+
f"closest CAREamics version from PyPI or conda-forge."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Remove any local version identifier)
|
|
38
|
+
return ".".join(parts[:3])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
@@ -17,19 +17,19 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
17
17
|
Classifier: Typing :: Typed
|
|
18
18
|
Requires-Python: >=3.9
|
|
19
19
|
Requires-Dist: bioimageio-core==0.7
|
|
20
|
-
Requires-Dist: matplotlib<=3.10.
|
|
20
|
+
Requires-Dist: matplotlib<=3.10.1
|
|
21
21
|
Requires-Dist: numpy<2.0.0
|
|
22
22
|
Requires-Dist: pillow<=11.1.0
|
|
23
|
-
Requires-Dist: psutil<=
|
|
23
|
+
Requires-Dist: psutil<=7.0.0
|
|
24
24
|
Requires-Dist: pydantic<2.11,>=2.5
|
|
25
25
|
Requires-Dist: pytorch-lightning<=2.5.0.post0,>=2.2
|
|
26
26
|
Requires-Dist: pyyaml!=6.0.0,<=6.0.2
|
|
27
|
-
Requires-Dist: scikit-image<=0.25.
|
|
28
|
-
Requires-Dist: tifffile<=2025.
|
|
27
|
+
Requires-Dist: scikit-image<=0.25.2
|
|
28
|
+
Requires-Dist: tifffile<=2025.3.13
|
|
29
29
|
Requires-Dist: torch<=2.6.0,>=2.0
|
|
30
|
-
Requires-Dist: torchvision<=0.20.1
|
|
31
30
|
Requires-Dist: torchvision<=0.21.0
|
|
32
|
-
Requires-Dist: typer<=0.15.
|
|
31
|
+
Requires-Dist: typer<=0.15.2,>=0.12.3
|
|
32
|
+
Requires-Dist: xarray<2025.3.0
|
|
33
33
|
Requires-Dist: zarr<3.0.0
|
|
34
34
|
Provides-Extra: dev
|
|
35
35
|
Requires-Dist: onnx; extra == 'dev'
|
|
@@ -40,7 +40,6 @@ Requires-Dist: sybil; extra == 'dev'
|
|
|
40
40
|
Provides-Extra: examples
|
|
41
41
|
Requires-Dist: careamics-portfolio; extra == 'examples'
|
|
42
42
|
Requires-Dist: jupyter; extra == 'examples'
|
|
43
|
-
Requires-Dist: matplotlib; extra == 'examples'
|
|
44
43
|
Provides-Extra: tensorboard
|
|
45
44
|
Requires-Dist: protobuf==5.29.1; extra == 'tensorboard'
|
|
46
45
|
Requires-Dist: tensorboard; extra == 'tensorboard'
|