careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (69) hide show
  1. careamics/careamist.py +163 -266
  2. careamics/config/algorithm_model.py +0 -15
  3. careamics/config/architectures/custom_model.py +3 -3
  4. careamics/config/configuration_example.py +0 -3
  5. careamics/config/configuration_factory.py +23 -25
  6. careamics/config/configuration_model.py +11 -11
  7. careamics/config/data_model.py +80 -50
  8. careamics/config/inference_model.py +29 -17
  9. careamics/config/optimizer_models.py +7 -7
  10. careamics/config/support/supported_transforms.py +0 -1
  11. careamics/config/tile_information.py +26 -58
  12. careamics/config/transformations/normalize_model.py +32 -4
  13. careamics/config/validators/validator_utils.py +1 -1
  14. careamics/dataset/__init__.py +12 -1
  15. careamics/dataset/dataset_utils/__init__.py +8 -1
  16. careamics/dataset/dataset_utils/file_utils.py +1 -1
  17. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  18. careamics/dataset/dataset_utils/read_tiff.py +0 -9
  19. careamics/dataset/dataset_utils/running_stats.py +186 -0
  20. careamics/dataset/in_memory_dataset.py +66 -171
  21. careamics/dataset/in_memory_pred_dataset.py +88 -0
  22. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  23. careamics/dataset/iterable_dataset.py +92 -249
  24. careamics/dataset/iterable_pred_dataset.py +121 -0
  25. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  26. careamics/dataset/patching/patching.py +54 -25
  27. careamics/dataset/patching/random_patching.py +9 -4
  28. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  29. careamics/dataset/tiling/__init__.py +10 -0
  30. careamics/dataset/tiling/collate_tiles.py +33 -0
  31. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  32. careamics/lightning_datamodule.py +1 -6
  33. careamics/lightning_module.py +11 -7
  34. careamics/lightning_prediction_datamodule.py +52 -72
  35. careamics/lvae_training/__init__.py +0 -0
  36. careamics/lvae_training/data_modules.py +1220 -0
  37. careamics/lvae_training/data_utils.py +618 -0
  38. careamics/lvae_training/eval_utils.py +905 -0
  39. careamics/lvae_training/get_config.py +84 -0
  40. careamics/lvae_training/lightning_module.py +701 -0
  41. careamics/lvae_training/metrics.py +214 -0
  42. careamics/lvae_training/train_lvae.py +339 -0
  43. careamics/lvae_training/train_utils.py +121 -0
  44. careamics/model_io/bioimage/model_description.py +40 -32
  45. careamics/model_io/bmz_io.py +1 -1
  46. careamics/model_io/model_io_utils.py +5 -2
  47. careamics/models/lvae/__init__.py +0 -0
  48. careamics/models/lvae/layers.py +1998 -0
  49. careamics/models/lvae/likelihoods.py +312 -0
  50. careamics/models/lvae/lvae.py +985 -0
  51. careamics/models/lvae/noise_models.py +409 -0
  52. careamics/models/lvae/utils.py +395 -0
  53. careamics/prediction_utils/__init__.py +12 -0
  54. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  55. careamics/prediction_utils/prediction_outputs.py +165 -0
  56. careamics/prediction_utils/stitch_prediction.py +100 -0
  57. careamics/transforms/n2v_manipulate.py +3 -1
  58. careamics/transforms/normalize.py +139 -68
  59. careamics/transforms/pixel_manipulation.py +33 -9
  60. careamics/transforms/tta.py +43 -29
  61. careamics/utils/ram.py +2 -2
  62. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +7 -6
  63. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/RECORD +65 -42
  64. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  65. careamics/lightning_prediction_loop.py +0 -118
  66. careamics/prediction/__init__.py +0 -7
  67. careamics/prediction/stitch_prediction.py +0 -70
  68. careamics/utils/running_stats.py +0 -43
  69. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -1,70 +0,0 @@
1
- """Prediction utility functions."""
2
-
3
- from typing import List
4
-
5
- import numpy as np
6
- import torch
7
-
8
-
9
- def stitch_prediction(
10
- tiles: List[torch.Tensor],
11
- stitching_data: List[List[torch.Tensor]],
12
- ) -> torch.Tensor:
13
- """
14
- Stitch tiles back together to form a full image.
15
-
16
- Parameters
17
- ----------
18
- tiles : List[torch.Tensor]
19
- Cropped tiles and their respective stitching coordinates.
20
- stitching_data : List
21
- List of information and coordinates obtained from
22
- `dataset.tiled_patching.extract_tiles`.
23
-
24
- Returns
25
- -------
26
- np.ndarray
27
- Full image.
28
- """
29
- # retrieve whole array size, there is two cases to consider:
30
- # 1. the tiles are stored in a list
31
- # 2. the tiles are stored in a list with batches along the first dim
32
- if tiles[0].shape[0] > 1:
33
- input_shape = np.array(
34
- [el.numpy() for el in stitching_data[0][0][0]], dtype=int
35
- ).squeeze()
36
- else:
37
- input_shape = np.array(
38
- [el.numpy() for el in stitching_data[0][0]], dtype=int
39
- ).squeeze()
40
-
41
- # TODO should use torch.zeros instead of np.zeros
42
- predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32))
43
-
44
- for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip(
45
- tiles, stitching_data
46
- ):
47
- for batch_idx in range(tile_batch.shape[0]):
48
- # Compute coordinates for cropping predicted tile
49
- slices = tuple(
50
- [
51
- slice(c[0][batch_idx], c[1][batch_idx])
52
- for c in overlap_crop_coords_batch
53
- ]
54
- )
55
-
56
- # Crop predited tile according to overlap coordinates
57
- cropped_tile = tile_batch[batch_idx].squeeze()[slices]
58
-
59
- # Insert cropped tile into predicted image using stitch coordinates
60
- predicted_image[
61
- (
62
- ...,
63
- *[
64
- slice(c[0][batch_idx], c[1][batch_idx])
65
- for c in stitch_coords_batch
66
- ],
67
- )
68
- ] = cropped_tile.to(torch.float32)
69
-
70
- return predicted_image
@@ -1,43 +0,0 @@
1
- """Running stats submodule, used in the Zarr dataset."""
2
-
3
- # from multiprocessing import Value
4
- # from typing import Tuple
5
-
6
- # import numpy as np
7
-
8
-
9
- # class RunningStats:
10
- # """Calculates running mean and std."""
11
-
12
- # def __init__(self) -> None:
13
- # self.reset()
14
-
15
- # def reset(self) -> None:
16
- # """Reset the running stats."""
17
- # self.avg_mean = Value("d", 0)
18
- # self.avg_std = Value("d", 0)
19
- # self.m2 = Value("d", 0)
20
- # self.count = Value("i", 0)
21
-
22
- # def init(self, mean: float, std: float) -> None:
23
- # """Initialize running stats."""
24
- # with self.avg_mean.get_lock():
25
- # self.avg_mean.value += mean
26
- # with self.avg_std.get_lock():
27
- # self.avg_std.value = std
28
-
29
- # def compute_std(self) -> Tuple[float, float]:
30
- # """Compute std."""
31
- # if self.count.value >= 2:
32
- # self.avg_std.value = np.sqrt(self.m2.value / self.count.value)
33
-
34
- # def update(self, value: float) -> None:
35
- # """Update running stats."""
36
- # with self.count.get_lock():
37
- # self.count.value += 1
38
- # delta = value - self.avg_mean.value
39
- # with self.avg_mean.get_lock():
40
- # self.avg_mean.value += delta / self.count.value
41
- # delta2 = value - self.avg_mean.value
42
- # with self.m2.get_lock():
43
- # self.m2.value += delta * delta2