careamics 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- careamics/__init__.py +24 -0
- careamics/careamist.py +961 -0
- careamics/cli/__init__.py +5 -0
- careamics/cli/conf.py +394 -0
- careamics/cli/main.py +234 -0
- careamics/cli/utils.py +27 -0
- careamics/config/__init__.py +66 -0
- careamics/config/algorithms/__init__.py +21 -0
- careamics/config/algorithms/care_algorithm_config.py +122 -0
- careamics/config/algorithms/hdn_algorithm_config.py +103 -0
- careamics/config/algorithms/microsplit_algorithm_config.py +103 -0
- careamics/config/algorithms/n2n_algorithm_config.py +115 -0
- careamics/config/algorithms/n2v_algorithm_config.py +296 -0
- careamics/config/algorithms/pn2v_algorithm_config.py +301 -0
- careamics/config/algorithms/unet_algorithm_config.py +91 -0
- careamics/config/algorithms/vae_algorithm_config.py +178 -0
- careamics/config/architectures/__init__.py +7 -0
- careamics/config/architectures/architecture_config.py +37 -0
- careamics/config/architectures/lvae_config.py +262 -0
- careamics/config/architectures/unet_config.py +125 -0
- careamics/config/configuration.py +367 -0
- careamics/config/configuration_factories.py +2400 -0
- careamics/config/data/__init__.py +27 -0
- careamics/config/data/data_config.py +472 -0
- careamics/config/data/inference_config.py +237 -0
- careamics/config/data/ng_data_config.py +1038 -0
- careamics/config/data/patch_filter/__init__.py +15 -0
- careamics/config/data/patch_filter/filter_config.py +16 -0
- careamics/config/data/patch_filter/mask_filter_config.py +17 -0
- careamics/config/data/patch_filter/max_filter_config.py +15 -0
- careamics/config/data/patch_filter/meanstd_filter_config.py +18 -0
- careamics/config/data/patch_filter/shannon_filter_config.py +15 -0
- careamics/config/data/patching_strategies/__init__.py +15 -0
- careamics/config/data/patching_strategies/_overlapping_patched_config.py +102 -0
- careamics/config/data/patching_strategies/_patched_config.py +56 -0
- careamics/config/data/patching_strategies/random_patching_config.py +45 -0
- careamics/config/data/patching_strategies/sequential_patching_config.py +25 -0
- careamics/config/data/patching_strategies/tiled_patching_config.py +40 -0
- careamics/config/data/patching_strategies/whole_patching_config.py +12 -0
- careamics/config/data/tile_information.py +65 -0
- careamics/config/lightning/__init__.py +15 -0
- careamics/config/lightning/callbacks/__init__.py +8 -0
- careamics/config/lightning/callbacks/callback_config.py +116 -0
- careamics/config/lightning/optimizer_configs.py +186 -0
- careamics/config/lightning/training_config.py +70 -0
- careamics/config/losses/__init__.py +8 -0
- careamics/config/losses/loss_config.py +60 -0
- careamics/config/ng_configs/__init__.py +5 -0
- careamics/config/ng_configs/n2v_configuration.py +64 -0
- careamics/config/ng_configs/ng_configuration.py +256 -0
- careamics/config/ng_factories/__init__.py +9 -0
- careamics/config/ng_factories/algorithm_factory.py +120 -0
- careamics/config/ng_factories/data_factory.py +154 -0
- careamics/config/ng_factories/n2v_factory.py +256 -0
- careamics/config/ng_factories/training_factory.py +69 -0
- careamics/config/noise_model/__init__.py +12 -0
- careamics/config/noise_model/likelihood_config.py +60 -0
- careamics/config/noise_model/noise_model_config.py +149 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +40 -0
- careamics/config/support/supported_architectures.py +13 -0
- careamics/config/support/supported_data.py +122 -0
- careamics/config/support/supported_filters.py +17 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +32 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_patching_strategies.py +22 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +12 -0
- careamics/config/transformations/__init__.py +22 -0
- careamics/config/transformations/n2v_manipulate_config.py +79 -0
- careamics/config/transformations/normalize_config.py +59 -0
- careamics/config/transformations/transform_config.py +45 -0
- careamics/config/transformations/transform_unions.py +29 -0
- careamics/config/transformations/xy_flip_config.py +43 -0
- careamics/config/transformations/xy_random_rotate90_config.py +35 -0
- careamics/config/utils/__init__.py +8 -0
- careamics/config/utils/configuration_io.py +85 -0
- careamics/config/validators/__init__.py +18 -0
- careamics/config/validators/axes_validators.py +90 -0
- careamics/config/validators/model_validators.py +84 -0
- careamics/config/validators/patch_validators.py +55 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +118 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +84 -0
- careamics/dataset/dataset_utils/running_stats.py +189 -0
- careamics/dataset/in_memory_dataset.py +303 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +131 -0
- careamics/dataset/iterable_dataset.py +294 -0
- careamics/dataset/iterable_pred_dataset.py +121 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +141 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +300 -0
- careamics/dataset/patching/random_patching.py +110 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +375 -0
- careamics/dataset/tiling/tiled_patching.py +166 -0
- careamics/dataset_ng/README.md +212 -0
- careamics/dataset_ng/__init__.py +0 -0
- careamics/dataset_ng/dataset.py +365 -0
- careamics/dataset_ng/demos/bsd68_demo.ipynb +361 -0
- careamics/dataset_ng/demos/bsd68_zarr_demo.ipynb +453 -0
- careamics/dataset_ng/demos/care_U2OS_demo.ipynb +330 -0
- careamics/dataset_ng/demos/demo_custom_image_stack.ipynb +736 -0
- careamics/dataset_ng/demos/demo_datamodule.ipynb +447 -0
- careamics/dataset_ng/demos/demo_dataset.ipynb +278 -0
- careamics/dataset_ng/demos/demo_patch_extractor.py +51 -0
- careamics/dataset_ng/demos/mouse_nuclei_demo.ipynb +293 -0
- careamics/dataset_ng/factory.py +180 -0
- careamics/dataset_ng/grouped_index_sampler.py +73 -0
- careamics/dataset_ng/image_stack/__init__.py +14 -0
- careamics/dataset_ng/image_stack/czi_image_stack.py +396 -0
- careamics/dataset_ng/image_stack/file_image_stack.py +140 -0
- careamics/dataset_ng/image_stack/image_stack_protocol.py +93 -0
- careamics/dataset_ng/image_stack/image_utils/__init__.py +6 -0
- careamics/dataset_ng/image_stack/image_utils/image_stack_utils.py +125 -0
- careamics/dataset_ng/image_stack/in_memory_image_stack.py +93 -0
- careamics/dataset_ng/image_stack/zarr_image_stack.py +170 -0
- careamics/dataset_ng/image_stack_loader/__init__.py +19 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loader_protocol.py +70 -0
- careamics/dataset_ng/image_stack_loader/image_stack_loaders.py +273 -0
- careamics/dataset_ng/image_stack_loader/zarr_utils.py +130 -0
- careamics/dataset_ng/legacy_interoperability.py +175 -0
- careamics/dataset_ng/microsplit_input_synth.py +377 -0
- careamics/dataset_ng/patch_extractor/__init__.py +7 -0
- careamics/dataset_ng/patch_extractor/limit_file_extractor.py +50 -0
- careamics/dataset_ng/patch_extractor/patch_construction.py +151 -0
- careamics/dataset_ng/patch_extractor/patch_extractor.py +117 -0
- careamics/dataset_ng/patch_filter/__init__.py +20 -0
- careamics/dataset_ng/patch_filter/coordinate_filter_protocol.py +27 -0
- careamics/dataset_ng/patch_filter/filter_factory.py +95 -0
- careamics/dataset_ng/patch_filter/mask_filter.py +96 -0
- careamics/dataset_ng/patch_filter/max_filter.py +188 -0
- careamics/dataset_ng/patch_filter/mean_std_filter.py +218 -0
- careamics/dataset_ng/patch_filter/patch_filter_protocol.py +50 -0
- careamics/dataset_ng/patch_filter/shannon_filter.py +188 -0
- careamics/dataset_ng/patching_strategies/__init__.py +26 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_factory.py +50 -0
- careamics/dataset_ng/patching_strategies/patching_strategy_protocol.py +161 -0
- careamics/dataset_ng/patching_strategies/random_patching.py +393 -0
- careamics/dataset_ng/patching_strategies/sequential_patching.py +99 -0
- careamics/dataset_ng/patching_strategies/tiling_strategy.py +207 -0
- careamics/dataset_ng/patching_strategies/whole_sample.py +61 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +11 -0
- careamics/file_io/read/get_func.py +57 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +32 -0
- careamics/lightning/callbacks/__init__.py +13 -0
- careamics/lightning/callbacks/data_stats_callback.py +33 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +234 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +399 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/dataset_ng/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/__init__.py +1 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/__init__.py +29 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/cached_tiles_strategy.py +164 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/file_path_utils.py +33 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/prediction_writer_callback.py +219 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_image_strategy.py +91 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy.py +27 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_strategy_factory.py +214 -0
- careamics/lightning/dataset_ng/callbacks/prediction_writer/write_tiles_zarr_strategy.py +375 -0
- careamics/lightning/dataset_ng/data_module.py +529 -0
- careamics/lightning/dataset_ng/data_module_utils.py +395 -0
- careamics/lightning/dataset_ng/lightning_modules/__init__.py +9 -0
- careamics/lightning/dataset_ng/lightning_modules/care_module.py +97 -0
- careamics/lightning/dataset_ng/lightning_modules/n2v_module.py +106 -0
- careamics/lightning/dataset_ng/lightning_modules/unet_module.py +221 -0
- careamics/lightning/dataset_ng/prediction/__init__.py +16 -0
- careamics/lightning/dataset_ng/prediction/convert_prediction.py +198 -0
- careamics/lightning/dataset_ng/prediction/stitch_prediction.py +171 -0
- careamics/lightning/lightning_module.py +914 -0
- careamics/lightning/microsplit_data_module.py +632 -0
- careamics/lightning/predict_data_module.py +341 -0
- careamics/lightning/train_data_module.py +666 -0
- careamics/losses/__init__.py +21 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +125 -0
- careamics/losses/loss_factory.py +80 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +589 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/calibration.py +191 -0
- careamics/lvae_training/dataset/__init__.py +20 -0
- careamics/lvae_training/dataset/config.py +135 -0
- careamics/lvae_training/dataset/lc_dataset.py +274 -0
- careamics/lvae_training/dataset/ms_dataset_ref.py +1067 -0
- careamics/lvae_training/dataset/multich_dataset.py +1121 -0
- careamics/lvae_training/dataset/multicrop_dset.py +196 -0
- careamics/lvae_training/dataset/multifile_dataset.py +335 -0
- careamics/lvae_training/dataset/types.py +32 -0
- careamics/lvae_training/dataset/utils/__init__.py +0 -0
- careamics/lvae_training/dataset/utils/data_utils.py +114 -0
- careamics/lvae_training/dataset/utils/empty_patch_fetcher.py +65 -0
- careamics/lvae_training/dataset/utils/index_manager.py +491 -0
- careamics/lvae_training/dataset/utils/index_switcher.py +165 -0
- careamics/lvae_training/eval_utils.py +987 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +113 -0
- careamics/model_io/bioimage/bioimage_utils.py +56 -0
- careamics/model_io/bioimage/cover_factory.py +171 -0
- careamics/model_io/bioimage/model_description.py +341 -0
- careamics/model_io/bmz_io.py +251 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +40 -0
- careamics/models/layers.py +495 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1371 -0
- careamics/models/lvae/likelihoods.py +394 -0
- careamics/models/lvae/lvae.py +848 -0
- careamics/models/lvae/noise_models.py +738 -0
- careamics/models/lvae/stochastic.py +394 -0
- careamics/models/lvae/utils.py +404 -0
- careamics/models/model_factory.py +54 -0
- careamics/models/unet.py +449 -0
- careamics/nm_training_placeholder.py +203 -0
- careamics/prediction_utils/__init__.py +21 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +238 -0
- careamics/prediction_utils/stitch_prediction.py +193 -0
- careamics/py.typed +5 -0
- careamics/transforms/__init__.py +22 -0
- careamics/transforms/compose.py +173 -0
- careamics/transforms/n2v_manipulate.py +150 -0
- careamics/transforms/n2v_manipulate_torch.py +149 -0
- careamics/transforms/normalize.py +374 -0
- careamics/transforms/pixel_manipulation.py +406 -0
- careamics/transforms/pixel_manipulation_torch.py +388 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +131 -0
- careamics/transforms/xy_random_rotate90.py +108 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +67 -0
- careamics/utils/deprecation.py +63 -0
- careamics/utils/lightning_utils.py +71 -0
- careamics/utils/logging.py +323 -0
- careamics/utils/metrics.py +394 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/plotting.py +76 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/serializers.py +62 -0
- careamics/utils/torch_utils.py +150 -0
- careamics/utils/version.py +38 -0
- careamics-0.0.19.dist-info/METADATA +80 -0
- careamics-0.0.19.dist-info/RECORD +279 -0
- careamics-0.0.19.dist-info/WHEEL +4 -0
- careamics-0.0.19.dist-info/entry_points.txt +2 -0
- careamics-0.0.19.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from .patching_strategy_protocol import PatchSpecs
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class WholeSamplePatchingStrategy:
|
|
7
|
+
# TODO: warn this strategy should only be used with batch size = 1
|
|
8
|
+
# for the case of multiple image stacks with different dimensions
|
|
9
|
+
|
|
10
|
+
# TODO: docs
|
|
11
|
+
def __init__(self, data_shapes: Sequence[Sequence[int]]):
|
|
12
|
+
self.data_shapes = data_shapes
|
|
13
|
+
|
|
14
|
+
self.patch_specs: list[PatchSpecs] = self._initialize_patch_specs()
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def n_patches(self) -> int:
|
|
18
|
+
return len(self.patch_specs)
|
|
19
|
+
|
|
20
|
+
def get_patch_spec(self, index: int) -> PatchSpecs:
|
|
21
|
+
return self.patch_specs[index]
|
|
22
|
+
|
|
23
|
+
# Note: this is used by the FileIterSampler
|
|
24
|
+
def get_patch_indices(self, data_idx: int) -> Sequence[int]:
|
|
25
|
+
"""
|
|
26
|
+
Get the patch indices will return patches for a specific `image_stack`.
|
|
27
|
+
|
|
28
|
+
The `image_stack` corresponds to the given `data_idx`.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
data_idx : int
|
|
33
|
+
An index that corresponds to a given `image_stack`.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
sequence of int
|
|
38
|
+
A sequence of patch indices, that when used to index the `CAREamicsDataset
|
|
39
|
+
will return a patch that comes from the `image_stack` corresponding to the
|
|
40
|
+
given `data_idx`.
|
|
41
|
+
"""
|
|
42
|
+
return [
|
|
43
|
+
i
|
|
44
|
+
for i, patch_spec in enumerate(self.patch_specs)
|
|
45
|
+
if patch_spec["data_idx"] == data_idx
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def _initialize_patch_specs(self) -> list[PatchSpecs]:
|
|
49
|
+
patch_specs: list[PatchSpecs] = []
|
|
50
|
+
for data_idx, data_shape in enumerate(self.data_shapes):
|
|
51
|
+
spatial_shape = data_shape[2:]
|
|
52
|
+
for sample_idx in range(data_shape[0]):
|
|
53
|
+
patch_specs.append(
|
|
54
|
+
{
|
|
55
|
+
"data_idx": data_idx,
|
|
56
|
+
"sample_idx": sample_idx,
|
|
57
|
+
"coords": tuple(0 for _ in spatial_shape),
|
|
58
|
+
"patch_size": spatial_shape,
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
return patch_specs
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Functions relating reading and writing image files."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"ReadFunc",
|
|
5
|
+
"SupportedWriteType",
|
|
6
|
+
"WriteFunc",
|
|
7
|
+
"get_read_func",
|
|
8
|
+
"get_write_func",
|
|
9
|
+
"read",
|
|
10
|
+
"write",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
from . import read, write
|
|
14
|
+
from .read import ReadFunc, get_read_func
|
|
15
|
+
from .write import SupportedWriteType, WriteFunc, get_write_func
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Module to get read functions."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Protocol, Union
|
|
6
|
+
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
|
|
11
|
+
from .tiff import read_tiff
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# This is very strict, function signature has to match including arg names
|
|
15
|
+
# See WriteFunc notes
|
|
16
|
+
class ReadFunc(Protocol):
|
|
17
|
+
"""Protocol for type hinting read functions."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
|
|
20
|
+
"""
|
|
21
|
+
Type hinted callables must match this function signature (not including self).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
file_path : pathlib.Path
|
|
26
|
+
Path to file.
|
|
27
|
+
*args
|
|
28
|
+
Other positional arguments.
|
|
29
|
+
**kwargs
|
|
30
|
+
Other keyword arguments.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
READ_FUNCS: dict[SupportedData, ReadFunc] = {
|
|
35
|
+
SupportedData.TIFF: read_tiff,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
|
|
40
|
+
"""
|
|
41
|
+
Get the read function for the data type.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
data_type : SupportedData
|
|
46
|
+
Data type.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
callable
|
|
51
|
+
Read function.
|
|
52
|
+
"""
|
|
53
|
+
if data_type in READ_FUNCS:
|
|
54
|
+
data_type = SupportedData(data_type) # mypy complaining about dict key type
|
|
55
|
+
return READ_FUNCS[data_type]
|
|
56
|
+
else:
|
|
57
|
+
raise NotImplementedError(f"Data type '{data_type}' is not supported.")
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Functions to read tiff images."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from fnmatch import fnmatch
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import tifffile
|
|
9
|
+
|
|
10
|
+
from careamics.config.support import SupportedData
|
|
11
|
+
from careamics.utils.logging import get_logger
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
|
|
17
|
+
"""
|
|
18
|
+
Read a tiff file and return a numpy array.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
file_path : Path
|
|
23
|
+
Path to a file.
|
|
24
|
+
*args : list
|
|
25
|
+
Additional arguments.
|
|
26
|
+
**kwargs : dict
|
|
27
|
+
Additional keyword arguments.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
np.ndarray
|
|
32
|
+
Resulting array.
|
|
33
|
+
|
|
34
|
+
Raises
|
|
35
|
+
------
|
|
36
|
+
ValueError
|
|
37
|
+
If the file failed to open.
|
|
38
|
+
OSError
|
|
39
|
+
If the file failed to open.
|
|
40
|
+
ValueError
|
|
41
|
+
If the file is not a valid tiff.
|
|
42
|
+
ValueError
|
|
43
|
+
If the data dimensions are incorrect.
|
|
44
|
+
ValueError
|
|
45
|
+
If the axes length is incorrect.
|
|
46
|
+
"""
|
|
47
|
+
if fnmatch(
|
|
48
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
49
|
+
):
|
|
50
|
+
try:
|
|
51
|
+
array = tifffile.imread(file_path)
|
|
52
|
+
except (ValueError, OSError) as e:
|
|
53
|
+
logging.exception(f"Exception in file {file_path}: {e}, skipping it.")
|
|
54
|
+
raise e
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"File {file_path} is not a valid tiff.")
|
|
57
|
+
|
|
58
|
+
return array
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Functions relating to writing image files of different formats."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"SupportedWriteType",
|
|
5
|
+
"WriteFunc",
|
|
6
|
+
"get_write_func",
|
|
7
|
+
"write_tiff",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from .get_func import (
|
|
11
|
+
SupportedWriteType,
|
|
12
|
+
WriteFunc,
|
|
13
|
+
get_write_func,
|
|
14
|
+
)
|
|
15
|
+
from .tiff import write_tiff
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Module to get write functions."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Literal, Protocol
|
|
5
|
+
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
|
|
8
|
+
from careamics.config.support import SupportedData
|
|
9
|
+
|
|
10
|
+
from .tiff import write_tiff
|
|
11
|
+
|
|
12
|
+
SupportedWriteType = Literal["tiff", "zarr", "custom"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# This is very strict, arguments have to be called file_path & img
|
|
16
|
+
# Alternative? - doesn't capture *args & **kwargs
|
|
17
|
+
# WriteFunc = Callable[[Path, NDArray], None]
|
|
18
|
+
class WriteFunc(Protocol):
|
|
19
|
+
"""Protocol for type hinting write functions."""
|
|
20
|
+
|
|
21
|
+
def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Type hinted callables must match this function signature (not including self).
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
file_path : pathlib.Path
|
|
28
|
+
Path to file.
|
|
29
|
+
img : numpy.ndarray
|
|
30
|
+
Image data to save.
|
|
31
|
+
*args
|
|
32
|
+
Other positional arguments.
|
|
33
|
+
**kwargs
|
|
34
|
+
Other keyword arguments.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
|
|
39
|
+
SupportedData.TIFF: write_tiff,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_write_func(data_type: SupportedWriteType) -> WriteFunc:
|
|
44
|
+
"""
|
|
45
|
+
Get the write function for the data type.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
data_type : {"tiff", "custom"}
|
|
50
|
+
Data type.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
callable
|
|
55
|
+
Write function.
|
|
56
|
+
"""
|
|
57
|
+
# error raised here if not supported
|
|
58
|
+
data_type_ = SupportedData(data_type) # new variable for mypy
|
|
59
|
+
# error if no write func.
|
|
60
|
+
if data_type_ not in WRITE_FUNCS:
|
|
61
|
+
raise NotImplementedError(f"No write function for data type '{data_type}'.")
|
|
62
|
+
|
|
63
|
+
return WRITE_FUNCS[data_type_]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Write tiff function."""
|
|
2
|
+
|
|
3
|
+
from fnmatch import fnmatch
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tifffile
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from careamics.config.support import SupportedData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
|
|
13
|
+
# TODO: add link to tiffile docs for args kwrgs?
|
|
14
|
+
"""
|
|
15
|
+
Write tiff files.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
file_path : pathlib.Path
|
|
20
|
+
Path to file.
|
|
21
|
+
img : numpy.ndarray
|
|
22
|
+
Image data to save.
|
|
23
|
+
*args
|
|
24
|
+
Positional arguments passed to `tifffile.imwrite`.
|
|
25
|
+
**kwargs
|
|
26
|
+
Keyword arguments passed to `tifffile.imwrite`.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
ValueError
|
|
31
|
+
When the file extension of `file_path` does not match the Unix shell-style
|
|
32
|
+
pattern '*.tif*'.
|
|
33
|
+
"""
|
|
34
|
+
if not fnmatch(
|
|
35
|
+
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
|
|
36
|
+
):
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
|
|
39
|
+
)
|
|
40
|
+
tifffile.imwrite(file_path, img, *args, **kwargs)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""CAREamics PyTorch Lightning modules."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
5
|
+
"FCNModule",
|
|
6
|
+
"HyperParametersCallback",
|
|
7
|
+
"MicroSplitDataModule",
|
|
8
|
+
"PredictDataModule",
|
|
9
|
+
"ProgressBarCallback",
|
|
10
|
+
"TrainDataModule",
|
|
11
|
+
"VAEModule",
|
|
12
|
+
"create_careamics_module",
|
|
13
|
+
"create_microsplit_predict_datamodule",
|
|
14
|
+
"create_microsplit_train_datamodule",
|
|
15
|
+
"create_predict_datamodule",
|
|
16
|
+
"create_train_datamodule",
|
|
17
|
+
"create_unet_based_module",
|
|
18
|
+
"create_vae_based_module",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
from .callbacks import DataStatsCallback, HyperParametersCallback, ProgressBarCallback
|
|
22
|
+
from .lightning_module import FCNModule, VAEModule, create_careamics_module
|
|
23
|
+
from .microsplit_data_module import (
|
|
24
|
+
MicroSplitDataModule,
|
|
25
|
+
create_microsplit_predict_datamodule,
|
|
26
|
+
create_microsplit_train_datamodule,
|
|
27
|
+
)
|
|
28
|
+
from .predict_data_module import PredictDataModule, create_predict_datamodule
|
|
29
|
+
from .train_data_module import (
|
|
30
|
+
TrainDataModule,
|
|
31
|
+
create_train_datamodule,
|
|
32
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Callbacks module."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"DataStatsCallback",
|
|
5
|
+
"HyperParametersCallback",
|
|
6
|
+
"PredictionWriterCallback",
|
|
7
|
+
"ProgressBarCallback",
|
|
8
|
+
]
|
|
9
|
+
|
|
10
|
+
from .data_stats_callback import DataStatsCallback
|
|
11
|
+
from .hyperparameters_callback import HyperParametersCallback
|
|
12
|
+
from .prediction_writer_callback import PredictionWriterCallback
|
|
13
|
+
from .progress_bar_callback import ProgressBarCallback
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Data statistics callback."""
|
|
2
|
+
|
|
3
|
+
import pytorch_lightning as L
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DataStatsCallback(Callback):
|
|
8
|
+
"""Callback to update model's data statistics from datamodule.
|
|
9
|
+
|
|
10
|
+
This callback ensures that the model has access to the data statistics (mean, std)
|
|
11
|
+
calculated by the datamodule before training starts.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def setup(self, trainer: L.Trainer, module: L.LightningModule, stage: str) -> None:
|
|
15
|
+
"""Called when trainer is setting up.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
trainer : Lightning.Trainer
|
|
20
|
+
PyTorch Lightning trainer.
|
|
21
|
+
module : Lightning.LightningModule
|
|
22
|
+
Lightning module.
|
|
23
|
+
stage : str
|
|
24
|
+
Current stage (fit, validate, test, or predict).
|
|
25
|
+
"""
|
|
26
|
+
if stage == "fit":
|
|
27
|
+
# Get data statistics from datamodule
|
|
28
|
+
(data_mean, data_std), _ = trainer.datamodule.get_data_stats()
|
|
29
|
+
|
|
30
|
+
# Set data statistics in the model's likelihood module
|
|
31
|
+
module.noise_model_likelihood.set_data_stats(
|
|
32
|
+
data_mean=data_mean["target"], data_std=data_std["target"]
|
|
33
|
+
)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Callback saving CAREamics configuration as hyperparameters in the model."""
|
|
2
|
+
|
|
3
|
+
from pytorch_lightning import LightningModule, Trainer
|
|
4
|
+
from pytorch_lightning.callbacks import Callback
|
|
5
|
+
|
|
6
|
+
from careamics.config import Configuration
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HyperParametersCallback(Callback):
|
|
10
|
+
"""
|
|
11
|
+
Callback allowing saving CAREamics configuration as hyperparameters in the model.
|
|
12
|
+
|
|
13
|
+
This allows saving the configuration as dictionary in the checkpoints, and
|
|
14
|
+
loading it subsequently in a CAREamist instance.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
config : Configuration
|
|
19
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
20
|
+
|
|
21
|
+
Attributes
|
|
22
|
+
----------
|
|
23
|
+
config : Configuration
|
|
24
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: Configuration) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Constructor.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
config : Configuration
|
|
34
|
+
CAREamics configuration to be saved as hyperparameter in the model.
|
|
35
|
+
"""
|
|
36
|
+
self.config = config
|
|
37
|
+
|
|
38
|
+
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Update the hyperparameters of the model with the configuration on train start.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
trainer : Trainer
|
|
45
|
+
PyTorch Lightning trainer, unused.
|
|
46
|
+
pl_module : LightningModule
|
|
47
|
+
PyTorch Lightning module.
|
|
48
|
+
"""
|
|
49
|
+
pl_module.hparams.update(self.config.model_dump())
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""A package for the `PredictionWriterCallback` class and utilities."""
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"CacheTiles",
|
|
5
|
+
"PredictionWriterCallback",
|
|
6
|
+
"WriteImage",
|
|
7
|
+
"WriteStrategy",
|
|
8
|
+
"WriteTilesZarr",
|
|
9
|
+
"create_write_strategy",
|
|
10
|
+
"select_write_extension",
|
|
11
|
+
"select_write_func",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
from .prediction_writer_callback import PredictionWriterCallback
|
|
15
|
+
from .write_strategy import CacheTiles, WriteImage, WriteStrategy, WriteTilesZarr
|
|
16
|
+
from .write_strategy_factory import (
|
|
17
|
+
create_write_strategy,
|
|
18
|
+
select_write_extension,
|
|
19
|
+
select_write_func,
|
|
20
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Module containing file path utilities for `WriteStrategy` to use."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from careamics.dataset import IterablePredDataset, IterableTiledPredDataset
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# TODO: move to datasets package ?
|
|
10
|
+
def get_sample_file_path(
|
|
11
|
+
dataset: Union[IterableTiledPredDataset, IterablePredDataset], sample_id: int
|
|
12
|
+
) -> Path:
|
|
13
|
+
"""
|
|
14
|
+
Get the file path for a particular sample.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
dataset : IterableTiledPredDataset or IterablePredDataset
|
|
19
|
+
Dataset.
|
|
20
|
+
sample_id : int
|
|
21
|
+
Sample ID, the index of the file in the dataset `dataset`.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
Path
|
|
26
|
+
The file path corresponding to the sample with the ID `sample_id`.
|
|
27
|
+
"""
|
|
28
|
+
return dataset.data_files[sample_id]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_write_file_path(
|
|
32
|
+
dirpath: Path, file_path: Path, write_extension: str
|
|
33
|
+
) -> Path:
|
|
34
|
+
"""
|
|
35
|
+
Create the file name for the output file.
|
|
36
|
+
|
|
37
|
+
Takes the original file path, changes the directory to `dirpath` and changes
|
|
38
|
+
the extension to `write_extension`.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
dirpath : pathlib.Path
|
|
43
|
+
The output directory to write file to.
|
|
44
|
+
file_path : pathlib.Path
|
|
45
|
+
The original file path.
|
|
46
|
+
write_extension : str
|
|
47
|
+
The extension that output files should have.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
Path
|
|
52
|
+
The output file path.
|
|
53
|
+
"""
|
|
54
|
+
file_name = Path(file_path.stem).with_suffix(write_extension)
|
|
55
|
+
file_path = dirpath / file_name
|
|
56
|
+
return file_path
|