careamics 0.1.0rc5__py3-none-any.whl → 0.1.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (118) hide show
  1. careamics/callbacks/hyperparameters_callback.py +10 -3
  2. careamics/callbacks/progress_bar_callback.py +37 -4
  3. careamics/careamist.py +164 -231
  4. careamics/config/algorithm_model.py +5 -18
  5. careamics/config/architectures/architecture_model.py +7 -0
  6. careamics/config/architectures/custom_model.py +11 -4
  7. careamics/config/architectures/register_model.py +3 -1
  8. careamics/config/architectures/unet_model.py +2 -0
  9. careamics/config/architectures/vae_model.py +2 -0
  10. careamics/config/callback_model.py +3 -15
  11. careamics/config/configuration_example.py +4 -5
  12. careamics/config/configuration_factory.py +27 -41
  13. careamics/config/configuration_model.py +11 -11
  14. careamics/config/data_model.py +89 -63
  15. careamics/config/inference_model.py +28 -81
  16. careamics/config/optimizer_models.py +11 -11
  17. careamics/config/support/__init__.py +0 -2
  18. careamics/config/support/supported_activations.py +2 -0
  19. careamics/config/support/supported_algorithms.py +3 -1
  20. careamics/config/support/supported_architectures.py +2 -0
  21. careamics/config/support/supported_data.py +2 -0
  22. careamics/config/support/supported_loggers.py +2 -0
  23. careamics/config/support/supported_losses.py +2 -0
  24. careamics/config/support/supported_optimizers.py +2 -0
  25. careamics/config/support/supported_pixel_manipulations.py +3 -3
  26. careamics/config/support/supported_struct_axis.py +2 -0
  27. careamics/config/support/supported_transforms.py +4 -16
  28. careamics/config/tile_information.py +28 -58
  29. careamics/config/transformations/__init__.py +3 -2
  30. careamics/config/transformations/normalize_model.py +32 -4
  31. careamics/config/transformations/xy_flip_model.py +43 -0
  32. careamics/config/transformations/xy_random_rotate90_model.py +11 -3
  33. careamics/config/validators/validator_utils.py +1 -1
  34. careamics/conftest.py +12 -0
  35. careamics/dataset/__init__.py +12 -1
  36. careamics/dataset/dataset_utils/__init__.py +8 -1
  37. careamics/dataset/dataset_utils/dataset_utils.py +4 -4
  38. careamics/dataset/dataset_utils/file_utils.py +4 -3
  39. careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
  40. careamics/dataset/dataset_utils/read_tiff.py +6 -11
  41. careamics/dataset/dataset_utils/read_utils.py +2 -0
  42. careamics/dataset/dataset_utils/read_zarr.py +11 -7
  43. careamics/dataset/dataset_utils/running_stats.py +186 -0
  44. careamics/dataset/in_memory_dataset.py +88 -154
  45. careamics/dataset/in_memory_pred_dataset.py +88 -0
  46. careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
  47. careamics/dataset/iterable_dataset.py +121 -191
  48. careamics/dataset/iterable_pred_dataset.py +121 -0
  49. careamics/dataset/iterable_tiled_pred_dataset.py +139 -0
  50. careamics/dataset/patching/patching.py +109 -39
  51. careamics/dataset/patching/random_patching.py +17 -6
  52. careamics/dataset/patching/sequential_patching.py +14 -8
  53. careamics/dataset/patching/validate_patch_dimension.py +7 -3
  54. careamics/dataset/tiling/__init__.py +10 -0
  55. careamics/dataset/tiling/collate_tiles.py +33 -0
  56. careamics/dataset/{patching → tiling}/tiled_patching.py +7 -5
  57. careamics/dataset/zarr_dataset.py +2 -0
  58. careamics/lightning_datamodule.py +46 -25
  59. careamics/lightning_module.py +19 -9
  60. careamics/lightning_prediction_datamodule.py +54 -84
  61. careamics/losses/__init__.py +2 -3
  62. careamics/losses/loss_factory.py +1 -1
  63. careamics/losses/losses.py +11 -7
  64. careamics/lvae_training/__init__.py +0 -0
  65. careamics/lvae_training/data_modules.py +1220 -0
  66. careamics/lvae_training/data_utils.py +618 -0
  67. careamics/lvae_training/eval_utils.py +905 -0
  68. careamics/lvae_training/get_config.py +84 -0
  69. careamics/lvae_training/lightning_module.py +701 -0
  70. careamics/lvae_training/metrics.py +214 -0
  71. careamics/lvae_training/train_lvae.py +339 -0
  72. careamics/lvae_training/train_utils.py +121 -0
  73. careamics/model_io/bioimage/model_description.py +40 -32
  74. careamics/model_io/bmz_io.py +3 -3
  75. careamics/model_io/model_io_utils.py +5 -2
  76. careamics/models/activation.py +2 -0
  77. careamics/models/layers.py +121 -25
  78. careamics/models/lvae/__init__.py +0 -0
  79. careamics/models/lvae/layers.py +1998 -0
  80. careamics/models/lvae/likelihoods.py +312 -0
  81. careamics/models/lvae/lvae.py +985 -0
  82. careamics/models/lvae/noise_models.py +409 -0
  83. careamics/models/lvae/utils.py +395 -0
  84. careamics/models/model_factory.py +1 -1
  85. careamics/models/unet.py +35 -14
  86. careamics/prediction_utils/__init__.py +12 -0
  87. careamics/prediction_utils/create_pred_datamodule.py +185 -0
  88. careamics/prediction_utils/prediction_outputs.py +165 -0
  89. careamics/prediction_utils/stitch_prediction.py +100 -0
  90. careamics/transforms/__init__.py +2 -2
  91. careamics/transforms/compose.py +33 -7
  92. careamics/transforms/n2v_manipulate.py +52 -14
  93. careamics/transforms/normalize.py +171 -48
  94. careamics/transforms/pixel_manipulation.py +35 -11
  95. careamics/transforms/struct_mask_parameters.py +3 -1
  96. careamics/transforms/transform.py +10 -19
  97. careamics/transforms/tta.py +43 -29
  98. careamics/transforms/xy_flip.py +123 -0
  99. careamics/transforms/xy_random_rotate90.py +38 -5
  100. careamics/utils/base_enum.py +28 -0
  101. careamics/utils/path_utils.py +2 -0
  102. careamics/utils/ram.py +4 -2
  103. careamics/utils/receptive_field.py +93 -87
  104. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/METADATA +8 -6
  105. careamics-0.1.0rc7.dist-info/RECORD +130 -0
  106. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/WHEEL +1 -1
  107. careamics/config/noise_models.py +0 -162
  108. careamics/config/support/supported_extraction_strategies.py +0 -25
  109. careamics/config/transformations/nd_flip_model.py +0 -27
  110. careamics/lightning_prediction_loop.py +0 -116
  111. careamics/losses/noise_model_factory.py +0 -40
  112. careamics/losses/noise_models.py +0 -524
  113. careamics/prediction/__init__.py +0 -7
  114. careamics/prediction/stitch_prediction.py +0 -74
  115. careamics/transforms/nd_flip.py +0 -67
  116. careamics/utils/running_stats.py +0 -43
  117. careamics-0.1.0rc5.dist-info/RECORD +0 -111
  118. {careamics-0.1.0rc5.dist-info → careamics-0.1.0rc7.dist-info}/licenses/LICENSE +0 -0
@@ -13,8 +13,8 @@ from bioimageio.spec.model.v0_5 import (
13
13
  ChannelAxis,
14
14
  EnvironmentFileDescr,
15
15
  FileDescr,
16
+ FixedZeroMeanUnitVarianceAlongAxisKwargs,
16
17
  FixedZeroMeanUnitVarianceDescr,
17
- FixedZeroMeanUnitVarianceKwargs,
18
18
  Identifier,
19
19
  InputTensorDescr,
20
20
  ModelDescr,
@@ -134,44 +134,52 @@ def _create_inputs_ouputs(
134
134
  output_axes = _create_axes(output_array, data_config, channel_names, False)
135
135
 
136
136
  # mean and std
137
- assert data_config.mean is not None, "Mean cannot be None."
138
- assert data_config.std is not None, "Std cannot be None."
139
- mean = data_config.mean
140
- std = data_config.std
137
+ assert data_config.image_means is not None, "Mean cannot be None."
138
+ assert data_config.image_means is not None, "Std cannot be None."
139
+ means = data_config.image_means
140
+ stds = data_config.image_stds
141
141
 
142
142
  # and the mean and std required to invert the normalization
143
143
  # CAREamics denormalization: x = y * (std + eps) + mean
144
144
  # BMZ normalization : x = (y - mean') / (std' + eps)
145
145
  # to apply the BMZ normalization as a denormalization step, we need:
146
146
  eps = 1e-6
147
- inv_mean = -mean / (std + eps)
148
- inv_std = 1 / (std + eps) - eps
149
-
150
- # create input/output descriptions
151
- input_descr = InputTensorDescr(
152
- id=TensorId("input"),
153
- axes=input_axes,
154
- test_tensor=FileDescr(source=input_path),
155
- preprocessing=[
156
- FixedZeroMeanUnitVarianceDescr(
157
- kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std)
158
- )
159
- ],
160
- )
161
- output_descr = OutputTensorDescr(
162
- id=TensorId("prediction"),
163
- axes=output_axes,
164
- test_tensor=FileDescr(source=output_path),
165
- postprocessing=[
166
- FixedZeroMeanUnitVarianceDescr(
167
- kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization
168
- mean=inv_mean, std=inv_std
147
+ inv_means = []
148
+ inv_stds = []
149
+ if means and stds:
150
+ for mean, std in zip(means, stds):
151
+ inv_means.append(-mean / (std + eps))
152
+ inv_stds.append(1 / (std + eps) - eps)
153
+
154
+ # create input/output descriptions
155
+ input_descr = InputTensorDescr(
156
+ id=TensorId("input"),
157
+ axes=input_axes,
158
+ test_tensor=FileDescr(source=input_path),
159
+ preprocessing=[
160
+ FixedZeroMeanUnitVarianceDescr(
161
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs(
162
+ mean=means, std=stds, axis="channel"
163
+ )
169
164
  )
170
- )
171
- ],
172
- )
165
+ ],
166
+ )
167
+ output_descr = OutputTensorDescr(
168
+ id=TensorId("prediction"),
169
+ axes=output_axes,
170
+ test_tensor=FileDescr(source=output_path),
171
+ postprocessing=[
172
+ FixedZeroMeanUnitVarianceDescr(
173
+ kwargs=FixedZeroMeanUnitVarianceAlongAxisKwargs( # invert norm
174
+ mean=inv_means, std=inv_stds, axis="channel"
175
+ )
176
+ )
177
+ ],
178
+ )
173
179
 
174
- return input_descr, output_descr
180
+ return input_descr, output_descr
181
+ else:
182
+ raise ValueError("Mean and std cannot be None.")
175
183
 
176
184
 
177
185
  def create_model_description(
@@ -280,7 +288,7 @@ def create_model_description(
280
288
  "bioimageio": {
281
289
  "test_kwargs": {
282
290
  "pytorch_state_dict": {
283
- "decimals": 2, # ...so we relax the constraints on the decimals
291
+ "decimals": 0, # ...so we relax the constraints on the decimals
284
292
  }
285
293
  }
286
294
  }
@@ -104,9 +104,9 @@ def export_to_bmz(
104
104
  authors : List[dict]
105
105
  Authors of the model.
106
106
  input_array : np.ndarray
107
- Input array.
107
+ Input array, should not have been normalized.
108
108
  output_array : np.ndarray
109
- Output array.
109
+ Output array, should have been denormalized.
110
110
  channel_names : Optional[List[str]], optional
111
111
  Channel names, by default None.
112
112
  data_description : Optional[str], optional
@@ -178,7 +178,7 @@ def export_to_bmz(
178
178
  )
179
179
 
180
180
  # test model description
181
- summary: ValidationSummary = test_model(model_description, decimal=0)
181
+ summary: ValidationSummary = test_model(model_description, decimal=1)
182
182
  if summary.status == "failed":
183
183
  raise ValueError(f"Model description test failed: {summary}")
184
184
 
@@ -3,7 +3,7 @@
3
3
  from pathlib import Path
4
4
  from typing import Tuple, Union
5
5
 
6
- from torch import load
6
+ import torch
7
7
 
8
8
  from careamics.config import Configuration
9
9
  from careamics.lightning_module import CAREamicsModule
@@ -64,7 +64,10 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configura
64
64
  If the checkpoint file does not contain hyper parameters (configuration).
65
65
  """
66
66
  # load checkpoint
67
- checkpoint: dict = load(path)
67
+ # here we might run into issues between devices
68
+ # see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html
69
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
70
+ checkpoint: dict = torch.load(path, map_location=device)
68
71
 
69
72
  # attempt to load configuration
70
73
  try:
@@ -1,3 +1,5 @@
1
+ """Activations for CAREamics models."""
2
+
1
3
  from typing import Callable, Union
2
4
 
3
5
  import torch.nn as nn
@@ -162,6 +162,18 @@ def _unpack_kernel_size(
162
162
  """Unpack kernel_size to a tuple of ints.
163
163
 
164
164
  Inspired by Kornia implementation. TODO: link
165
+
166
+ Parameters
167
+ ----------
168
+ kernel_size : Union[Tuple[int, ...], int]
169
+ Kernel size.
170
+ dim : int
171
+ Number of dimensions.
172
+
173
+ Returns
174
+ -------
175
+ Tuple[int, ...]
176
+ Kernel size tuple.
165
177
  """
166
178
  if isinstance(kernel_size, int):
167
179
  kernel_dims = tuple([kernel_size for _ in range(dim)])
@@ -173,7 +185,20 @@ def _unpack_kernel_size(
173
185
  def _compute_zero_padding(
174
186
  kernel_size: Union[Tuple[int, ...], int], dim: int
175
187
  ) -> Tuple[int, ...]:
176
- """Utility function that computes zero padding tuple."""
188
+ """Utility function that computes zero padding tuple.
189
+
190
+ Parameters
191
+ ----------
192
+ kernel_size : Union[Tuple[int, ...], int]
193
+ Kernel size.
194
+ dim : int
195
+ Number of dimensions.
196
+
197
+ Returns
198
+ -------
199
+ Tuple[int, ...]
200
+ Zero padding tuple.
201
+ """
177
202
  kernel_dims = _unpack_kernel_size(kernel_size, dim)
178
203
  return tuple([(kd - 1) // 2 for kd in kernel_dims])
179
204
 
@@ -191,14 +216,19 @@ def get_pascal_kernel_1d(
191
216
 
192
217
  Parameters
193
218
  ----------
194
- kernel_size: height and width of the kernel.
195
- norm: if to normalize the kernel or not. Default: False.
196
- device: tensor device
197
- dtype: tensor dtype
219
+ kernel_size : int
220
+ Kernel size.
221
+ norm : bool
222
+ Normalize the kernel, by default False.
223
+ device : Optional[torch.device]
224
+ Device of the tensor, by default None.
225
+ dtype : Optional[torch.dtype]
226
+ Data type of the tensor, by default None.
198
227
 
199
228
  Returns
200
229
  -------
201
- kernel shaped as :math:`(kernel_size,)`
230
+ torch.Tensor
231
+ Pascal kernel.
202
232
 
203
233
  Examples
204
234
  --------
@@ -245,19 +275,28 @@ def _get_pascal_kernel_nd(
245
275
  ) -> torch.Tensor:
246
276
  """Generate pascal filter kernel by kernel size.
247
277
 
278
+ If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
279
+ otherwise the kernel will be shaped as kernel_size
280
+
248
281
  Inspired by Kornia implementation.
249
282
 
250
283
  Parameters
251
284
  ----------
252
- kernel_size: height and width of the kernel.
253
- norm: if to normalize the kernel or not. Default: True.
254
- device: tensor device
255
- dtype: tensor dtype
285
+ kernel_size : Union[Tuple[int, int], int]
286
+ Kernel size for the pascal kernel.
287
+ norm : bool
288
+ Normalize the kernel, by default True.
289
+ dim : int
290
+ Number of dimensions, by default 2.
291
+ device : Optional[torch.device]
292
+ Device of the tensor, by default None.
293
+ dtype : Optional[torch.dtype]
294
+ Data type of the tensor, by default None.
256
295
 
257
296
  Returns
258
297
  -------
259
- if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size)
260
- otherwise the kernel will be shaped as kernel_size
298
+ torch.Tensor
299
+ Pascal kernel.
261
300
 
262
301
  Examples
263
302
  --------
@@ -303,6 +342,24 @@ def _max_blur_pool_by_kernel2d(
303
342
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel.
304
343
 
305
344
  Inspired by Kornia implementation.
345
+
346
+ Parameters
347
+ ----------
348
+ x : torch.Tensor
349
+ Input tensor.
350
+ kernel : torch.Tensor
351
+ Kernel tensor.
352
+ stride : int
353
+ Stride.
354
+ max_pool_size : int
355
+ Maximum pool size.
356
+ ceil_mode : bool
357
+ Ceil mode, by default False. Set to True to match output size of conv2d.
358
+
359
+ Returns
360
+ -------
361
+ torch.Tensor
362
+ Output tensor.
306
363
  """
307
364
  # compute local maxima
308
365
  x = F.max_pool2d(
@@ -323,6 +380,24 @@ def _max_blur_pool_by_kernel3d(
323
380
  """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel.
324
381
 
325
382
  Inspired by Kornia implementation.
383
+
384
+ Parameters
385
+ ----------
386
+ x : torch.Tensor
387
+ Input tensor.
388
+ kernel : torch.Tensor
389
+ Kernel tensor.
390
+ stride : int
391
+ Stride.
392
+ max_pool_size : int
393
+ Maximum pool size.
394
+ ceil_mode : bool
395
+ Ceil mode, by default False. Set to True to match output size of conv2d.
396
+
397
+ Returns
398
+ -------
399
+ torch.Tensor
400
+ Output tensor.
326
401
  """
327
402
  # compute local maxima
328
403
  x = F.max_pool3d(
@@ -343,21 +418,16 @@ class MaxBlurPool(nn.Module):
343
418
 
344
419
  Parameters
345
420
  ----------
346
- dim: int
347
- Toggles between 2D and 3D
348
- kernel_size: Union[Tuple[int, int], int]
421
+ dim : int
422
+ Toggles between 2D and 3D.
423
+ kernel_size : Union[Tuple[int, int], int]
349
424
  Kernel size for max pooling.
350
- stride: int
425
+ stride : int
351
426
  Stride for pooling.
352
- max_pool_size: int
427
+ max_pool_size : int
353
428
  Max kernel size for max pooling.
354
- ceil_mode: bool
355
- Should be true to match output size of conv2d with same kernel size.
356
-
357
- Returns
358
- -------
359
- torch.Tensor
360
- The pooled and blurred tensor.
429
+ ceil_mode : bool
430
+ Ceil mode, by default False. Set to True to match output size of conv2d.
361
431
  """
362
432
 
363
433
  def __init__(
@@ -368,6 +438,21 @@ class MaxBlurPool(nn.Module):
368
438
  max_pool_size: int = 2,
369
439
  ceil_mode: bool = False,
370
440
  ) -> None:
441
+ """Constructor.
442
+
443
+ Parameters
444
+ ----------
445
+ dim : int
446
+ Dimension of the convolution.
447
+ kernel_size : Union[Tuple[int, int], int]
448
+ Kernel size for max pooling.
449
+ stride : int, optional
450
+ Stride, by default 2.
451
+ max_pool_size : int, optional
452
+ Maximum pool size, by default 2.
453
+ ceil_mode : bool, optional
454
+ Ceil mode, by default False. Set to True to match output size of conv2d.
455
+ """
371
456
  super().__init__()
372
457
  self.dim = dim
373
458
  self.kernel_size = kernel_size
@@ -377,7 +462,18 @@ class MaxBlurPool(nn.Module):
377
462
  self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim)
378
463
 
379
464
  def forward(self, x: torch.Tensor) -> torch.Tensor:
380
- """Forward pass of the function."""
465
+ """Forward pass of the function.
466
+
467
+ Parameters
468
+ ----------
469
+ x : torch.Tensor
470
+ Input tensor.
471
+
472
+ Returns
473
+ -------
474
+ torch.Tensor
475
+ Output tensor.
476
+ """
381
477
  self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype)
382
478
  if self.dim == 2:
383
479
  return _max_blur_pool_by_kernel2d(
File without changes