careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/careamist.py +7 -4
- careamics/config/configuration.py +6 -55
- careamics/config/configuration_factories.py +22 -12
- careamics/config/data/data_model.py +49 -9
- careamics/config/data/ng_data_model.py +167 -2
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_model.py +16 -0
- careamics/config/data/patch_filter/mask_filter_model.py +17 -0
- careamics/config/data/patch_filter/max_filter_model.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/dataset_ng/dataset.py +57 -5
- careamics/dataset_ng/factory.py +101 -18
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/lightning/callbacks/data_stats_callback.py +13 -3
- careamics/lightning/dataset_ng/data_module.py +79 -2
- careamics/lightning/lightning_module.py +4 -3
- careamics/lightning/microsplit_data_module.py +15 -10
- careamics/lvae_training/eval_utils.py +46 -24
- careamics/models/lvae/likelihoods.py +2 -1
- careamics/prediction_utils/prediction_outputs.py +3 -2
- careamics/prediction_utils/stitch_prediction.py +17 -6
- careamics/utils/version.py +4 -4
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
- {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
|
@@ -39,6 +39,10 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
39
39
|
train_data_target : Optional[InputType]
|
|
40
40
|
Training data target, can be a path to a folder,
|
|
41
41
|
a list of paths, or a numpy array.
|
|
42
|
+
train_data_mask : InputType (when filtering is needed)
|
|
43
|
+
Training data mask, can be a path to a folder,
|
|
44
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
45
|
+
Only required when using coordinate-based patch filtering.
|
|
42
46
|
val_data : Optional[InputType]
|
|
43
47
|
Validation data, can be a path to a folder,
|
|
44
48
|
a list of paths, or a numpy array.
|
|
@@ -99,6 +103,9 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
99
103
|
train_data_target : Optional[Any]
|
|
100
104
|
Training data target, can be a path to a folder, a list of paths, or a numpy
|
|
101
105
|
array.
|
|
106
|
+
train_data_mask : Optional[Any]
|
|
107
|
+
Training data mask, can be a path to a folder, a list of paths, or a numpy
|
|
108
|
+
array.
|
|
102
109
|
val_data : Optional[Any]
|
|
103
110
|
Validation data, can be a path to a folder, a list of paths, or a numpy array.
|
|
104
111
|
val_data_target : Optional[Any]
|
|
@@ -118,7 +125,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
118
125
|
If input and target data types are not consistent.
|
|
119
126
|
"""
|
|
120
127
|
|
|
121
|
-
# standard use
|
|
128
|
+
# standard use (no mask)
|
|
122
129
|
@overload
|
|
123
130
|
def __init__(
|
|
124
131
|
self,
|
|
@@ -136,7 +143,26 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
136
143
|
use_in_memory: bool = True,
|
|
137
144
|
) -> None: ...
|
|
138
145
|
|
|
139
|
-
#
|
|
146
|
+
# with training mask for filtering
|
|
147
|
+
@overload
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
data_config: NGDataConfig,
|
|
151
|
+
*,
|
|
152
|
+
train_data: InputType | None = None,
|
|
153
|
+
train_data_target: InputType | None = None,
|
|
154
|
+
train_data_mask: InputType,
|
|
155
|
+
val_data: InputType | None = None,
|
|
156
|
+
val_data_target: InputType | None = None,
|
|
157
|
+
pred_data: InputType | None = None,
|
|
158
|
+
pred_data_target: InputType | None = None,
|
|
159
|
+
extension_filter: str = "",
|
|
160
|
+
val_percentage: float | None = None,
|
|
161
|
+
val_minimum_split: int = 5,
|
|
162
|
+
use_in_memory: bool = True,
|
|
163
|
+
) -> None: ...
|
|
164
|
+
|
|
165
|
+
# custom read function (no mask)
|
|
140
166
|
@overload
|
|
141
167
|
def __init__(
|
|
142
168
|
self,
|
|
@@ -156,6 +182,48 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
156
182
|
use_in_memory: bool = True,
|
|
157
183
|
) -> None: ...
|
|
158
184
|
|
|
185
|
+
# custom read function with training mask
|
|
186
|
+
@overload
|
|
187
|
+
def __init__(
|
|
188
|
+
self,
|
|
189
|
+
data_config: NGDataConfig,
|
|
190
|
+
*,
|
|
191
|
+
train_data: InputType | None = None,
|
|
192
|
+
train_data_target: InputType | None = None,
|
|
193
|
+
train_data_mask: InputType,
|
|
194
|
+
val_data: InputType | None = None,
|
|
195
|
+
val_data_target: InputType | None = None,
|
|
196
|
+
pred_data: InputType | None = None,
|
|
197
|
+
pred_data_target: InputType | None = None,
|
|
198
|
+
read_source_func: Callable,
|
|
199
|
+
read_kwargs: dict[str, Any] | None = None,
|
|
200
|
+
extension_filter: str = "",
|
|
201
|
+
val_percentage: float | None = None,
|
|
202
|
+
val_minimum_split: int = 5,
|
|
203
|
+
use_in_memory: bool = True,
|
|
204
|
+
) -> None: ...
|
|
205
|
+
|
|
206
|
+
# image stack loader (no mask)
|
|
207
|
+
@overload
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
data_config: NGDataConfig,
|
|
211
|
+
*,
|
|
212
|
+
train_data: Any | None = None,
|
|
213
|
+
train_data_target: Any | None = None,
|
|
214
|
+
val_data: Any | None = None,
|
|
215
|
+
val_data_target: Any | None = None,
|
|
216
|
+
pred_data: Any | None = None,
|
|
217
|
+
pred_data_target: Any | None = None,
|
|
218
|
+
image_stack_loader: ImageStackLoader,
|
|
219
|
+
image_stack_loader_kwargs: dict[str, Any] | None = None,
|
|
220
|
+
extension_filter: str = "",
|
|
221
|
+
val_percentage: float | None = None,
|
|
222
|
+
val_minimum_split: int = 5,
|
|
223
|
+
use_in_memory: bool = True,
|
|
224
|
+
) -> None: ...
|
|
225
|
+
|
|
226
|
+
# image stack loader with training mask
|
|
159
227
|
@overload
|
|
160
228
|
def __init__(
|
|
161
229
|
self,
|
|
@@ -163,6 +231,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
163
231
|
*,
|
|
164
232
|
train_data: Any | None = None,
|
|
165
233
|
train_data_target: Any | None = None,
|
|
234
|
+
train_data_mask: Any,
|
|
166
235
|
val_data: Any | None = None,
|
|
167
236
|
val_data_target: Any | None = None,
|
|
168
237
|
pred_data: Any | None = None,
|
|
@@ -181,6 +250,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
181
250
|
*,
|
|
182
251
|
train_data: Any | None = None,
|
|
183
252
|
train_data_target: Any | None = None,
|
|
253
|
+
train_data_mask: Any | None = None,
|
|
184
254
|
val_data: Any | None = None,
|
|
185
255
|
val_data_target: Any | None = None,
|
|
186
256
|
pred_data: Any | None = None,
|
|
@@ -209,6 +279,10 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
209
279
|
train_data_target : Optional[InputType]
|
|
210
280
|
Training data target, can be a path to a folder,
|
|
211
281
|
a list of paths, or a numpy array.
|
|
282
|
+
train_data_mask : InputType (when filtering is needed)
|
|
283
|
+
Training data mask, can be a path to a folder,
|
|
284
|
+
a list of paths, or a numpy array. Used for coordinate filtering.
|
|
285
|
+
Only required when using coordinate-based patch filtering.
|
|
212
286
|
val_data : Optional[InputType]
|
|
213
287
|
Validation data, can be a path to a folder,
|
|
214
288
|
a list of paths, or a numpy array.
|
|
@@ -268,6 +342,8 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
268
342
|
self.train_data, self.train_data_target = self._initialize_data_pair(
|
|
269
343
|
train_data, train_data_target
|
|
270
344
|
)
|
|
345
|
+
self.train_data_mask, _ = self._initialize_data_pair(train_data_mask, None)
|
|
346
|
+
|
|
271
347
|
self.val_data, self.val_data_target = self._initialize_data_pair(
|
|
272
348
|
val_data, val_data_target
|
|
273
349
|
)
|
|
@@ -574,6 +650,7 @@ class CareamicsDataModule(L.LightningDataModule):
|
|
|
574
650
|
mode=Mode.TRAINING,
|
|
575
651
|
inputs=self.train_data,
|
|
576
652
|
targets=self.train_data_target,
|
|
653
|
+
masks=self.train_data_mask,
|
|
577
654
|
config=self.config,
|
|
578
655
|
in_memory=self.use_in_memory,
|
|
579
656
|
read_func=self.read_source_func,
|
|
@@ -437,7 +437,8 @@ class VAEModule(L.LightningModule):
|
|
|
437
437
|
or self.noise_model_likelihood.data_std is None
|
|
438
438
|
):
|
|
439
439
|
raise RuntimeError(
|
|
440
|
-
"NoiseModelLikelihood: data_mean and data_std must be set before
|
|
440
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before"
|
|
441
|
+
"training."
|
|
441
442
|
)
|
|
442
443
|
loss = self.loss_func(
|
|
443
444
|
model_outputs=out,
|
|
@@ -541,9 +542,9 @@ class VAEModule(L.LightningModule):
|
|
|
541
542
|
# get reconstructed img
|
|
542
543
|
if self.model.predict_logvar is None:
|
|
543
544
|
rec_img = rec
|
|
544
|
-
|
|
545
|
+
_logvar = torch.tensor([-1])
|
|
545
546
|
else:
|
|
546
|
-
rec_img,
|
|
547
|
+
rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
|
|
547
548
|
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
|
|
548
549
|
|
|
549
550
|
# aggregate results
|
|
@@ -74,7 +74,8 @@ def load_data(datadir):
|
|
|
74
74
|
channel_stack = np.concatenate(
|
|
75
75
|
channel_images, axis=0
|
|
76
76
|
) # FIXME: this line works iff images have
|
|
77
|
-
# a singleton channel dimension. Specify in the notebook or change with
|
|
77
|
+
# a singleton channel dimension. Specify in the notebook or change with
|
|
78
|
+
# `torch.stack`??
|
|
78
79
|
channels_data.append(channel_stack)
|
|
79
80
|
|
|
80
81
|
final_data = np.stack(channels_data, axis=-1)
|
|
@@ -204,7 +205,8 @@ class MicroSplitDataModule(L.LightningDataModule):
|
|
|
204
205
|
|
|
205
206
|
def __init__(
|
|
206
207
|
self,
|
|
207
|
-
|
|
208
|
+
# Should be compatible with microSplit DatasetConfig
|
|
209
|
+
data_config: MicroSplitDataConfig,
|
|
208
210
|
train_data: str,
|
|
209
211
|
val_data: str | None = None,
|
|
210
212
|
train_data_target: str | None = None,
|
|
@@ -301,7 +303,8 @@ class MicroSplitDataModule(L.LightningDataModule):
|
|
|
301
303
|
"""
|
|
302
304
|
return DataLoader(
|
|
303
305
|
self.train_dataset,
|
|
304
|
-
|
|
306
|
+
# TODO should be inside dataloader params?
|
|
307
|
+
batch_size=self.train_config.batch_size,
|
|
305
308
|
**self.train_config.train_dataloader_params,
|
|
306
309
|
)
|
|
307
310
|
|
|
@@ -355,7 +358,9 @@ def create_microsplit_train_datamodule(
|
|
|
355
358
|
**dataset_kwargs,
|
|
356
359
|
) -> MicroSplitDataModule:
|
|
357
360
|
"""
|
|
358
|
-
Create a MicroSplitDataModule for microSplit-style datasets
|
|
361
|
+
Create a MicroSplitDataModule for microSplit-style datasets.
|
|
362
|
+
|
|
363
|
+
This includes config creation.
|
|
359
364
|
|
|
360
365
|
Parameters
|
|
361
366
|
----------
|
|
@@ -424,10 +429,10 @@ def create_microsplit_train_datamodule(
|
|
|
424
429
|
**dataset_config_params,
|
|
425
430
|
datasplit_type=DataSplitType.Train,
|
|
426
431
|
)
|
|
427
|
-
val_config = MicroSplitDataConfig(
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
)
|
|
432
|
+
# val_config = MicroSplitDataConfig(
|
|
433
|
+
# **dataset_config_params,
|
|
434
|
+
# datasplit_type=DataSplitType.Val,
|
|
435
|
+
# )
|
|
431
436
|
# TODO, data config is duplicated here and in configuration
|
|
432
437
|
|
|
433
438
|
return MicroSplitDataModule(
|
|
@@ -578,10 +583,10 @@ def create_microsplit_predict_datamodule(
|
|
|
578
583
|
Grid size for patch extraction.
|
|
579
584
|
multiscale_count : int, optional
|
|
580
585
|
Number of LC scales.
|
|
581
|
-
tiling_mode : TilingMode, default=ShiftBoundary
|
|
582
|
-
Tiling mode for patch extraction.
|
|
583
586
|
data_stats : tuple, optional
|
|
584
587
|
Data statistics, by default None.
|
|
588
|
+
tiling_mode : TilingMode, default=ShiftBoundary
|
|
589
|
+
Tiling mode for patch extraction.
|
|
585
590
|
read_source_func : Callable, optional
|
|
586
591
|
Function to read the source data.
|
|
587
592
|
extension_filter : str, optional
|
|
@@ -32,7 +32,7 @@ class TilingMode:
|
|
|
32
32
|
ShiftBoundary = 2
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
#
|
|
35
|
+
# ------------------------------------------------------------------------------------
|
|
36
36
|
# Function of plotting: TODO -> moved them to another file, plot_utils.py
|
|
37
37
|
def clean_ax(ax):
|
|
38
38
|
"""
|
|
@@ -68,7 +68,9 @@ def get_psnr_str(tar_hsnr, pred, col_idx):
|
|
|
68
68
|
"""
|
|
69
69
|
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
|
|
70
70
|
"""
|
|
71
|
-
|
|
71
|
+
psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
|
|
72
|
+
|
|
73
|
+
return f"{psnr:.1f}"
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
def add_psnr_str(ax_, psnr):
|
|
@@ -129,8 +131,10 @@ def show_for_one(
|
|
|
129
131
|
baseline_preds=None,
|
|
130
132
|
):
|
|
131
133
|
"""
|
|
132
|
-
Given an index, it plots the input, target, reconstructed images and the difference
|
|
133
|
-
|
|
134
|
+
Given an index, it plots the input, target, reconstructed images and the difference
|
|
135
|
+
image.
|
|
136
|
+
Note the the difference image is computed with respect to a ground truth image,
|
|
137
|
+
obtained from the high SNR dataset.
|
|
134
138
|
"""
|
|
135
139
|
highsnr_val_dset.set_img_sz(patch_size, 64)
|
|
136
140
|
highsnr_val_dset.disable_noise()
|
|
@@ -164,7 +168,8 @@ def plot_crops(
|
|
|
164
168
|
for i in range(len(baseline_preds)):
|
|
165
169
|
if baseline_preds[i].shape != tar_hsnr.shape:
|
|
166
170
|
print(
|
|
167
|
-
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not
|
|
171
|
+
f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
|
|
172
|
+
f"match target shape {tar_hsnr.shape}"
|
|
168
173
|
)
|
|
169
174
|
print("This happens when we want to predict the edges of the image.")
|
|
170
175
|
return
|
|
@@ -333,14 +338,21 @@ def plot_crops(
|
|
|
333
338
|
ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
|
|
334
339
|
clean_ax(ax_temp)
|
|
335
340
|
|
|
336
|
-
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
|
|
337
|
-
#
|
|
338
|
-
#
|
|
339
|
-
#
|
|
340
|
-
#
|
|
341
|
-
#
|
|
341
|
+
# line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
|
|
342
|
+
# label='$C_1$')
|
|
343
|
+
# line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
|
|
344
|
+
# label='$C_2$')
|
|
345
|
+
# line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
|
|
346
|
+
# label='Pred')
|
|
347
|
+
# line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
|
|
348
|
+
# linestyle='--', label='$C^N_1$')
|
|
349
|
+
# line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
|
|
350
|
+
# linestyle='--', label='$C^N_2$')
|
|
351
|
+
# legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
|
|
352
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
342
353
|
# prop={'size': 11})
|
|
343
|
-
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
|
|
354
|
+
# legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
|
|
355
|
+
# loc='upper right', frameon=False, labelcolor='white',
|
|
344
356
|
# prop={'size': 11})
|
|
345
357
|
|
|
346
358
|
if calibration_stats is not None:
|
|
@@ -383,7 +395,9 @@ def plot_calibration(ax, calibration_stats):
|
|
|
383
395
|
|
|
384
396
|
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
|
|
385
397
|
"""
|
|
386
|
-
Adapted from
|
|
398
|
+
Adapted from
|
|
399
|
+
https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
|
|
400
|
+
matplotlib
|
|
387
401
|
|
|
388
402
|
Function to offset the "center" of a colormap. Useful for
|
|
389
403
|
data with a negative min and positive max and you want the
|
|
@@ -444,7 +458,8 @@ def get_fractional_change(target, prediction, max_val=None):
|
|
|
444
458
|
|
|
445
459
|
def get_zero_centered_midval(error):
|
|
446
460
|
"""
|
|
447
|
-
When done this way, the midval ensures that the colorbar is centered at 0. (Don't
|
|
461
|
+
When done this way, the midval ensures that the colorbar is centered at 0. (Don't
|
|
462
|
+
know how, but it works ;))
|
|
448
463
|
"""
|
|
449
464
|
vmax = error.max()
|
|
450
465
|
vmin = error.min()
|
|
@@ -670,7 +685,7 @@ def get_single_file_mmse(
|
|
|
670
685
|
return stitched_predictions, stitched_stds
|
|
671
686
|
|
|
672
687
|
|
|
673
|
-
#
|
|
688
|
+
# ---------------------------------------------------------------------------------
|
|
674
689
|
### Classes and Functions used to stitch predictions
|
|
675
690
|
class PatchLocation:
|
|
676
691
|
"""
|
|
@@ -698,9 +713,10 @@ def _get_location(extra_padding, hwt, pred_h, pred_w):
|
|
|
698
713
|
|
|
699
714
|
def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
|
|
700
715
|
"""
|
|
701
|
-
For a given idx of the dataset, it returns where exactly in the dataset, does this
|
|
702
|
-
Note that this prediction also has padded pixels and so a subset of
|
|
703
|
-
Which time frame, which spatial location
|
|
716
|
+
For a given idx of the dataset, it returns where exactly in the dataset, does this
|
|
717
|
+
prediction lies. Note that this prediction also has padded pixels and so a subset of
|
|
718
|
+
it will be used in the final prediction. Which time frame, which spatial location
|
|
719
|
+
(h_start, h_end, w_start,w_end)
|
|
704
720
|
Args:
|
|
705
721
|
dset:
|
|
706
722
|
dset_input_idx:
|
|
@@ -785,9 +801,11 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
|
|
|
785
801
|
# NOTE: don't need to compute it for every patch.
|
|
786
802
|
assert (
|
|
787
803
|
smoothening_pixelcount == 0
|
|
788
|
-
), "For smoothing,enable the get_smoothing_mask. It is disabled since I
|
|
804
|
+
), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
|
|
805
|
+
"don't use it and it needs modification to work with non-square images"
|
|
789
806
|
mask = 1
|
|
790
|
-
# mask = _get_smoothing_mask(cropped_pred_i.shape,
|
|
807
|
+
# mask = _get_smoothing_mask(cropped_pred_i.shape,
|
|
808
|
+
# smoothening_pixelcount, loc, frame_size)
|
|
791
809
|
|
|
792
810
|
cropped_pred_list.append(cropped_pred_i)
|
|
793
811
|
|
|
@@ -827,7 +845,8 @@ def stitch_predictions_new(predictions, dset):
|
|
|
827
845
|
output = np.zeros(shape, dtype=predictions.dtype)
|
|
828
846
|
# frame_shape = dset.get_data_shape()[:-1]
|
|
829
847
|
for dset_idx in range(predictions.shape[0]):
|
|
830
|
-
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
848
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
849
|
+
# predictions.shape[-1])
|
|
831
850
|
# grid start, grid end
|
|
832
851
|
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
833
852
|
ge = gs + mng.grid_shape
|
|
@@ -843,7 +862,8 @@ def stitch_predictions_new(predictions, dset):
|
|
|
843
862
|
vgs = np.array([max(0, x) for x in gs], dtype=int)
|
|
844
863
|
vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
|
|
845
864
|
# assert np.all(vgs == gs)
|
|
846
|
-
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
|
|
865
|
+
# assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
|
|
866
|
+
# to dig why it's failing at this point !
|
|
847
867
|
# print('VGS')
|
|
848
868
|
# print(gs)
|
|
849
869
|
# print(ge)
|
|
@@ -898,7 +918,8 @@ def stitch_predictions_general(predictions, dset):
|
|
|
898
918
|
# frame_shape = dset.get_data_shape()[:-1]
|
|
899
919
|
for patch_idx in range(predictions.shape[0]):
|
|
900
920
|
# grid start, grid end
|
|
901
|
-
# channel_idx is 0 because during prediction we're only use one channel.
|
|
921
|
+
# channel_idx is 0 because during prediction we're only use one channel.
|
|
922
|
+
# # TODO revisit this
|
|
902
923
|
# 0th dimension is sample index in the output list
|
|
903
924
|
grid_coords = np.array(
|
|
904
925
|
mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
|
|
@@ -906,7 +927,8 @@ def stitch_predictions_general(predictions, dset):
|
|
|
906
927
|
)
|
|
907
928
|
sample_idx = grid_coords[0]
|
|
908
929
|
grid_start = grid_coords[1:]
|
|
909
|
-
# from here on, coordinates are relative to the sample(file in the list of
|
|
930
|
+
# from here on, coordinates are relative to the sample(file in the list of
|
|
931
|
+
# inputs)
|
|
910
932
|
grid_end = grid_start + mng.grid_shape
|
|
911
933
|
|
|
912
934
|
# patch start, patch end
|
|
@@ -381,7 +381,8 @@ class NoiseModelLikelihood(LikelihoodModule):
|
|
|
381
381
|
"""
|
|
382
382
|
if self.data_mean is None or self.data_std is None:
|
|
383
383
|
raise RuntimeError(
|
|
384
|
-
"NoiseModelLikelihood: data_mean and data_std must be set before
|
|
384
|
+
"NoiseModelLikelihood: data_mean and data_std must be set before"
|
|
385
|
+
"callinglog_likelihood."
|
|
385
386
|
)
|
|
386
387
|
self._set_params_to_same_device_as(x)
|
|
387
388
|
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
|
|
@@ -47,8 +47,9 @@ def convert_outputs_microsplit(
|
|
|
47
47
|
"""
|
|
48
48
|
Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
|
|
49
49
|
|
|
50
|
-
This function processes microsplit predictions that return (tile_prediction,
|
|
51
|
-
and stitches them back together using the same logic as
|
|
50
|
+
This function processes microsplit predictions that return (tile_prediction,
|
|
51
|
+
tile_std) tuples and stitches them back together using the same logic as
|
|
52
|
+
get_single_file_mmse.
|
|
52
53
|
|
|
53
54
|
Parameters
|
|
54
55
|
----------
|
|
@@ -17,13 +17,21 @@ class TilingMode:
|
|
|
17
17
|
ShiftBoundary = 2
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
def stitch_prediction_vae(predictions, dset):
|
|
20
|
+
def stitch_prediction_vae(predictions, dset) -> NDArray:
|
|
21
21
|
"""
|
|
22
22
|
Stitch predictions back together using dataset's index manager.
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
predictions : np.ndarray
|
|
27
|
+
Array of predictions with shape (n_tiles, channels, height, width).
|
|
28
|
+
dset : Dataset
|
|
29
|
+
Dataset object with idx_manager containing tiling information.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
np.ndarray
|
|
34
|
+
Stitched array with shape matching the original data shape.
|
|
27
35
|
"""
|
|
28
36
|
mng = dset.idx_manager
|
|
29
37
|
|
|
@@ -34,7 +42,8 @@ def stitch_prediction_vae(predictions, dset):
|
|
|
34
42
|
output = np.zeros(shape, dtype=predictions.dtype)
|
|
35
43
|
# frame_shape = dset.get_data_shape()[:-1]
|
|
36
44
|
for dset_idx in range(predictions.shape[0]):
|
|
37
|
-
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
45
|
+
# loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
|
|
46
|
+
# predictions.shape[-1])
|
|
38
47
|
# grid start, grid end
|
|
39
48
|
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
|
40
49
|
ge = gs + mng.grid_shape
|
|
@@ -178,6 +187,8 @@ def stitch_prediction_single(
|
|
|
178
187
|
|
|
179
188
|
# Insert cropped tile into predicted image using stitch coordinates
|
|
180
189
|
image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
|
|
181
|
-
|
|
190
|
+
|
|
191
|
+
# TODO fix mypy error here, potentially due to numpy 2
|
|
192
|
+
predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
|
|
182
193
|
|
|
183
194
|
return predicted_image
|
careamics/utils/version.py
CHANGED
|
@@ -21,9 +21,9 @@ def get_careamics_version() -> str:
|
|
|
21
21
|
parts = __version__.split(".")
|
|
22
22
|
|
|
23
23
|
# for local installs that do not detect the latest versions via tags
|
|
24
|
-
# (typically our CI will install `0.
|
|
25
|
-
if "dev" in parts[
|
|
26
|
-
parts[
|
|
24
|
+
# (typically our CI will install `0.X.devX<hash>.<other hash>` versions)
|
|
25
|
+
if "dev" in parts[2]:
|
|
26
|
+
parts[2] = "*"
|
|
27
27
|
clean_version = ".".join(parts[:3])
|
|
28
28
|
|
|
29
29
|
logger.warning(
|
|
@@ -34,5 +34,5 @@ def get_careamics_version() -> str:
|
|
|
34
34
|
f"closest CAREamics version from PyPI or conda-forge."
|
|
35
35
|
)
|
|
36
36
|
|
|
37
|
-
# Remove any local version identifier
|
|
37
|
+
# Remove any local version identifier
|
|
38
38
|
return ".".join(parts[:3])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: careamics
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: Toolbox for running N2V and friends.
|
|
5
5
|
Project-URL: homepage, https://careamics.github.io/
|
|
6
6
|
Project-URL: repository, https://github.com/CAREamics/careamics
|
|
@@ -16,30 +16,24 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
16
16
|
Classifier: Typing :: Typed
|
|
17
17
|
Requires-Python: >=3.11
|
|
18
18
|
Requires-Dist: bioimageio-core>=0.9.0
|
|
19
|
-
Requires-Dist: matplotlib<=3.10.
|
|
19
|
+
Requires-Dist: matplotlib<=3.10.7
|
|
20
20
|
Requires-Dist: numpy>=1.21
|
|
21
21
|
Requires-Dist: numpy>=2.1.0; python_version >= '3.13'
|
|
22
22
|
Requires-Dist: pillow<=11.3.0
|
|
23
23
|
Requires-Dist: psutil<=7.1.0
|
|
24
|
-
Requires-Dist: pydantic<=2.12,>=2.11
|
|
24
|
+
Requires-Dist: pydantic<=2.12.2,>=2.11
|
|
25
25
|
Requires-Dist: pytorch-lightning<=2.5.5,>=2.2
|
|
26
26
|
Requires-Dist: pyyaml!=6.0.0,<=6.0.3
|
|
27
27
|
Requires-Dist: scikit-image<=0.25.2
|
|
28
|
-
Requires-Dist: tifffile<=2025.
|
|
28
|
+
Requires-Dist: tifffile<=2025.10.4
|
|
29
29
|
Requires-Dist: torch<=2.8.0,>=2.0
|
|
30
|
+
Requires-Dist: torchmetrics<1.5.0,>=0.11.0
|
|
30
31
|
Requires-Dist: torchvision<=0.23.0
|
|
31
32
|
Requires-Dist: typer<=0.19.2,>=0.12.3
|
|
32
33
|
Requires-Dist: validators<=0.35.0
|
|
33
34
|
Requires-Dist: zarr<4.0.0,>=3.0.0
|
|
34
35
|
Provides-Extra: czi
|
|
35
36
|
Requires-Dist: pylibczirw<6.0.0,>=4.1.2; extra == 'czi'
|
|
36
|
-
Provides-Extra: dev
|
|
37
|
-
Requires-Dist: ml-dtypes>=0.5.0; extra == 'dev'
|
|
38
|
-
Requires-Dist: onnx; extra == 'dev'
|
|
39
|
-
Requires-Dist: pre-commit; extra == 'dev'
|
|
40
|
-
Requires-Dist: pytest; extra == 'dev'
|
|
41
|
-
Requires-Dist: pytest-cov; extra == 'dev'
|
|
42
|
-
Requires-Dist: sybil; extra == 'dev'
|
|
43
37
|
Provides-Extra: examples
|
|
44
38
|
Requires-Dist: careamics-portfolio; extra == 'examples'
|
|
45
39
|
Requires-Dist: jupyter; extra == 'examples'
|