careamics 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (92) hide show
  1. careamics/careamist.py +55 -61
  2. careamics/cli/conf.py +24 -9
  3. careamics/cli/main.py +8 -8
  4. careamics/cli/utils.py +2 -4
  5. careamics/config/__init__.py +8 -0
  6. careamics/config/algorithms/__init__.py +4 -0
  7. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  8. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  9. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  10. careamics/config/algorithms/vae_algorithm_model.py +53 -18
  11. careamics/config/architectures/lvae_model.py +12 -8
  12. careamics/config/callback_model.py +15 -11
  13. careamics/config/configuration.py +9 -8
  14. careamics/config/configuration_factories.py +892 -78
  15. careamics/config/data/data_model.py +7 -14
  16. careamics/config/data/ng_data_model.py +8 -15
  17. careamics/config/data/patching_strategies/_overlapping_patched_model.py +4 -5
  18. careamics/config/inference_model.py +6 -11
  19. careamics/config/likelihood_model.py +4 -4
  20. careamics/config/loss_model.py +6 -2
  21. careamics/config/nm_model.py +30 -7
  22. careamics/config/optimizer_models.py +1 -2
  23. careamics/config/support/supported_algorithms.py +5 -3
  24. careamics/config/support/supported_losses.py +5 -2
  25. careamics/config/training_model.py +8 -38
  26. careamics/config/transformations/normalize_model.py +3 -4
  27. careamics/config/transformations/xy_flip_model.py +2 -2
  28. careamics/config/transformations/xy_random_rotate90_model.py +2 -2
  29. careamics/config/validators/validator_utils.py +1 -2
  30. careamics/dataset/dataset_utils/iterate_over_files.py +3 -3
  31. careamics/dataset/in_memory_dataset.py +2 -2
  32. careamics/dataset/iterable_dataset.py +1 -2
  33. careamics/dataset/patching/random_patching.py +6 -6
  34. careamics/dataset/patching/sequential_patching.py +4 -4
  35. careamics/dataset/tiling/lvae_tiled_patching.py +2 -2
  36. careamics/dataset_ng/dataset.py +3 -3
  37. careamics/dataset_ng/factory.py +19 -19
  38. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  39. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  40. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  41. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  42. careamics/dataset_ng/patching_strategies/random_patching.py +2 -3
  43. careamics/dataset_ng/patching_strategies/sequential_patching.py +1 -2
  44. careamics/file_io/read/__init__.py +0 -1
  45. careamics/lightning/__init__.py +16 -2
  46. careamics/lightning/callbacks/__init__.py +2 -0
  47. careamics/lightning/callbacks/data_stats_callback.py +23 -0
  48. careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +5 -5
  49. careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +5 -5
  50. careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +8 -8
  51. careamics/lightning/dataset_ng/data_module.py +43 -43
  52. careamics/lightning/lightning_module.py +166 -68
  53. careamics/lightning/microsplit_data_module.py +631 -0
  54. careamics/lightning/predict_data_module.py +16 -9
  55. careamics/lightning/train_data_module.py +29 -18
  56. careamics/losses/__init__.py +7 -1
  57. careamics/losses/loss_factory.py +9 -1
  58. careamics/losses/lvae/losses.py +94 -9
  59. careamics/lvae_training/dataset/__init__.py +8 -8
  60. careamics/lvae_training/dataset/config.py +56 -44
  61. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  62. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  63. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  64. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  65. careamics/model_io/bioimage/model_description.py +12 -11
  66. careamics/model_io/bmz_io.py +12 -8
  67. careamics/models/layers.py +5 -5
  68. careamics/models/lvae/likelihoods.py +30 -14
  69. careamics/models/lvae/lvae.py +2 -2
  70. careamics/models/lvae/noise_models.py +20 -14
  71. careamics/prediction_utils/__init__.py +8 -2
  72. careamics/prediction_utils/lvae_prediction.py +5 -5
  73. careamics/prediction_utils/prediction_outputs.py +48 -3
  74. careamics/prediction_utils/stitch_prediction.py +71 -0
  75. careamics/transforms/compose.py +9 -9
  76. careamics/transforms/n2v_manipulate.py +3 -3
  77. careamics/transforms/n2v_manipulate_torch.py +4 -4
  78. careamics/transforms/normalize.py +4 -6
  79. careamics/transforms/pixel_manipulation.py +6 -8
  80. careamics/transforms/pixel_manipulation_torch.py +5 -7
  81. careamics/transforms/xy_flip.py +3 -5
  82. careamics/transforms/xy_random_rotate90.py +4 -6
  83. careamics/utils/logging.py +8 -8
  84. careamics/utils/metrics.py +2 -2
  85. careamics/utils/plotting.py +1 -3
  86. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/METADATA +18 -16
  87. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/RECORD +90 -88
  88. careamics/dataset/zarr_dataset.py +0 -151
  89. careamics/file_io/read/zarr.py +0 -60
  90. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/WHEEL +0 -0
  91. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/entry_points.txt +0 -0
  92. {careamics-0.0.14.dist-info → careamics-0.0.16.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,631 @@
1
+ """MicroSplit data module for training and validation."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ import tifffile
10
+ from numpy.typing import NDArray
11
+ from torch.utils.data import DataLoader
12
+
13
+ from careamics.dataset.dataset_utils.dataset_utils import reshape_array
14
+ from careamics.lvae_training.dataset import (
15
+ DataSplitType,
16
+ DataType,
17
+ LCMultiChDloader,
18
+ MicroSplitDataConfig,
19
+ )
20
+ from careamics.lvae_training.dataset.types import TilingMode
21
+
22
+
23
+ # TODO refactor
24
+ def load_one_file(fpath):
25
+ """Load a single 2D image file.
26
+
27
+ Parameters
28
+ ----------
29
+ fpath : str or Path
30
+ Path to the image file.
31
+
32
+ Returns
33
+ -------
34
+ numpy.ndarray
35
+ Reshaped image data.
36
+ """
37
+ data = tifffile.imread(fpath)
38
+ if len(data.shape) == 2:
39
+ axes = "YX"
40
+ elif len(data.shape) == 3:
41
+ axes = "SYX"
42
+ elif len(data.shape) == 4:
43
+ axes = "STYX"
44
+ else:
45
+ raise ValueError(f"Invalid data shape: {data.shape}")
46
+ data = reshape_array(data, axes)
47
+ data = data.reshape(-1, data.shape[-2], data.shape[-1])
48
+ return data
49
+
50
+
51
+ # TODO refactor
52
+ def load_data(datadir):
53
+ """Load data from a directory containing channel subdirectories with image files.
54
+
55
+ Parameters
56
+ ----------
57
+ datadir : str or Path
58
+ Path to the data directory containing channel subdirectories.
59
+
60
+ Returns
61
+ -------
62
+ numpy.ndarray
63
+ Stacked array of all channels' data.
64
+ """
65
+ data_path = Path(datadir)
66
+
67
+ channel_dirs = sorted(p for p in data_path.iterdir() if p.is_dir())
68
+ channels_data = []
69
+
70
+ for channel_dir in channel_dirs:
71
+ image_files = sorted(f for f in channel_dir.iterdir() if f.is_file())
72
+ channel_images = [load_one_file(image_path) for image_path in image_files]
73
+
74
+ channel_stack = np.concatenate(
75
+ channel_images, axis=0
76
+ ) # FIXME: this line works iff images have
77
+ # a singleton channel dimension. Specify in the notebook or change with `torch.stack`??
78
+ channels_data.append(channel_stack)
79
+
80
+ final_data = np.stack(channels_data, axis=-1)
81
+ return final_data
82
+
83
+
84
+ # TODO refactor
85
+ def get_datasplit_tuples(val_fraction, test_fraction, data_length):
86
+ """Get train/val/test indices for data splitting.
87
+
88
+ Parameters
89
+ ----------
90
+ val_fraction : float or None
91
+ Fraction of data to use for validation.
92
+ test_fraction : float or None
93
+ Fraction of data to use for testing.
94
+ data_length : int
95
+ Total length of the dataset.
96
+
97
+ Returns
98
+ -------
99
+ tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
100
+ Training, validation, and test indices.
101
+ """
102
+ indices = np.arange(data_length)
103
+ np.random.shuffle(indices)
104
+
105
+ if val_fraction is None:
106
+ val_fraction = 0.0
107
+ if test_fraction is None:
108
+ test_fraction = 0.0
109
+
110
+ val_size = int(data_length * val_fraction)
111
+ test_size = int(data_length * test_fraction)
112
+ train_size = data_length - val_size - test_size
113
+
114
+ train_idx = indices[:train_size]
115
+ val_idx = indices[train_size : train_size + val_size]
116
+ test_idx = indices[train_size + val_size :]
117
+
118
+ return train_idx, val_idx, test_idx
119
+
120
+
121
+ # TODO refactor
122
+ def get_train_val_data(
123
+ data_config,
124
+ datadir,
125
+ datasplit_type: DataSplitType,
126
+ val_fraction=None,
127
+ test_fraction=None,
128
+ allow_generation=None,
129
+ **kwargs,
130
+ ):
131
+ """Load and split data according to configuration.
132
+
133
+ Parameters
134
+ ----------
135
+ data_config : MicroSplitDataConfig
136
+ Data configuration object.
137
+ datadir : str or Path
138
+ Path to the data directory.
139
+ datasplit_type : DataSplitType
140
+ Type of data split to return.
141
+ val_fraction : float, optional
142
+ Fraction of data to use for validation.
143
+ test_fraction : float, optional
144
+ Fraction of data to use for testing.
145
+ allow_generation : bool, optional
146
+ Whether to allow data generation.
147
+ **kwargs
148
+ Additional keyword arguments.
149
+
150
+ Returns
151
+ -------
152
+ numpy.ndarray
153
+ Split data array.
154
+ """
155
+ data = load_data(datadir)
156
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
157
+ val_fraction, test_fraction, len(data)
158
+ )
159
+
160
+ if datasplit_type == DataSplitType.All:
161
+ data = data.astype(np.float64)
162
+ elif datasplit_type == DataSplitType.Train:
163
+ data = data[train_idx].astype(np.float64)
164
+ elif datasplit_type == DataSplitType.Val:
165
+ data = data[val_idx].astype(np.float64)
166
+ elif datasplit_type == DataSplitType.Test:
167
+ # TODO this is only used for prediction, and only because old dataset uses it
168
+ data = data[test_idx].astype(np.float64)
169
+ else:
170
+ raise Exception("invalid datasplit")
171
+
172
+ return data
173
+
174
+
175
+ class MicroSplitDataModule(L.LightningDataModule):
176
+ """Lightning DataModule for MicroSplit-style datasets.
177
+
178
+ Matches the interface of TrainDataModule, but internally uses original MicroSplit
179
+ dataset logic.
180
+
181
+ Parameters
182
+ ----------
183
+ data_config : MicroSplitDataConfig
184
+ Configuration for the MicroSplit dataset.
185
+ train_data : str
186
+ Path to training data directory.
187
+ val_data : str, optional
188
+ Path to validation data directory.
189
+ train_data_target : str, optional
190
+ Path to training target data.
191
+ val_data_target : str, optional
192
+ Path to validation target data.
193
+ read_source_func : Callable, optional
194
+ Function to read source data.
195
+ extension_filter : str, optional
196
+ File extension filter.
197
+ val_percentage : float, optional
198
+ Percentage of data to use for validation, by default 0.1.
199
+ val_minimum_split : int, optional
200
+ Minimum number of samples for validation split, by default 5.
201
+ use_in_memory : bool, optional
202
+ Whether to use in-memory dataset, by default True.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ data_config: MicroSplitDataConfig, # Should be compatible with microSplit DatasetConfig
208
+ train_data: str,
209
+ val_data: str | None = None,
210
+ train_data_target: str | None = None,
211
+ val_data_target: str | None = None,
212
+ read_source_func: Callable | None = None,
213
+ extension_filter: str = "",
214
+ val_percentage: float = 0.1,
215
+ val_minimum_split: int = 5,
216
+ use_in_memory: bool = True,
217
+ ):
218
+ """Initialize MicroSplitDataModule.
219
+
220
+ Parameters
221
+ ----------
222
+ data_config : MicroSplitDataConfig
223
+ Configuration for the MicroSplit dataset.
224
+ train_data : str
225
+ Path to training data directory.
226
+ val_data : str, optional
227
+ Path to validation data directory.
228
+ train_data_target : str, optional
229
+ Path to training target data.
230
+ val_data_target : str, optional
231
+ Path to validation target data.
232
+ read_source_func : Callable, optional
233
+ Function to read source data.
234
+ extension_filter : str, optional
235
+ File extension filter.
236
+ val_percentage : float, optional
237
+ Percentage of data to use for validation, by default 0.1.
238
+ val_minimum_split : int, optional
239
+ Minimum number of samples for validation split, by default 5.
240
+ use_in_memory : bool, optional
241
+ Whether to use in-memory dataset, by default True.
242
+ """
243
+ super().__init__()
244
+ # Dataset selection logic (adapted from create_train_val_datasets)
245
+ self.train_config = data_config # SHould configs be separated?
246
+ self.val_config = data_config
247
+ self.test_config = data_config
248
+
249
+ datapath = train_data
250
+ load_data_func = read_source_func
251
+
252
+ dataset_class = LCMultiChDloader # TODO hardcoded for now
253
+
254
+ # Create datasets
255
+ self.train_dataset = dataset_class(
256
+ self.train_config,
257
+ datapath,
258
+ load_data_fn=load_data_func,
259
+ val_fraction=val_percentage,
260
+ test_fraction=0.1,
261
+ )
262
+ max_val = self.train_dataset.get_max_val()
263
+ self.val_config.max_val = max_val
264
+ if self.train_config.datasplit_type == DataSplitType.All:
265
+ self.val_config.datasplit_type = DataSplitType.All
266
+ self.test_config.datasplit_type = DataSplitType.All
267
+ self.val_dataset = dataset_class(
268
+ self.val_config,
269
+ datapath,
270
+ load_data_fn=load_data_func,
271
+ val_fraction=val_percentage,
272
+ test_fraction=0.1,
273
+ )
274
+ self.test_config.max_val = max_val
275
+ self.test_dataset = dataset_class(
276
+ self.test_config,
277
+ datapath,
278
+ load_data_fn=load_data_func,
279
+ val_fraction=val_percentage,
280
+ test_fraction=0.1,
281
+ )
282
+ mean_val, std_val = self.train_dataset.compute_mean_std()
283
+ self.train_dataset.set_mean_std(mean_val, std_val)
284
+ self.val_dataset.set_mean_std(mean_val, std_val)
285
+ self.test_dataset.set_mean_std(mean_val, std_val)
286
+ data_stats = self.train_dataset.get_mean_std()
287
+
288
+ # Store data statistics
289
+ self.data_stats = (
290
+ data_stats[0],
291
+ data_stats[1],
292
+ ) # TODO repeats old logic, revisit
293
+
294
+ def train_dataloader(self):
295
+ """Create a dataloader for training.
296
+
297
+ Returns
298
+ -------
299
+ DataLoader
300
+ Training dataloader.
301
+ """
302
+ return DataLoader(
303
+ self.train_dataset,
304
+ batch_size=self.train_config.batch_size, # TODO should be inside dataloader params?
305
+ **self.train_config.train_dataloader_params,
306
+ )
307
+
308
+ def val_dataloader(self):
309
+ """Create a dataloader for validation.
310
+
311
+ Returns
312
+ -------
313
+ DataLoader
314
+ Validation dataloader.
315
+ """
316
+ return DataLoader(
317
+ self.val_dataset,
318
+ batch_size=self.train_config.batch_size,
319
+ **self.val_config.val_dataloader_params, # TODO duplicated
320
+ )
321
+
322
+ def get_data_stats(self):
323
+ """Get data statistics.
324
+
325
+ Returns
326
+ -------
327
+ tuple[dict, dict]
328
+ A tuple containing two dictionaries:
329
+ - data_mean: mean values for input and target
330
+ - data_std: standard deviation values for input and target
331
+ """
332
+ return self.data_stats, self.val_config.max_val # TODO should be in the config?
333
+
334
+
335
+ def create_microsplit_train_datamodule(
336
+ train_data: str,
337
+ patch_size: tuple,
338
+ data_type: DataType,
339
+ axes: str, # TODO should be there after refactoring
340
+ batch_size: int,
341
+ val_data: str | None = None,
342
+ num_channels: int = 2,
343
+ depth3D: int = 1,
344
+ grid_size: tuple | None = None,
345
+ multiscale_count: int | None = None,
346
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
347
+ read_source_func: Callable | None = None, # TODO should be there after refactoring
348
+ extension_filter: str = "",
349
+ val_percentage: float = 0.1,
350
+ val_minimum_split: int = 5,
351
+ use_in_memory: bool = True,
352
+ transforms: list | None = None, # TODO should it be here?
353
+ train_dataloader_params: dict | None = None,
354
+ val_dataloader_params: dict | None = None,
355
+ **dataset_kwargs,
356
+ ) -> MicroSplitDataModule:
357
+ """
358
+ Create a MicroSplitDataModule for microSplit-style datasets, including config creation.
359
+
360
+ Parameters
361
+ ----------
362
+ train_data : str
363
+ Path to training data.
364
+ patch_size : tuple
365
+ Size of one patch of data.
366
+ data_type : DataType
367
+ Type of the dataset (must be a DataType enum value).
368
+ axes : str
369
+ Axes of the data (e.g., 'SYX').
370
+ batch_size : int
371
+ Batch size for dataloaders.
372
+ val_data : str, optional
373
+ Path to validation data.
374
+ num_channels : int, default=2
375
+ Number of channels in the input.
376
+ depth3D : int, default=1
377
+ Number of slices in 3D.
378
+ grid_size : tuple, optional
379
+ Grid size for patch extraction.
380
+ multiscale_count : int, optional
381
+ Number of LC scales.
382
+ tiling_mode : TilingMode, default=ShiftBoundary
383
+ Tiling mode for patch extraction.
384
+ read_source_func : Callable, optional
385
+ Function to read the source data.
386
+ extension_filter : str, optional
387
+ File extension filter.
388
+ val_percentage : float, default=0.1
389
+ Percentage of training data to use for validation.
390
+ val_minimum_split : int, default=5
391
+ Minimum number of patches/files for validation split.
392
+ use_in_memory : bool, default=True
393
+ Use in-memory dataset if possible.
394
+ transforms : list, optional
395
+ List of transforms to apply.
396
+ train_dataloader_params : dict, optional
397
+ Parameters for training dataloader.
398
+ val_dataloader_params : dict, optional
399
+ Parameters for validation dataloader.
400
+ **dataset_kwargs :
401
+ Additional arguments passed to DatasetConfig.
402
+
403
+ Returns
404
+ -------
405
+ MicroSplitDataModule
406
+ Configured MicroSplitDataModule instance.
407
+ """
408
+ # Create dataset configs with only valid parameters
409
+ dataset_config_params = {
410
+ "data_type": data_type,
411
+ "image_size": patch_size,
412
+ "num_channels": num_channels,
413
+ "depth3D": depth3D,
414
+ "grid_size": grid_size,
415
+ "multiscale_lowres_count": multiscale_count,
416
+ "tiling_mode": tiling_mode,
417
+ "batch_size": batch_size,
418
+ "train_dataloader_params": train_dataloader_params,
419
+ "val_dataloader_params": val_dataloader_params,
420
+ **dataset_kwargs,
421
+ }
422
+
423
+ train_config = MicroSplitDataConfig(
424
+ **dataset_config_params,
425
+ datasplit_type=DataSplitType.Train,
426
+ )
427
+ val_config = MicroSplitDataConfig(
428
+ **dataset_config_params,
429
+ datasplit_type=DataSplitType.Val,
430
+ )
431
+ # TODO, data config is duplicated here and in configuration
432
+
433
+ return MicroSplitDataModule(
434
+ data_config=train_config,
435
+ train_data=train_data,
436
+ val_data=val_data or train_data,
437
+ train_data_target=None,
438
+ val_data_target=None,
439
+ read_source_func=get_train_val_data, # Use our wrapped function
440
+ extension_filter=extension_filter,
441
+ val_percentage=val_percentage,
442
+ val_minimum_split=val_minimum_split,
443
+ use_in_memory=use_in_memory,
444
+ )
445
+
446
+
447
+ class MicroSplitPredictDataModule(L.LightningDataModule):
448
+ """Lightning DataModule for MicroSplit-style prediction datasets.
449
+
450
+ Matches the interface of PredictDataModule, but internally uses MicroSplit
451
+ dataset logic for prediction.
452
+
453
+ Parameters
454
+ ----------
455
+ pred_config : MicroSplitDataConfig
456
+ Configuration for MicroSplit prediction.
457
+ pred_data : str or Path or numpy.ndarray
458
+ Prediction data, can be a path to a folder, a file or a numpy array.
459
+ read_source_func : Callable, optional
460
+ Function to read custom types.
461
+ extension_filter : str, optional
462
+ Filter to filter file extensions for custom types.
463
+ dataloader_params : dict, optional
464
+ Dataloader parameters.
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ pred_config: MicroSplitDataConfig,
470
+ pred_data: Union[str, Path, NDArray],
471
+ read_source_func: Callable | None = None,
472
+ extension_filter: str = "",
473
+ dataloader_params: dict | None = None,
474
+ ) -> None:
475
+ """
476
+ Constructor for MicroSplit prediction data module.
477
+
478
+ Parameters
479
+ ----------
480
+ pred_config : MicroSplitDataConfig
481
+ Configuration for MicroSplit prediction.
482
+ pred_data : str or Path or numpy.ndarray
483
+ Prediction data, can be a path to a folder, a file or a numpy array.
484
+ read_source_func : Callable, optional
485
+ Function to read custom types, by default None.
486
+ extension_filter : str, optional
487
+ Filter to filter file extensions for custom types, by default "".
488
+ dataloader_params : dict, optional
489
+ Dataloader parameters, by default {}.
490
+ """
491
+ super().__init__()
492
+
493
+ if dataloader_params is None:
494
+ dataloader_params = {}
495
+ self.pred_config = pred_config
496
+ self.pred_data = pred_data
497
+ self.read_source_func = read_source_func or get_train_val_data
498
+ self.extension_filter = extension_filter
499
+ self.dataloader_params = dataloader_params
500
+
501
+ def prepare_data(self) -> None:
502
+ """Hook used to prepare the data before calling `setup`."""
503
+ # # TODO currently data preparation is handled in dataset creation, revisit!
504
+ pass
505
+
506
+ def setup(self, stage: str | None = None) -> None:
507
+ """
508
+ Hook called at the beginning of predict.
509
+
510
+ Parameters
511
+ ----------
512
+ stage : Optional[str], optional
513
+ Stage, by default None.
514
+ """
515
+ # Create prediction dataset using LCMultiChDloader
516
+ self.predict_dataset = LCMultiChDloader(
517
+ self.pred_config,
518
+ self.pred_data,
519
+ load_data_fn=self.read_source_func,
520
+ val_fraction=0.0, # No validation split for prediction
521
+ test_fraction=1.0, # No test split for prediction
522
+ )
523
+ self.predict_dataset.set_mean_std(*self.pred_config.data_stats)
524
+
525
+ def predict_dataloader(self) -> DataLoader:
526
+ """
527
+ Create a dataloader for prediction.
528
+
529
+ Returns
530
+ -------
531
+ DataLoader
532
+ Prediction dataloader.
533
+ """
534
+ return DataLoader(
535
+ self.predict_dataset,
536
+ batch_size=self.pred_config.batch_size,
537
+ **self.dataloader_params,
538
+ )
539
+
540
+
541
+ def create_microsplit_predict_datamodule(
542
+ pred_data: Union[str, Path, NDArray],
543
+ tile_size: tuple,
544
+ data_type: DataType,
545
+ axes: str,
546
+ batch_size: int = 1,
547
+ num_channels: int = 2,
548
+ depth3D: int = 1,
549
+ grid_size: int | None = None,
550
+ multiscale_count: int | None = None,
551
+ data_stats: tuple | None = None,
552
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
553
+ read_source_func: Callable | None = None,
554
+ extension_filter: str = "",
555
+ dataloader_params: dict | None = None,
556
+ **dataset_kwargs,
557
+ ) -> MicroSplitPredictDataModule:
558
+ """
559
+ Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
560
+
561
+ Parameters
562
+ ----------
563
+ pred_data : str or Path or numpy.ndarray
564
+ Prediction data, can be a path to a folder, a file or a numpy array.
565
+ tile_size : tuple
566
+ Size of one tile of data.
567
+ data_type : DataType
568
+ Type of the dataset (must be a DataType enum value).
569
+ axes : str
570
+ Axes of the data (e.g., 'SYX').
571
+ batch_size : int, default=1
572
+ Batch size for prediction dataloader.
573
+ num_channels : int, default=2
574
+ Number of channels in the input.
575
+ depth3D : int, default=1
576
+ Number of slices in 3D.
577
+ grid_size : tuple, optional
578
+ Grid size for patch extraction.
579
+ multiscale_count : int, optional
580
+ Number of LC scales.
581
+ tiling_mode : TilingMode, default=ShiftBoundary
582
+ Tiling mode for patch extraction.
583
+ data_stats : tuple, optional
584
+ Data statistics, by default None.
585
+ read_source_func : Callable, optional
586
+ Function to read the source data.
587
+ extension_filter : str, optional
588
+ File extension filter.
589
+ dataloader_params : dict, optional
590
+ Parameters for prediction dataloader.
591
+ **dataset_kwargs :
592
+ Additional arguments passed to MicroSplitDataConfig.
593
+
594
+ Returns
595
+ -------
596
+ MicroSplitPredictDataModule
597
+ Configured MicroSplitPredictDataModule instance.
598
+ """
599
+ if dataloader_params is None:
600
+ dataloader_params = {}
601
+
602
+ # Create prediction config with only valid parameters
603
+ prediction_config_params = {
604
+ "data_type": data_type,
605
+ "image_size": tile_size,
606
+ "num_channels": num_channels,
607
+ "depth3D": depth3D,
608
+ "grid_size": grid_size,
609
+ "multiscale_lowres_count": multiscale_count,
610
+ "data_stats": data_stats,
611
+ "tiling_mode": tiling_mode,
612
+ "batch_size": batch_size,
613
+ "datasplit_type": DataSplitType.Test, # For prediction, use all data
614
+ **dataset_kwargs,
615
+ }
616
+
617
+ pred_config = MicroSplitDataConfig(**prediction_config_params)
618
+
619
+ # Remove batch_size from dataloader_params if present
620
+ if "batch_size" in dataloader_params:
621
+ del dataloader_params["batch_size"]
622
+
623
+ return MicroSplitPredictDataModule(
624
+ pred_config=pred_config,
625
+ pred_data=pred_data,
626
+ read_source_func=(
627
+ read_source_func if read_source_func is not None else get_train_val_data
628
+ ),
629
+ extension_filter=extension_filter,
630
+ dataloader_params=dataloader_params,
631
+ )
@@ -2,7 +2,7 @@
2
2
 
3
3
  from collections.abc import Callable
4
4
  from pathlib import Path
5
- from typing import Any, Literal, Optional, Union
5
+ from typing import Any, Literal, Union
6
6
 
7
7
  import numpy as np
8
8
  import pytorch_lightning as L
@@ -65,9 +65,9 @@ class PredictDataModule(L.LightningDataModule):
65
65
  self,
66
66
  pred_config: InferenceConfig,
67
67
  pred_data: Union[Path, str, NDArray],
68
- read_source_func: Optional[Callable] = None,
68
+ read_source_func: Callable | None = None,
69
69
  extension_filter: str = "",
70
- dataloader_params: Optional[dict] = None,
70
+ dataloader_params: dict | None = None,
71
71
  ) -> None:
72
72
  """
73
73
  Constructor.
@@ -173,7 +173,7 @@ class PredictDataModule(L.LightningDataModule):
173
173
  self.pred_data, self.data_type, self.extension_filter
174
174
  )
175
175
 
176
- def setup(self, stage: Optional[str] = None) -> None:
176
+ def setup(self, stage: str | None = None) -> None:
177
177
  """
178
178
  Hook called at the beginning of predict.
179
179
 
@@ -217,11 +217,18 @@ class PredictDataModule(L.LightningDataModule):
217
217
  DataLoader
218
218
  Prediction dataloader.
219
219
  """
220
+ # For tiled predictions, we need to ensure tiles are processed in order
221
+ # to avoid stitching artifacts. Multi-worker processing can return batches
222
+ # out of order, so we disable it for tiled predictions.
223
+ dataloader_params = self.dataloader_params.copy()
224
+ if self.tiled:
225
+ dataloader_params["num_workers"] = 0
226
+
220
227
  return DataLoader(
221
228
  self.predict_dataset,
222
229
  batch_size=self.batch_size,
223
230
  collate_fn=collate_tiles if self.tiled else None,
224
- **self.dataloader_params,
231
+ **dataloader_params,
225
232
  )
226
233
 
227
234
 
@@ -231,13 +238,13 @@ def create_predict_datamodule(
231
238
  axes: str,
232
239
  image_means: list[float],
233
240
  image_stds: list[float],
234
- tile_size: Optional[tuple[int, ...]] = None,
235
- tile_overlap: Optional[tuple[int, ...]] = None,
241
+ tile_size: tuple[int, ...] | None = None,
242
+ tile_overlap: tuple[int, ...] | None = None,
236
243
  batch_size: int = 1,
237
244
  tta_transforms: bool = True,
238
- read_source_func: Optional[Callable] = None,
245
+ read_source_func: Callable | None = None,
239
246
  extension_filter: str = "",
240
- dataloader_params: Optional[dict] = None,
247
+ dataloader_params: dict | None = None,
241
248
  ) -> PredictDataModule:
242
249
  """Create a CAREamics prediction Lightning datamodule.
243
250