careamics 0.1.0rc6__py3-none-any.whl → 0.1.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (91) hide show
  1. careamics/__init__.py +1 -14
  2. careamics/careamist.py +212 -294
  3. careamics/config/__init__.py +0 -3
  4. careamics/config/algorithm_model.py +8 -15
  5. careamics/config/architectures/architecture_model.py +1 -0
  6. careamics/config/architectures/custom_model.py +5 -3
  7. careamics/config/architectures/unet_model.py +19 -0
  8. careamics/config/architectures/vae_model.py +1 -0
  9. careamics/config/callback_model.py +76 -34
  10. careamics/config/configuration_factory.py +18 -98
  11. careamics/config/configuration_model.py +23 -18
  12. careamics/config/data_model.py +103 -54
  13. careamics/config/inference_model.py +41 -19
  14. careamics/config/optimizer_models.py +13 -7
  15. careamics/config/support/supported_data.py +29 -4
  16. careamics/config/support/supported_transforms.py +0 -1
  17. careamics/config/tile_information.py +36 -58
  18. careamics/config/training_model.py +5 -1
  19. careamics/config/transformations/normalize_model.py +32 -4
  20. careamics/config/validators/validator_utils.py +1 -1
  21. careamics/dataset/__init__.py +12 -1
  22. careamics/dataset/dataset_utils/__init__.py +8 -7
  23. careamics/dataset/dataset_utils/file_utils.py +2 -2
  24. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  25. careamics/dataset/dataset_utils/running_stats.py +186 -0
  26. careamics/dataset/in_memory_dataset.py +84 -173
  27. careamics/dataset/in_memory_pred_dataset.py +88 -0
  28. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  29. careamics/dataset/iterable_dataset.py +97 -250
  30. careamics/dataset/iterable_pred_dataset.py +122 -0
  31. careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
  32. careamics/dataset/patching/patching.py +97 -52
  33. careamics/dataset/patching/random_patching.py +9 -4
  34. careamics/dataset/patching/validate_patch_dimension.py +5 -3
  35. careamics/dataset/tiling/__init__.py +10 -0
  36. careamics/dataset/tiling/collate_tiles.py +33 -0
  37. careamics/dataset/{patching → tiling}/tiled_patching.py +4 -4
  38. careamics/file_io/__init__.py +7 -0
  39. careamics/file_io/read/__init__.py +11 -0
  40. careamics/file_io/read/get_func.py +56 -0
  41. careamics/{dataset/dataset_utils/read_tiff.py → file_io/read/tiff.py} +3 -10
  42. careamics/file_io/write/__init__.py +9 -0
  43. careamics/file_io/write/get_func.py +59 -0
  44. careamics/file_io/write/tiff.py +39 -0
  45. careamics/lightning/__init__.py +17 -0
  46. careamics/{lightning_module.py → lightning/lightning_module.py} +69 -92
  47. careamics/{lightning_prediction_datamodule.py → lightning/predict_data_module.py} +120 -178
  48. careamics/{lightning_datamodule.py → lightning/train_data_module.py} +135 -220
  49. careamics/lvae_training/__init__.py +0 -0
  50. careamics/lvae_training/data_modules.py +1220 -0
  51. careamics/lvae_training/data_utils.py +618 -0
  52. careamics/lvae_training/eval_utils.py +905 -0
  53. careamics/lvae_training/get_config.py +84 -0
  54. careamics/lvae_training/lightning_module.py +701 -0
  55. careamics/lvae_training/metrics.py +214 -0
  56. careamics/lvae_training/train_lvae.py +339 -0
  57. careamics/lvae_training/train_utils.py +121 -0
  58. careamics/model_io/bioimage/model_description.py +40 -32
  59. careamics/model_io/bmz_io.py +2 -2
  60. careamics/model_io/model_io_utils.py +6 -3
  61. careamics/models/lvae/__init__.py +0 -0
  62. careamics/models/lvae/layers.py +1998 -0
  63. careamics/models/lvae/likelihoods.py +312 -0
  64. careamics/models/lvae/lvae.py +985 -0
  65. careamics/models/lvae/noise_models.py +409 -0
  66. careamics/models/lvae/utils.py +395 -0
  67. careamics/prediction_utils/__init__.py +10 -0
  68. careamics/prediction_utils/prediction_outputs.py +137 -0
  69. careamics/prediction_utils/stitch_prediction.py +103 -0
  70. careamics/transforms/n2v_manipulate.py +3 -1
  71. careamics/transforms/normalize.py +139 -68
  72. careamics/transforms/pixel_manipulation.py +33 -9
  73. careamics/transforms/tta.py +43 -29
  74. careamics/utils/__init__.py +2 -0
  75. careamics/utils/autocorrelation.py +40 -0
  76. careamics/utils/ram.py +2 -2
  77. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/METADATA +7 -6
  78. careamics-0.1.0rc8.dist-info/RECORD +135 -0
  79. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/WHEEL +1 -1
  80. careamics/config/configuration_example.py +0 -89
  81. careamics/dataset/dataset_utils/read_utils.py +0 -27
  82. careamics/lightning_prediction_loop.py +0 -118
  83. careamics/prediction/__init__.py +0 -7
  84. careamics/prediction/stitch_prediction.py +0 -70
  85. careamics/utils/running_stats.py +0 -43
  86. careamics-0.1.0rc6.dist-info/RECORD +0 -107
  87. /careamics/{dataset/dataset_utils/read_zarr.py → file_io/read/zarr.py} +0 -0
  88. /careamics/{callbacks → lightning/callbacks}/__init__.py +0 -0
  89. /careamics/{callbacks → lightning/callbacks}/hyperparameters_callback.py +0 -0
  90. /careamics/{callbacks → lightning/callbacks}/progress_bar_callback.py +0 -0
  91. {careamics-0.1.0rc6.dist-info → careamics-0.1.0rc8.dist-info}/licenses/LICENSE +0 -0
@@ -13,8 +13,8 @@ from bioimageio.spec.model.v0_5 import (
13
13
  ChannelAxis,
14
14
  EnvironmentFileDescr,
15
15
  FileDescr,
16
+ FixedZeroMeanUnitVarianceAlongAxisKwargs,
16
17
  FixedZeroMeanUnitVarianceDescr,
17
- FixedZeroMeanUnitVarianceKwargs,
18
18
  Identifier,
19
19
  InputTensorDescr,
20
20
  ModelDescr,
@@ -134,44 +134,52 @@ def _create_inputs_ouputs(
134
134
  output_axes = _create_axes(output_array, data_config, channel_names, False)
135
135
 
136
136
  # mean and std
137
- assert data_config.mean is not None, "Mean cannot be None."
138
- assert data_config.std is not None, "Std cannot be None."
139
- mean = data_config.mean
140
- std = data_config.std
137
+ assert data_config.image_means is not None, "Mean cannot be None."
138
+ assert data_config.image_means is not None, "Std cannot be None."
139
+ means = data_config.image_means
140
+ stds = data_config.image_stds
141
141
 
142
142
  # and the mean and std required to invert the normalization
143
143
  # CAREamics denormalization: x = y * (std + eps) + mean
144
144
  # BMZ normalization : x = (y - mean') / (std' + eps)
145
145
  # to apply the BMZ normalization as a denormalization step, we need:
146
146
  eps = 1e-6
147
- inv_mean = -mean / (std + eps)
148
- inv_std = 1 / (std + eps) - eps
149
-
150
- # create input/output descriptions
151
- input_descr = InputTensorDescr(
152
- id=TensorId("input"),
153
- axes=input_axes,
154
- test_tensor=FileDescr(source=input_path),
155
- preprocessing=[
156
- FixedZeroMeanUnitVarianceDescr(
157
- kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std)
158
- )
159
- ],
160
- )
161
- output_descr = OutputTensorDescr(
162
- id=TensorId("prediction"),
163
- axes=output_axes,
164
- test_tensor=FileDescr(source=output_path),
165
- postprocessing=[
166
- FixedZeroMeanUnitVarianceDescr(
167
- kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
168
- mean=inv_mean, std=inv_std
147
+ inv_means = []
148
+ inv_stds = []
149
+ if means and stds:
150
+ for mean, std in zip(means, stds):
151
+ inv_means.append(-mean / (std + eps))
152
+ inv_stds.append(1 / (std + eps) - eps)
153
+
154
+ # create input/output descriptions
155
+ input_descr = InputTensorDescr(
156
+ id=TensorId("input"),
157
+ axes=input_axes,
158
+ test_tensor=FileDescr(source=input_path),
159
+ preprocessing=[
160
+ FixedZeroMeanUnitVarianceDescr(
161
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
162
+ mean=means, std=stds, axis="channel"
163
+ )
169
164
  )
170
- )
171
- ],
172
- )
165
+ ],
166
+ )
167
+ output_descr = OutputTensorDescr(
168
+ id=TensorId("prediction"),
169
+ axes=output_axes,
170
+ test_tensor=FileDescr(source=output_path),
171
+ postprocessing=[
172
+ FixedZeroMeanUnitVarianceDescr(
173
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
174
+ mean=inv_means, std=inv_stds, axis="channel"
175
+ )
176
+ )
177
+ ],
178
+ )
173
179
 
174
- return input_descr, output_descr
180
+ return input_descr, output_descr
181
+ else:
182
+ raise ValueError("Mean and std cannot be None.")
175
183
 
176
184
 
177
185
  def create_model_description(
@@ -280,7 +288,7 @@ def create_model_description(
280
288
  "bioimageio": {
281
289
  "test_kwargs": {
282
290
  "pytorch_state_dict": {
283
- "decimals": 2, # ...so we relax the constraints on the decimals
291
+ "decimals": 0, # ...so we relax the constraints on the decimals
284
292
  }
285
293
  }
286
294
  }
@@ -12,7 +12,7 @@ from torch import __version__, load, save
12
12
 
13
13
  from careamics.config import Configuration, load_configuration, save_configuration
14
14
  from careamics.config.support import SupportedArchitecture
15
- from careamics.lightning_module import CAREamicsModule
15
+ from careamics.lightning.lightning_module import CAREamicsModule
16
16
 
17
17
  from .bioimage import (
18
18
  create_env_text,
@@ -178,7 +178,7 @@ def export_to_bmz(
178
178
  )
179
179
 
180
180
  # test model description
181
- summary: ValidationSummary = test_model(model_description, decimal=2)
181
+ summary: ValidationSummary = test_model(model_description, decimal=1)
182
182
  if summary.status == "failed":
183
183
  raise ValueError(f"Model description test failed: {summary}")
184
184
 
@@ -3,10 +3,10 @@
3
3
  from pathlib import Path
4
4
  from typing import Tuple, Union
5
5
 
6
- from torch import load
6
+ import torch
7
7
 
8
8
  from careamics.config import Configuration
9
- from careamics.lightning_module import CAREamicsModule
9
+ from careamics.lightning.lightning_module import CAREamicsModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
12
12
 
@@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
64
64
  If the checkpoint file does not contain hyper parameters (configuration).
65
65
  """
66
66
  # load checkpoint
67
- checkpoint: dict = load(path)
67
+ # here we might run into issues between devices
68
+ # see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
69
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70
+ checkpoint: dict = torch.load(path, map_location=device)
68
71
 
69
72
  # attempt to load configuration
70
73
  try:
File without changes