careamics 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (79) hide show
  1. careamics/careamist.py +11 -14
  2. careamics/cli/conf.py +18 -3
  3. careamics/config/__init__.py +8 -0
  4. careamics/config/algorithms/__init__.py +4 -0
  5. careamics/config/algorithms/hdn_algorithm_model.py +103 -0
  6. careamics/config/algorithms/microsplit_algorithm_model.py +103 -0
  7. careamics/config/algorithms/n2v_algorithm_model.py +1 -2
  8. careamics/config/algorithms/vae_algorithm_model.py +51 -16
  9. careamics/config/architectures/lvae_model.py +12 -8
  10. careamics/config/callback_model.py +7 -3
  11. careamics/config/configuration.py +15 -63
  12. careamics/config/configuration_factories.py +853 -29
  13. careamics/config/data/data_model.py +50 -11
  14. careamics/config/data/ng_data_model.py +168 -4
  15. careamics/config/data/patch_filter/__init__.py +15 -0
  16. careamics/config/data/patch_filter/filter_model.py +16 -0
  17. careamics/config/data/patch_filter/mask_filter_model.py +17 -0
  18. careamics/config/data/patch_filter/max_filter_model.py +15 -0
  19. careamics/config/data/patch_filter/meanstd_filter_model.py +18 -0
  20. careamics/config/data/patch_filter/shannon_filter_model.py +15 -0
  21. careamics/config/inference_model.py +1 -2
  22. careamics/config/likelihood_model.py +2 -2
  23. careamics/config/loss_model.py +6 -2
  24. careamics/config/nm_model.py +26 -1
  25. careamics/config/optimizer_models.py +1 -2
  26. careamics/config/support/supported_algorithms.py +5 -3
  27. careamics/config/support/supported_filters.py +17 -0
  28. careamics/config/support/supported_losses.py +5 -2
  29. careamics/config/training_model.py +6 -36
  30. careamics/config/transformations/normalize_model.py +1 -2
  31. careamics/dataset_ng/dataset.py +57 -5
  32. careamics/dataset_ng/factory.py +101 -18
  33. careamics/dataset_ng/patch_extractor/demo_custom_image_stack_loader.py +4 -4
  34. careamics/dataset_ng/patch_extractor/image_stack/in_memory_image_stack.py +1 -2
  35. careamics/dataset_ng/patch_extractor/image_stack/zarr_image_stack.py +33 -7
  36. careamics/dataset_ng/patch_extractor/image_stack_loader.py +2 -2
  37. careamics/dataset_ng/patch_filter/__init__.py +20 -0
  38. careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
  39. careamics/dataset_ng/patch_filter/filter_factory.py +94 -0
  40. careamics/dataset_ng/patch_filter/mask_filter.py +95 -0
  41. careamics/dataset_ng/patch_filter/max_filter.py +188 -0
  42. careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
  43. careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
  44. careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
  45. careamics/file_io/read/__init__.py +0 -1
  46. careamics/lightning/__init__.py +16 -2
  47. careamics/lightning/callbacks/__init__.py +2 -0
  48. careamics/lightning/callbacks/data_stats_callback.py +33 -0
  49. careamics/lightning/dataset_ng/data_module.py +79 -2
  50. careamics/lightning/lightning_module.py +162 -61
  51. careamics/lightning/microsplit_data_module.py +636 -0
  52. careamics/lightning/predict_data_module.py +8 -1
  53. careamics/lightning/train_data_module.py +19 -8
  54. careamics/losses/__init__.py +7 -1
  55. careamics/losses/loss_factory.py +9 -1
  56. careamics/losses/lvae/losses.py +85 -0
  57. careamics/lvae_training/dataset/__init__.py +8 -8
  58. careamics/lvae_training/dataset/config.py +56 -44
  59. careamics/lvae_training/dataset/lc_dataset.py +18 -12
  60. careamics/lvae_training/dataset/ms_dataset_ref.py +5 -5
  61. careamics/lvae_training/dataset/multich_dataset.py +24 -18
  62. careamics/lvae_training/dataset/multifile_dataset.py +6 -6
  63. careamics/lvae_training/eval_utils.py +46 -24
  64. careamics/model_io/bmz_io.py +9 -5
  65. careamics/models/lvae/likelihoods.py +31 -14
  66. careamics/models/lvae/lvae.py +2 -2
  67. careamics/models/lvae/noise_models.py +20 -14
  68. careamics/prediction_utils/__init__.py +8 -2
  69. careamics/prediction_utils/prediction_outputs.py +49 -3
  70. careamics/prediction_utils/stitch_prediction.py +83 -1
  71. careamics/transforms/xy_random_rotate90.py +1 -1
  72. careamics/utils/version.py +4 -4
  73. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/METADATA +19 -22
  74. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/RECORD +77 -60
  75. careamics/dataset/zarr_dataset.py +0 -151
  76. careamics/file_io/read/zarr.py +0 -60
  77. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/WHEEL +0 -0
  78. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/entry_points.txt +0 -0
  79. {careamics-0.0.15.dist-info → careamics-0.0.17.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,636 @@
1
+ """MicroSplit data module for training and validation."""
2
+
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ import pytorch_lightning as L
9
+ import tifffile
10
+ from numpy.typing import NDArray
11
+ from torch.utils.data import DataLoader
12
+
13
+ from careamics.dataset.dataset_utils.dataset_utils import reshape_array
14
+ from careamics.lvae_training.dataset import (
15
+ DataSplitType,
16
+ DataType,
17
+ LCMultiChDloader,
18
+ MicroSplitDataConfig,
19
+ )
20
+ from careamics.lvae_training.dataset.types import TilingMode
21
+
22
+
23
+ # TODO refactor
24
+ def load_one_file(fpath):
25
+ """Load a single 2D image file.
26
+
27
+ Parameters
28
+ ----------
29
+ fpath : str or Path
30
+ Path to the image file.
31
+
32
+ Returns
33
+ -------
34
+ numpy.ndarray
35
+ Reshaped image data.
36
+ """
37
+ data = tifffile.imread(fpath)
38
+ if len(data.shape) == 2:
39
+ axes = "YX"
40
+ elif len(data.shape) == 3:
41
+ axes = "SYX"
42
+ elif len(data.shape) == 4:
43
+ axes = "STYX"
44
+ else:
45
+ raise ValueError(f"Invalid data shape: {data.shape}")
46
+ data = reshape_array(data, axes)
47
+ data = data.reshape(-1, data.shape[-2], data.shape[-1])
48
+ return data
49
+
50
+
51
+ # TODO refactor
52
+ def load_data(datadir):
53
+ """Load data from a directory containing channel subdirectories with image files.
54
+
55
+ Parameters
56
+ ----------
57
+ datadir : str or Path
58
+ Path to the data directory containing channel subdirectories.
59
+
60
+ Returns
61
+ -------
62
+ numpy.ndarray
63
+ Stacked array of all channels' data.
64
+ """
65
+ data_path = Path(datadir)
66
+
67
+ channel_dirs = sorted(p for p in data_path.iterdir() if p.is_dir())
68
+ channels_data = []
69
+
70
+ for channel_dir in channel_dirs:
71
+ image_files = sorted(f for f in channel_dir.iterdir() if f.is_file())
72
+ channel_images = [load_one_file(image_path) for image_path in image_files]
73
+
74
+ channel_stack = np.concatenate(
75
+ channel_images, axis=0
76
+ ) # FIXME: this line works iff images have
77
+ # a singleton channel dimension. Specify in the notebook or change with
78
+ # `torch.stack`??
79
+ channels_data.append(channel_stack)
80
+
81
+ final_data = np.stack(channels_data, axis=-1)
82
+ return final_data
83
+
84
+
85
+ # TODO refactor
86
+ def get_datasplit_tuples(val_fraction, test_fraction, data_length):
87
+ """Get train/val/test indices for data splitting.
88
+
89
+ Parameters
90
+ ----------
91
+ val_fraction : float or None
92
+ Fraction of data to use for validation.
93
+ test_fraction : float or None
94
+ Fraction of data to use for testing.
95
+ data_length : int
96
+ Total length of the dataset.
97
+
98
+ Returns
99
+ -------
100
+ tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]
101
+ Training, validation, and test indices.
102
+ """
103
+ indices = np.arange(data_length)
104
+ np.random.shuffle(indices)
105
+
106
+ if val_fraction is None:
107
+ val_fraction = 0.0
108
+ if test_fraction is None:
109
+ test_fraction = 0.0
110
+
111
+ val_size = int(data_length * val_fraction)
112
+ test_size = int(data_length * test_fraction)
113
+ train_size = data_length - val_size - test_size
114
+
115
+ train_idx = indices[:train_size]
116
+ val_idx = indices[train_size : train_size + val_size]
117
+ test_idx = indices[train_size + val_size :]
118
+
119
+ return train_idx, val_idx, test_idx
120
+
121
+
122
+ # TODO refactor
123
+ def get_train_val_data(
124
+ data_config,
125
+ datadir,
126
+ datasplit_type: DataSplitType,
127
+ val_fraction=None,
128
+ test_fraction=None,
129
+ allow_generation=None,
130
+ **kwargs,
131
+ ):
132
+ """Load and split data according to configuration.
133
+
134
+ Parameters
135
+ ----------
136
+ data_config : MicroSplitDataConfig
137
+ Data configuration object.
138
+ datadir : str or Path
139
+ Path to the data directory.
140
+ datasplit_type : DataSplitType
141
+ Type of data split to return.
142
+ val_fraction : float, optional
143
+ Fraction of data to use for validation.
144
+ test_fraction : float, optional
145
+ Fraction of data to use for testing.
146
+ allow_generation : bool, optional
147
+ Whether to allow data generation.
148
+ **kwargs
149
+ Additional keyword arguments.
150
+
151
+ Returns
152
+ -------
153
+ numpy.ndarray
154
+ Split data array.
155
+ """
156
+ data = load_data(datadir)
157
+ train_idx, val_idx, test_idx = get_datasplit_tuples(
158
+ val_fraction, test_fraction, len(data)
159
+ )
160
+
161
+ if datasplit_type == DataSplitType.All:
162
+ data = data.astype(np.float64)
163
+ elif datasplit_type == DataSplitType.Train:
164
+ data = data[train_idx].astype(np.float64)
165
+ elif datasplit_type == DataSplitType.Val:
166
+ data = data[val_idx].astype(np.float64)
167
+ elif datasplit_type == DataSplitType.Test:
168
+ # TODO this is only used for prediction, and only because old dataset uses it
169
+ data = data[test_idx].astype(np.float64)
170
+ else:
171
+ raise Exception("invalid datasplit")
172
+
173
+ return data
174
+
175
+
176
+ class MicroSplitDataModule(L.LightningDataModule):
177
+ """Lightning DataModule for MicroSplit-style datasets.
178
+
179
+ Matches the interface of TrainDataModule, but internally uses original MicroSplit
180
+ dataset logic.
181
+
182
+ Parameters
183
+ ----------
184
+ data_config : MicroSplitDataConfig
185
+ Configuration for the MicroSplit dataset.
186
+ train_data : str
187
+ Path to training data directory.
188
+ val_data : str, optional
189
+ Path to validation data directory.
190
+ train_data_target : str, optional
191
+ Path to training target data.
192
+ val_data_target : str, optional
193
+ Path to validation target data.
194
+ read_source_func : Callable, optional
195
+ Function to read source data.
196
+ extension_filter : str, optional
197
+ File extension filter.
198
+ val_percentage : float, optional
199
+ Percentage of data to use for validation, by default 0.1.
200
+ val_minimum_split : int, optional
201
+ Minimum number of samples for validation split, by default 5.
202
+ use_in_memory : bool, optional
203
+ Whether to use in-memory dataset, by default True.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ # Should be compatible with microSplit DatasetConfig
209
+ data_config: MicroSplitDataConfig,
210
+ train_data: str,
211
+ val_data: str | None = None,
212
+ train_data_target: str | None = None,
213
+ val_data_target: str | None = None,
214
+ read_source_func: Callable | None = None,
215
+ extension_filter: str = "",
216
+ val_percentage: float = 0.1,
217
+ val_minimum_split: int = 5,
218
+ use_in_memory: bool = True,
219
+ ):
220
+ """Initialize MicroSplitDataModule.
221
+
222
+ Parameters
223
+ ----------
224
+ data_config : MicroSplitDataConfig
225
+ Configuration for the MicroSplit dataset.
226
+ train_data : str
227
+ Path to training data directory.
228
+ val_data : str, optional
229
+ Path to validation data directory.
230
+ train_data_target : str, optional
231
+ Path to training target data.
232
+ val_data_target : str, optional
233
+ Path to validation target data.
234
+ read_source_func : Callable, optional
235
+ Function to read source data.
236
+ extension_filter : str, optional
237
+ File extension filter.
238
+ val_percentage : float, optional
239
+ Percentage of data to use for validation, by default 0.1.
240
+ val_minimum_split : int, optional
241
+ Minimum number of samples for validation split, by default 5.
242
+ use_in_memory : bool, optional
243
+ Whether to use in-memory dataset, by default True.
244
+ """
245
+ super().__init__()
246
+ # Dataset selection logic (adapted from create_train_val_datasets)
247
+ self.train_config = data_config # SHould configs be separated?
248
+ self.val_config = data_config
249
+ self.test_config = data_config
250
+
251
+ datapath = train_data
252
+ load_data_func = read_source_func
253
+
254
+ dataset_class = LCMultiChDloader # TODO hardcoded for now
255
+
256
+ # Create datasets
257
+ self.train_dataset = dataset_class(
258
+ self.train_config,
259
+ datapath,
260
+ load_data_fn=load_data_func,
261
+ val_fraction=val_percentage,
262
+ test_fraction=0.1,
263
+ )
264
+ max_val = self.train_dataset.get_max_val()
265
+ self.val_config.max_val = max_val
266
+ if self.train_config.datasplit_type == DataSplitType.All:
267
+ self.val_config.datasplit_type = DataSplitType.All
268
+ self.test_config.datasplit_type = DataSplitType.All
269
+ self.val_dataset = dataset_class(
270
+ self.val_config,
271
+ datapath,
272
+ load_data_fn=load_data_func,
273
+ val_fraction=val_percentage,
274
+ test_fraction=0.1,
275
+ )
276
+ self.test_config.max_val = max_val
277
+ self.test_dataset = dataset_class(
278
+ self.test_config,
279
+ datapath,
280
+ load_data_fn=load_data_func,
281
+ val_fraction=val_percentage,
282
+ test_fraction=0.1,
283
+ )
284
+ mean_val, std_val = self.train_dataset.compute_mean_std()
285
+ self.train_dataset.set_mean_std(mean_val, std_val)
286
+ self.val_dataset.set_mean_std(mean_val, std_val)
287
+ self.test_dataset.set_mean_std(mean_val, std_val)
288
+ data_stats = self.train_dataset.get_mean_std()
289
+
290
+ # Store data statistics
291
+ self.data_stats = (
292
+ data_stats[0],
293
+ data_stats[1],
294
+ ) # TODO repeats old logic, revisit
295
+
296
+ def train_dataloader(self):
297
+ """Create a dataloader for training.
298
+
299
+ Returns
300
+ -------
301
+ DataLoader
302
+ Training dataloader.
303
+ """
304
+ return DataLoader(
305
+ self.train_dataset,
306
+ # TODO should be inside dataloader params?
307
+ batch_size=self.train_config.batch_size,
308
+ **self.train_config.train_dataloader_params,
309
+ )
310
+
311
+ def val_dataloader(self):
312
+ """Create a dataloader for validation.
313
+
314
+ Returns
315
+ -------
316
+ DataLoader
317
+ Validation dataloader.
318
+ """
319
+ return DataLoader(
320
+ self.val_dataset,
321
+ batch_size=self.train_config.batch_size,
322
+ **self.val_config.val_dataloader_params, # TODO duplicated
323
+ )
324
+
325
+ def get_data_stats(self):
326
+ """Get data statistics.
327
+
328
+ Returns
329
+ -------
330
+ tuple[dict, dict]
331
+ A tuple containing two dictionaries:
332
+ - data_mean: mean values for input and target
333
+ - data_std: standard deviation values for input and target
334
+ """
335
+ return self.data_stats, self.val_config.max_val # TODO should be in the config?
336
+
337
+
338
+ def create_microsplit_train_datamodule(
339
+ train_data: str,
340
+ patch_size: tuple,
341
+ data_type: DataType,
342
+ axes: str, # TODO should be there after refactoring
343
+ batch_size: int,
344
+ val_data: str | None = None,
345
+ num_channels: int = 2,
346
+ depth3D: int = 1,
347
+ grid_size: tuple | None = None,
348
+ multiscale_count: int | None = None,
349
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
350
+ read_source_func: Callable | None = None, # TODO should be there after refactoring
351
+ extension_filter: str = "",
352
+ val_percentage: float = 0.1,
353
+ val_minimum_split: int = 5,
354
+ use_in_memory: bool = True,
355
+ transforms: list | None = None, # TODO should it be here?
356
+ train_dataloader_params: dict | None = None,
357
+ val_dataloader_params: dict | None = None,
358
+ **dataset_kwargs,
359
+ ) -> MicroSplitDataModule:
360
+ """
361
+ Create a MicroSplitDataModule for microSplit-style datasets.
362
+
363
+ This includes config creation.
364
+
365
+ Parameters
366
+ ----------
367
+ train_data : str
368
+ Path to training data.
369
+ patch_size : tuple
370
+ Size of one patch of data.
371
+ data_type : DataType
372
+ Type of the dataset (must be a DataType enum value).
373
+ axes : str
374
+ Axes of the data (e.g., 'SYX').
375
+ batch_size : int
376
+ Batch size for dataloaders.
377
+ val_data : str, optional
378
+ Path to validation data.
379
+ num_channels : int, default=2
380
+ Number of channels in the input.
381
+ depth3D : int, default=1
382
+ Number of slices in 3D.
383
+ grid_size : tuple, optional
384
+ Grid size for patch extraction.
385
+ multiscale_count : int, optional
386
+ Number of LC scales.
387
+ tiling_mode : TilingMode, default=ShiftBoundary
388
+ Tiling mode for patch extraction.
389
+ read_source_func : Callable, optional
390
+ Function to read the source data.
391
+ extension_filter : str, optional
392
+ File extension filter.
393
+ val_percentage : float, default=0.1
394
+ Percentage of training data to use for validation.
395
+ val_minimum_split : int, default=5
396
+ Minimum number of patches/files for validation split.
397
+ use_in_memory : bool, default=True
398
+ Use in-memory dataset if possible.
399
+ transforms : list, optional
400
+ List of transforms to apply.
401
+ train_dataloader_params : dict, optional
402
+ Parameters for training dataloader.
403
+ val_dataloader_params : dict, optional
404
+ Parameters for validation dataloader.
405
+ **dataset_kwargs :
406
+ Additional arguments passed to DatasetConfig.
407
+
408
+ Returns
409
+ -------
410
+ MicroSplitDataModule
411
+ Configured MicroSplitDataModule instance.
412
+ """
413
+ # Create dataset configs with only valid parameters
414
+ dataset_config_params = {
415
+ "data_type": data_type,
416
+ "image_size": patch_size,
417
+ "num_channels": num_channels,
418
+ "depth3D": depth3D,
419
+ "grid_size": grid_size,
420
+ "multiscale_lowres_count": multiscale_count,
421
+ "tiling_mode": tiling_mode,
422
+ "batch_size": batch_size,
423
+ "train_dataloader_params": train_dataloader_params,
424
+ "val_dataloader_params": val_dataloader_params,
425
+ **dataset_kwargs,
426
+ }
427
+
428
+ train_config = MicroSplitDataConfig(
429
+ **dataset_config_params,
430
+ datasplit_type=DataSplitType.Train,
431
+ )
432
+ # val_config = MicroSplitDataConfig(
433
+ # **dataset_config_params,
434
+ # datasplit_type=DataSplitType.Val,
435
+ # )
436
+ # TODO, data config is duplicated here and in configuration
437
+
438
+ return MicroSplitDataModule(
439
+ data_config=train_config,
440
+ train_data=train_data,
441
+ val_data=val_data or train_data,
442
+ train_data_target=None,
443
+ val_data_target=None,
444
+ read_source_func=get_train_val_data, # Use our wrapped function
445
+ extension_filter=extension_filter,
446
+ val_percentage=val_percentage,
447
+ val_minimum_split=val_minimum_split,
448
+ use_in_memory=use_in_memory,
449
+ )
450
+
451
+
452
+ class MicroSplitPredictDataModule(L.LightningDataModule):
453
+ """Lightning DataModule for MicroSplit-style prediction datasets.
454
+
455
+ Matches the interface of PredictDataModule, but internally uses MicroSplit
456
+ dataset logic for prediction.
457
+
458
+ Parameters
459
+ ----------
460
+ pred_config : MicroSplitDataConfig
461
+ Configuration for MicroSplit prediction.
462
+ pred_data : str or Path or numpy.ndarray
463
+ Prediction data, can be a path to a folder, a file or a numpy array.
464
+ read_source_func : Callable, optional
465
+ Function to read custom types.
466
+ extension_filter : str, optional
467
+ Filter to filter file extensions for custom types.
468
+ dataloader_params : dict, optional
469
+ Dataloader parameters.
470
+ """
471
+
472
+ def __init__(
473
+ self,
474
+ pred_config: MicroSplitDataConfig,
475
+ pred_data: Union[str, Path, NDArray],
476
+ read_source_func: Callable | None = None,
477
+ extension_filter: str = "",
478
+ dataloader_params: dict | None = None,
479
+ ) -> None:
480
+ """
481
+ Constructor for MicroSplit prediction data module.
482
+
483
+ Parameters
484
+ ----------
485
+ pred_config : MicroSplitDataConfig
486
+ Configuration for MicroSplit prediction.
487
+ pred_data : str or Path or numpy.ndarray
488
+ Prediction data, can be a path to a folder, a file or a numpy array.
489
+ read_source_func : Callable, optional
490
+ Function to read custom types, by default None.
491
+ extension_filter : str, optional
492
+ Filter to filter file extensions for custom types, by default "".
493
+ dataloader_params : dict, optional
494
+ Dataloader parameters, by default {}.
495
+ """
496
+ super().__init__()
497
+
498
+ if dataloader_params is None:
499
+ dataloader_params = {}
500
+ self.pred_config = pred_config
501
+ self.pred_data = pred_data
502
+ self.read_source_func = read_source_func or get_train_val_data
503
+ self.extension_filter = extension_filter
504
+ self.dataloader_params = dataloader_params
505
+
506
+ def prepare_data(self) -> None:
507
+ """Hook used to prepare the data before calling `setup`."""
508
+ # # TODO currently data preparation is handled in dataset creation, revisit!
509
+ pass
510
+
511
+ def setup(self, stage: str | None = None) -> None:
512
+ """
513
+ Hook called at the beginning of predict.
514
+
515
+ Parameters
516
+ ----------
517
+ stage : Optional[str], optional
518
+ Stage, by default None.
519
+ """
520
+ # Create prediction dataset using LCMultiChDloader
521
+ self.predict_dataset = LCMultiChDloader(
522
+ self.pred_config,
523
+ self.pred_data,
524
+ load_data_fn=self.read_source_func,
525
+ val_fraction=0.0, # No validation split for prediction
526
+ test_fraction=1.0, # No test split for prediction
527
+ )
528
+ self.predict_dataset.set_mean_std(*self.pred_config.data_stats)
529
+
530
+ def predict_dataloader(self) -> DataLoader:
531
+ """
532
+ Create a dataloader for prediction.
533
+
534
+ Returns
535
+ -------
536
+ DataLoader
537
+ Prediction dataloader.
538
+ """
539
+ return DataLoader(
540
+ self.predict_dataset,
541
+ batch_size=self.pred_config.batch_size,
542
+ **self.dataloader_params,
543
+ )
544
+
545
+
546
+ def create_microsplit_predict_datamodule(
547
+ pred_data: Union[str, Path, NDArray],
548
+ tile_size: tuple,
549
+ data_type: DataType,
550
+ axes: str,
551
+ batch_size: int = 1,
552
+ num_channels: int = 2,
553
+ depth3D: int = 1,
554
+ grid_size: int | None = None,
555
+ multiscale_count: int | None = None,
556
+ data_stats: tuple | None = None,
557
+ tiling_mode: TilingMode = TilingMode.ShiftBoundary,
558
+ read_source_func: Callable | None = None,
559
+ extension_filter: str = "",
560
+ dataloader_params: dict | None = None,
561
+ **dataset_kwargs,
562
+ ) -> MicroSplitPredictDataModule:
563
+ """
564
+ Create a MicroSplitPredictDataModule for microSplit-style prediction datasets.
565
+
566
+ Parameters
567
+ ----------
568
+ pred_data : str or Path or numpy.ndarray
569
+ Prediction data, can be a path to a folder, a file or a numpy array.
570
+ tile_size : tuple
571
+ Size of one tile of data.
572
+ data_type : DataType
573
+ Type of the dataset (must be a DataType enum value).
574
+ axes : str
575
+ Axes of the data (e.g., 'SYX').
576
+ batch_size : int, default=1
577
+ Batch size for prediction dataloader.
578
+ num_channels : int, default=2
579
+ Number of channels in the input.
580
+ depth3D : int, default=1
581
+ Number of slices in 3D.
582
+ grid_size : tuple, optional
583
+ Grid size for patch extraction.
584
+ multiscale_count : int, optional
585
+ Number of LC scales.
586
+ data_stats : tuple, optional
587
+ Data statistics, by default None.
588
+ tiling_mode : TilingMode, default=ShiftBoundary
589
+ Tiling mode for patch extraction.
590
+ read_source_func : Callable, optional
591
+ Function to read the source data.
592
+ extension_filter : str, optional
593
+ File extension filter.
594
+ dataloader_params : dict, optional
595
+ Parameters for prediction dataloader.
596
+ **dataset_kwargs :
597
+ Additional arguments passed to MicroSplitDataConfig.
598
+
599
+ Returns
600
+ -------
601
+ MicroSplitPredictDataModule
602
+ Configured MicroSplitPredictDataModule instance.
603
+ """
604
+ if dataloader_params is None:
605
+ dataloader_params = {}
606
+
607
+ # Create prediction config with only valid parameters
608
+ prediction_config_params = {
609
+ "data_type": data_type,
610
+ "image_size": tile_size,
611
+ "num_channels": num_channels,
612
+ "depth3D": depth3D,
613
+ "grid_size": grid_size,
614
+ "multiscale_lowres_count": multiscale_count,
615
+ "data_stats": data_stats,
616
+ "tiling_mode": tiling_mode,
617
+ "batch_size": batch_size,
618
+ "datasplit_type": DataSplitType.Test, # For prediction, use all data
619
+ **dataset_kwargs,
620
+ }
621
+
622
+ pred_config = MicroSplitDataConfig(**prediction_config_params)
623
+
624
+ # Remove batch_size from dataloader_params if present
625
+ if "batch_size" in dataloader_params:
626
+ del dataloader_params["batch_size"]
627
+
628
+ return MicroSplitPredictDataModule(
629
+ pred_config=pred_config,
630
+ pred_data=pred_data,
631
+ read_source_func=(
632
+ read_source_func if read_source_func is not None else get_train_val_data
633
+ ),
634
+ extension_filter=extension_filter,
635
+ dataloader_params=dataloader_params,
636
+ )
@@ -217,11 +217,18 @@ class PredictDataModule(L.LightningDataModule):
217
217
  DataLoader
218
218
  Prediction dataloader.
219
219
  """
220
+ # For tiled predictions, we need to ensure tiles are processed in order
221
+ # to avoid stitching artifacts. Multi-worker processing can return batches
222
+ # out of order, so we disable it for tiled predictions.
223
+ dataloader_params = self.dataloader_params.copy()
224
+ if self.tiled:
225
+ dataloader_params["num_workers"] = 0
226
+
220
227
  return DataLoader(
221
228
  self.predict_dataset,
222
229
  batch_size=self.batch_size,
223
230
  collate_fn=collate_tiles if self.tiled else None,
224
- **self.dataloader_params,
231
+ **dataloader_params,
225
232
  )
226
233
 
227
234