careamics 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (36) hide show
  1. careamics/careamist.py +7 -4
  2. careamics/config/configuration.py +6 -55
  3. careamics/config/configuration_factories.py +22 -12
  4. careamics/config/data/data_model.py +49 -9
  5. careamics/config/data/ng_data_model.py +167 -2
  6. careamics/config/data/patch_filter/__init__.py +15 -0
  7. careamics/config/data/patch_filter/filter_model.py +16 -0
  8. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  9. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  10. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  11. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  12. careamics/config/support/supported_filters.py +17 -0
  13. careamics/dataset_ng/dataset.py +57 -5
  14. careamics/dataset_ng/factory.py +101 -18
  15. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  16. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  17. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  18. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  19. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  20. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  21. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  22. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  23. careamics/lightning/callbacks/data_stats_callback.py +13 -3
  24. careamics/lightning/dataset_ng/data_module.py +79 -2
  25. careamics/lightning/lightning_module.py +4 -3
  26. careamics/lightning/microsplit_data_module.py +15 -10
  27. careamics/lvae_training/eval_utils.py +46 -24
  28. careamics/models/lvae/likelihoods.py +2 -1
  29. careamics/prediction_utils/prediction_outputs.py +3 -2
  30. careamics/prediction_utils/stitch_prediction.py +17 -6
  31. careamics/utils/version.py +4 -4
  32. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/METADATA +5 -11
  33. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/RECORD +36 -21
  34. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  35. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  36. {careamics-0.0.16.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -39,6 +39,10 @@ class CareamicsDataModule(L.LightningDataModule):
39
39
  train_data_target : Optional[InputType]
40
40
  Training data target, can be a path to a folder,
41
41
  a list of paths, or a numpy array.
42
+ train_data_mask : InputType (when filtering is needed)
43
+ Training data mask, can be a path to a folder,
44
+ a list of paths, or a numpy array. Used for coordinate filtering.
45
+ Only required when using coordinate-based patch filtering.
42
46
  val_data : Optional[InputType]
43
47
  Validation data, can be a path to a folder,
44
48
  a list of paths, or a numpy array.
@@ -99,6 +103,9 @@ class CareamicsDataModule(L.LightningDataModule):
99
103
  train_data_target : Optional[Any]
100
104
  Training data target, can be a path to a folder, a list of paths, or a numpy
101
105
  array.
106
+ train_data_mask : Optional[Any]
107
+ Training data mask, can be a path to a folder, a list of paths, or a numpy
108
+ array.
102
109
  val_data : Optional[Any]
103
110
  Validation data, can be a path to a folder, a list of paths, or a numpy array.
104
111
  val_data_target : Optional[Any]
@@ -118,7 +125,7 @@ class CareamicsDataModule(L.LightningDataModule):
118
125
  If input and target data types are not consistent.
119
126
  """
120
127
 
121
- # standard use
128
+ # standard use (no mask)
122
129
  @overload
123
130
  def __init__(
124
131
  self,
@@ -136,7 +143,26 @@ class CareamicsDataModule(L.LightningDataModule):
136
143
  use_in_memory: bool = True,
137
144
  ) -> None: ...
138
145
 
139
- # custom read function
146
+ # with training mask for filtering
147
+ @overload
148
+ def __init__(
149
+ self,
150
+ data_config: NGDataConfig,
151
+ *,
152
+ train_data: InputType | None = None,
153
+ train_data_target: InputType | None = None,
154
+ train_data_mask: InputType,
155
+ val_data: InputType | None = None,
156
+ val_data_target: InputType | None = None,
157
+ pred_data: InputType | None = None,
158
+ pred_data_target: InputType | None = None,
159
+ extension_filter: str = "",
160
+ val_percentage: float | None = None,
161
+ val_minimum_split: int = 5,
162
+ use_in_memory: bool = True,
163
+ ) -> None: ...
164
+
165
+ # custom read function (no mask)
140
166
  @overload
141
167
  def __init__(
142
168
  self,
@@ -156,6 +182,48 @@ class CareamicsDataModule(L.LightningDataModule):
156
182
  use_in_memory: bool = True,
157
183
  ) -> None: ...
158
184
 
185
+ # custom read function with training mask
186
+ @overload
187
+ def __init__(
188
+ self,
189
+ data_config: NGDataConfig,
190
+ *,
191
+ train_data: InputType | None = None,
192
+ train_data_target: InputType | None = None,
193
+ train_data_mask: InputType,
194
+ val_data: InputType | None = None,
195
+ val_data_target: InputType | None = None,
196
+ pred_data: InputType | None = None,
197
+ pred_data_target: InputType | None = None,
198
+ read_source_func: Callable,
199
+ read_kwargs: dict[str, Any] | None = None,
200
+ extension_filter: str = "",
201
+ val_percentage: float | None = None,
202
+ val_minimum_split: int = 5,
203
+ use_in_memory: bool = True,
204
+ ) -> None: ...
205
+
206
+ # image stack loader (no mask)
207
+ @overload
208
+ def __init__(
209
+ self,
210
+ data_config: NGDataConfig,
211
+ *,
212
+ train_data: Any | None = None,
213
+ train_data_target: Any | None = None,
214
+ val_data: Any | None = None,
215
+ val_data_target: Any | None = None,
216
+ pred_data: Any | None = None,
217
+ pred_data_target: Any | None = None,
218
+ image_stack_loader: ImageStackLoader,
219
+ image_stack_loader_kwargs: dict[str, Any] | None = None,
220
+ extension_filter: str = "",
221
+ val_percentage: float | None = None,
222
+ val_minimum_split: int = 5,
223
+ use_in_memory: bool = True,
224
+ ) -> None: ...
225
+
226
+ # image stack loader with training mask
159
227
  @overload
160
228
  def __init__(
161
229
  self,
@@ -163,6 +231,7 @@ class CareamicsDataModule(L.LightningDataModule):
163
231
  *,
164
232
  train_data: Any | None = None,
165
233
  train_data_target: Any | None = None,
234
+ train_data_mask: Any,
166
235
  val_data: Any | None = None,
167
236
  val_data_target: Any | None = None,
168
237
  pred_data: Any | None = None,
@@ -181,6 +250,7 @@ class CareamicsDataModule(L.LightningDataModule):
181
250
  *,
182
251
  train_data: Any | None = None,
183
252
  train_data_target: Any | None = None,
253
+ train_data_mask: Any | None = None,
184
254
  val_data: Any | None = None,
185
255
  val_data_target: Any | None = None,
186
256
  pred_data: Any | None = None,
@@ -209,6 +279,10 @@ class CareamicsDataModule(L.LightningDataModule):
209
279
  train_data_target : Optional[InputType]
210
280
  Training data target, can be a path to a folder,
211
281
  a list of paths, or a numpy array.
282
+ train_data_mask : InputType (when filtering is needed)
283
+ Training data mask, can be a path to a folder,
284
+ a list of paths, or a numpy array. Used for coordinate filtering.
285
+ Only required when using coordinate-based patch filtering.
212
286
  val_data : Optional[InputType]
213
287
  Validation data, can be a path to a folder,
214
288
  a list of paths, or a numpy array.
@@ -268,6 +342,8 @@ class CareamicsDataModule(L.LightningDataModule):
268
342
  self.train_data, self.train_data_target = self._initialize_data_pair(
269
343
  train_data, train_data_target
270
344
  )
345
+ self.train_data_mask, _ = self._initialize_data_pair(train_data_mask, None)
346
+
271
347
  self.val_data, self.val_data_target = self._initialize_data_pair(
272
348
  val_data, val_data_target
273
349
  )
@@ -574,6 +650,7 @@ class CareamicsDataModule(L.LightningDataModule):
574
650
  mode=Mode.TRAINING,
575
651
  inputs=self.train_data,
576
652
  targets=self.train_data_target,
653
+ masks=self.train_data_mask,
577
654
  config=self.config,
578
655
  in_memory=self.use_in_memory,
579
656
  read_func=self.read_source_func,
@@ -437,7 +437,8 @@ class VAEModule(L.LightningModule):
437
437
  or self.noise_model_likelihood.data_std is None
438
438
  ):
439
439
  raise RuntimeError(
440
- "NoiseModelLikelihood: data_mean and data_std must be set before training."
440
+ "NoiseModelLikelihood: data_mean and data_std must be set before"
441
+ "training."
441
442
  )
442
443
  loss = self.loss_func(
443
444
  model_outputs=out,
@@ -541,9 +542,9 @@ class VAEModule(L.LightningModule):
541
542
  # get reconstructed img
542
543
  if self.model.predict_logvar is None:
543
544
  rec_img = rec
544
- logvar = torch.tensor([-1])
545
+ _logvar = torch.tensor([-1])
545
546
  else:
546
- rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
547
+ rec_img, _logvar = torch.chunk(rec, chunks=2, dim=1)
547
548
  rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
548
549
 
549
550
  # aggregate results
@@ -74,7 +74,8 @@ def load_data(datadir):
74
74
  channel_stack = np.concatenate(
75
75
  channel_images, axis=0
76
76
  ) # FIXME: this line works iff images have
77
- # a singleton channel dimension. Specify in the notebook or change with `torch.stack`??
77
+ # a singleton channel dimension. Specify in the notebook or change with
78
+ # `torch.stack`??
78
79
  channels_data.append(channel_stack)
79
80
 
80
81
  final_data = np.stack(channels_data, axis=-1)
@@ -204,7 +205,8 @@ class MicroSplitDataModule(L.LightningDataModule):
204
205
 
205
206
  def __init__(
206
207
  self,
207
- data_config: MicroSplitDataConfig, # Should be compatible with microSplit DatasetConfig
208
+ # Should be compatible with microSplit DatasetConfig
209
+ data_config: MicroSplitDataConfig,
208
210
  train_data: str,
209
211
  val_data: str | None = None,
210
212
  train_data_target: str | None = None,
@@ -301,7 +303,8 @@ class MicroSplitDataModule(L.LightningDataModule):
301
303
  """
302
304
  return DataLoader(
303
305
  self.train_dataset,
304
- batch_size=self.train_config.batch_size, # TODO should be inside dataloader params?
306
+ # TODO should be inside dataloader params?
307
+ batch_size=self.train_config.batch_size,
305
308
  **self.train_config.train_dataloader_params,
306
309
  )
307
310
 
@@ -355,7 +358,9 @@ def create_microsplit_train_datamodule(
355
358
  **dataset_kwargs,
356
359
  ) -> MicroSplitDataModule:
357
360
  """
358
- Create a MicroSplitDataModule for microSplit-style datasets, including config creation.
361
+ Create a MicroSplitDataModule for microSplit-style datasets.
362
+
363
+ This includes config creation.
359
364
 
360
365
  Parameters
361
366
  ----------
@@ -424,10 +429,10 @@ def create_microsplit_train_datamodule(
424
429
  **dataset_config_params,
425
430
  datasplit_type=DataSplitType.Train,
426
431
  )
427
- val_config = MicroSplitDataConfig(
428
- **dataset_config_params,
429
- datasplit_type=DataSplitType.Val,
430
- )
432
+ # val_config = MicroSplitDataConfig(
433
+ # **dataset_config_params,
434
+ # datasplit_type=DataSplitType.Val,
435
+ # )
431
436
  # TODO, data config is duplicated here and in configuration
432
437
 
433
438
  return MicroSplitDataModule(
@@ -578,10 +583,10 @@ def create_microsplit_predict_datamodule(
578
583
  Grid size for patch extraction.
579
584
  multiscale_count : int, optional
580
585
  Number of LC scales.
581
- tiling_mode : TilingMode, default=ShiftBoundary
582
- Tiling mode for patch extraction.
583
586
  data_stats : tuple, optional
584
587
  Data statistics, by default None.
588
+ tiling_mode : TilingMode, default=ShiftBoundary
589
+ Tiling mode for patch extraction.
585
590
  read_source_func : Callable, optional
586
591
  Function to read the source data.
587
592
  extension_filter : str, optional
@@ -32,7 +32,7 @@ class TilingMode:
32
32
  ShiftBoundary = 2
33
33
 
34
34
 
35
- # ------------------------------------------------------------------------------------------------
35
+ # ------------------------------------------------------------------------------------
36
36
  # Function of plotting: TODO -> moved them to another file, plot_utils.py
37
37
  def clean_ax(ax):
38
38
  """
@@ -68,7 +68,9 @@ def get_psnr_str(tar_hsnr, pred, col_idx):
68
68
  """
69
69
  Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
70
70
  """
71
- return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
71
+ psnr = scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item()
72
+
73
+ return f"{psnr:.1f}"
72
74
 
73
75
 
74
76
  def add_psnr_str(ax_, psnr):
@@ -129,8 +131,10 @@ def show_for_one(
129
131
  baseline_preds=None,
130
132
  ):
131
133
  """
132
- Given an index, it plots the input, target, reconstructed images and the difference image.
133
- Note the the difference image is computed with respect to a ground truth image, obtained from the high SNR dataset.
134
+ Given an index, it plots the input, target, reconstructed images and the difference
135
+ image.
136
+ Note the the difference image is computed with respect to a ground truth image,
137
+ obtained from the high SNR dataset.
134
138
  """
135
139
  highsnr_val_dset.set_img_sz(patch_size, 64)
136
140
  highsnr_val_dset.disable_noise()
@@ -164,7 +168,8 @@ def plot_crops(
164
168
  for i in range(len(baseline_preds)):
165
169
  if baseline_preds[i].shape != tar_hsnr.shape:
166
170
  print(
167
- f"Baseline prediction {i} shape {baseline_preds[i].shape} does not match target shape {tar_hsnr.shape}"
171
+ f"Baseline prediction {i} shape {baseline_preds[i].shape} does not "
172
+ f"match target shape {tar_hsnr.shape}"
168
173
  )
169
174
  print("This happens when we want to predict the edges of the image.")
170
175
  return
@@ -333,14 +338,21 @@ def plot_crops(
333
338
  ax_temp.imshow(inp[0, 0].cpu().numpy(), cmap="magma")
334
339
  clean_ax(ax_temp)
335
340
 
336
- # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-', label='$C_1$')
337
- # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-', label='$C_2$')
338
- # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-', label='Pred')
339
- # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='--', label='$C^N_1$')
340
- # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='--', label='$C^N_2$')
341
- # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred], loc='upper right', frameon=False, labelcolor='white',
341
+ # line_ch1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0], linestyle='-',
342
+ # label='$C_1$')
343
+ # line_ch2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1], linestyle='-',
344
+ # label='$C_2$')
345
+ # line_pred = mlines.Line2D([0, 1], [0, 1], color=color_pred, linestyle='-',
346
+ # label='Pred')
347
+ # line_noisych1 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[0],
348
+ # linestyle='--', label='$C^N_1$')
349
+ # line_noisych2 = mlines.Line2D([0, 1], [0, 1], color=color_ch_list[1],
350
+ # linestyle='--', label='$C^N_2$')
351
+ # legend_ch1 = legend_ch1_ax.legend(handles=[line_ch1, line_noisych1, line_pred],
352
+ # loc='upper right', frameon=False, labelcolor='white',
342
353
  # prop={'size': 11})
343
- # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred], loc='upper right', frameon=False, labelcolor='white',
354
+ # legend_ch2 = legend_ch2_ax.legend(handles=[line_ch2, line_noisych2, line_pred],
355
+ # loc='upper right', frameon=False, labelcolor='white',
344
356
  # prop={'size': 11})
345
357
 
346
358
  if calibration_stats is not None:
@@ -383,7 +395,9 @@ def plot_calibration(ax, calibration_stats):
383
395
 
384
396
  def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name="shiftedcmap"):
385
397
  """
386
- Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib
398
+ Adapted from
399
+ https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-
400
+ matplotlib
387
401
 
388
402
  Function to offset the "center" of a colormap. Useful for
389
403
  data with a negative min and positive max and you want the
@@ -444,7 +458,8 @@ def get_fractional_change(target, prediction, max_val=None):
444
458
 
445
459
  def get_zero_centered_midval(error):
446
460
  """
447
- When done this way, the midval ensures that the colorbar is centered at 0. (Don't know how, but it works ;))
461
+ When done this way, the midval ensures that the colorbar is centered at 0. (Don't
462
+ know how, but it works ;))
448
463
  """
449
464
  vmax = error.max()
450
465
  vmin = error.min()
@@ -670,7 +685,7 @@ def get_single_file_mmse(
670
685
  return stitched_predictions, stitched_stds
671
686
 
672
687
 
673
- # ------------------------------------------------------------------------------------------
688
+ # ---------------------------------------------------------------------------------
674
689
  ### Classes and Functions used to stitch predictions
675
690
  class PatchLocation:
676
691
  """
@@ -698,9 +713,10 @@ def _get_location(extra_padding, hwt, pred_h, pred_w):
698
713
 
699
714
  def get_location_from_idx(dset, dset_input_idx, pred_h, pred_w):
700
715
  """
701
- For a given idx of the dataset, it returns where exactly in the dataset, does this prediction lies.
702
- Note that this prediction also has padded pixels and so a subset of it will be used in the final prediction.
703
- Which time frame, which spatial location (h_start, h_end, w_start,w_end)
716
+ For a given idx of the dataset, it returns where exactly in the dataset, does this
717
+ prediction lies. Note that this prediction also has padded pixels and so a subset of
718
+ it will be used in the final prediction. Which time frame, which spatial location
719
+ (h_start, h_end, w_start,w_end)
704
720
  Args:
705
721
  dset:
706
722
  dset_input_idx:
@@ -785,9 +801,11 @@ def stitch_predictions(predictions, dset, smoothening_pixelcount=0):
785
801
  # NOTE: don't need to compute it for every patch.
786
802
  assert (
787
803
  smoothening_pixelcount == 0
788
- ), "For smoothing,enable the get_smoothing_mask. It is disabled since I don't use it and it needs modification to work with non-square images"
804
+ ), "For smoothing,enable the get_smoothing_mask. It is disabled since I"
805
+ "don't use it and it needs modification to work with non-square images"
789
806
  mask = 1
790
- # mask = _get_smoothing_mask(cropped_pred_i.shape, smoothening_pixelcount, loc, frame_size)
807
+ # mask = _get_smoothing_mask(cropped_pred_i.shape,
808
+ # smoothening_pixelcount, loc, frame_size)
791
809
 
792
810
  cropped_pred_list.append(cropped_pred_i)
793
811
 
@@ -827,7 +845,8 @@ def stitch_predictions_new(predictions, dset):
827
845
  output = np.zeros(shape, dtype=predictions.dtype)
828
846
  # frame_shape = dset.get_data_shape()[:-1]
829
847
  for dset_idx in range(predictions.shape[0]):
830
- # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
848
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
849
+ # predictions.shape[-1])
831
850
  # grid start, grid end
832
851
  gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
833
852
  ge = gs + mng.grid_shape
@@ -843,7 +862,8 @@ def stitch_predictions_new(predictions, dset):
843
862
  vgs = np.array([max(0, x) for x in gs], dtype=int)
844
863
  vge = np.array([min(x, y) for x, y in zip(ge, mng.data_shape)], dtype=int)
845
864
  # assert np.all(vgs == gs)
846
- # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest to dig why it's failing at this point !
865
+ # assert np.all(vge == ge) # TODO comented out this shit cuz I have no interest
866
+ # to dig why it's failing at this point !
847
867
  # print('VGS')
848
868
  # print(gs)
849
869
  # print(ge)
@@ -898,7 +918,8 @@ def stitch_predictions_general(predictions, dset):
898
918
  # frame_shape = dset.get_data_shape()[:-1]
899
919
  for patch_idx in range(predictions.shape[0]):
900
920
  # grid start, grid end
901
- # channel_idx is 0 because during prediction we're only use one channel. # TODO revisit this
921
+ # channel_idx is 0 because during prediction we're only use one channel.
922
+ # # TODO revisit this
902
923
  # 0th dimension is sample index in the output list
903
924
  grid_coords = np.array(
904
925
  mng.get_location_from_patch_idx(channel_idx=0, patch_idx=patch_idx),
@@ -906,7 +927,8 @@ def stitch_predictions_general(predictions, dset):
906
927
  )
907
928
  sample_idx = grid_coords[0]
908
929
  grid_start = grid_coords[1:]
909
- # from here on, coordinates are relative to the sample(file in the list of inputs)
930
+ # from here on, coordinates are relative to the sample(file in the list of
931
+ # inputs)
910
932
  grid_end = grid_start + mng.grid_shape
911
933
 
912
934
  # patch start, patch end
@@ -381,7 +381,8 @@ class NoiseModelLikelihood(LikelihoodModule):
381
381
  """
382
382
  if self.data_mean is None or self.data_std is None:
383
383
  raise RuntimeError(
384
- "NoiseModelLikelihood: data_mean and data_std must be set before calling log_likelihood."
384
+ "NoiseModelLikelihood: data_mean and data_std must be set before"
385
+ "callinglog_likelihood."
385
386
  )
386
387
  self._set_params_to_same_device_as(x)
387
388
  predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
@@ -47,8 +47,9 @@ def convert_outputs_microsplit(
47
47
  """
48
48
  Convert microsplit Lightning trainer outputs using eval_utils stitching functions.
49
49
 
50
- This function processes microsplit predictions that return (tile_prediction, tile_std) tuples
51
- and stitches them back together using the same logic as get_single_file_mmse.
50
+ This function processes microsplit predictions that return (tile_prediction,
51
+ tile_std) tuples and stitches them back together using the same logic as
52
+ get_single_file_mmse.
52
53
 
53
54
  Parameters
54
55
  ----------
@@ -17,13 +17,21 @@ class TilingMode:
17
17
  ShiftBoundary = 2
18
18
 
19
19
 
20
- def stitch_prediction_vae(predictions, dset):
20
+ def stitch_prediction_vae(predictions, dset) -> NDArray:
21
21
  """
22
22
  Stitch predictions back together using dataset's index manager.
23
23
 
24
- Args:
25
- predictions: Array of predictions with shape (n_tiles, channels, height, width)
26
- dset: Dataset object with idx_manager containing tiling information
24
+ Parameters
25
+ ----------
26
+ predictions : np.ndarray
27
+ Array of predictions with shape (n_tiles, channels, height, width).
28
+ dset : Dataset
29
+ Dataset object with idx_manager containing tiling information.
30
+
31
+ Returns
32
+ -------
33
+ np.ndarray
34
+ Stitched array with shape matching the original data shape.
27
35
  """
28
36
  mng = dset.idx_manager
29
37
 
@@ -34,7 +42,8 @@ def stitch_prediction_vae(predictions, dset):
34
42
  output = np.zeros(shape, dtype=predictions.dtype)
35
43
  # frame_shape = dset.get_data_shape()[:-1]
36
44
  for dset_idx in range(predictions.shape[0]):
37
- # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2], predictions.shape[-1])
45
+ # loc = get_location_from_idx(dset, dset_idx, predictions.shape[-2],
46
+ # predictions.shape[-1])
38
47
  # grid start, grid end
39
48
  gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
40
49
  ge = gs + mng.grid_shape
@@ -178,6 +187,8 @@ def stitch_prediction_single(
178
187
 
179
188
  # Insert cropped tile into predicted image using stitch coordinates
180
189
  image_slices = (..., *[slice(c[0], c[1]) for c in tile_info.stitch_coords])
181
- predicted_image[image_slices] = cropped_tile.astype(np.float32)
190
+
191
+ # TODO fix mypy error here, potentially due to numpy 2
192
+ predicted_image[image_slices] = cropped_tile.astype(np.float32) # type: ignore
182
193
 
183
194
  return predicted_image
@@ -21,9 +21,9 @@ def get_careamics_version() -> str:
21
21
  parts = __version__.split(".")
22
22
 
23
23
  # for local installs that do not detect the latest versions via tags
24
- # (typically our CI will install `0.1.devX<hash>` versions)
25
- if "dev" in parts[-1]:
26
- parts[-1] = "*"
24
+ # (typically our CI will install `0.X.devX<hash>.<other hash>` versions)
25
+ if "dev" in parts[2]:
26
+ parts[2] = "*"
27
27
  clean_version = ".".join(parts[:3])
28
28
 
29
29
  logger.warning(
@@ -34,5 +34,5 @@ def get_careamics_version() -> str:
34
34
  f"closest CAREamics version from PyPI or conda-forge."
35
35
  )
36
36
 
37
- # Remove any local version identifier)
37
+ # Remove any local version identifier
38
38
  return ".".join(parts[:3])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: careamics
3
- Version: 0.0.16
3
+ Version: 0.0.17
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -16,30 +16,24 @@ Classifier: Programming Language :: Python :: 3.13
16
16
  Classifier: Typing :: Typed
17
17
  Requires-Python: >=3.11
18
18
  Requires-Dist: bioimageio-core>=0.9.0
19
- Requires-Dist: matplotlib<=3.10.6
19
+ Requires-Dist: matplotlib<=3.10.7
20
20
  Requires-Dist: numpy>=1.21
21
21
  Requires-Dist: numpy>=2.1.0; python_version >= '3.13'
22
22
  Requires-Dist: pillow<=11.3.0
23
23
  Requires-Dist: psutil<=7.1.0
24
- Requires-Dist: pydantic<=2.12,>=2.11
24
+ Requires-Dist: pydantic<=2.12.2,>=2.11
25
25
  Requires-Dist: pytorch-lightning<=2.5.5,>=2.2
26
26
  Requires-Dist: pyyaml!=6.0.0,<=6.0.3
27
27
  Requires-Dist: scikit-image<=0.25.2
28
- Requires-Dist: tifffile<=2025.9.30
28
+ Requires-Dist: tifffile<=2025.10.4
29
29
  Requires-Dist: torch<=2.8.0,>=2.0
30
+ Requires-Dist: torchmetrics<1.5.0,>=0.11.0
30
31
  Requires-Dist: torchvision<=0.23.0
31
32
  Requires-Dist: typer<=0.19.2,>=0.12.3
32
33
  Requires-Dist: validators<=0.35.0
33
34
  Requires-Dist: zarr<4.0.0,>=3.0.0
34
35
  Provides-Extra: czi
35
36
  Requires-Dist: pylibczirw<6.0.0,>=4.1.2; extra == 'czi'
36
- Provides-Extra: dev
37
- Requires-Dist: ml-dtypes>=0.5.0; extra == 'dev'
38
- Requires-Dist: onnx; extra == 'dev'
39
- Requires-Dist: pre-commit; extra == 'dev'
40
- Requires-Dist: pytest; extra == 'dev'
41
- Requires-Dist: pytest-cov; extra == 'dev'
42
- Requires-Dist: sybil; extra == 'dev'
43
37
  Provides-Extra: examples
44
38
  Requires-Dist: careamics-portfolio; extra == 'examples'
45
39
  Requires-Dist: jupyter; extra == 'examples'