careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

@@ -1,3 +1,4 @@
1
+ """Training and validation Lightning data modules."""
1
2
  from pathlib import Path
2
3
  from typing import Any, Callable, Dict, List, Literal, Optional, Union
3
4
 
@@ -6,7 +7,7 @@ import pytorch_lightning as L
6
7
  from albumentations import Compose
7
8
  from torch.utils.data import DataLoader
8
9
 
9
- from careamics.config import DataModel
10
+ from careamics.config import DataConfig
10
11
  from careamics.config.data_model import TRANSFORMS_UNION
11
12
  from careamics.config.support import SupportedData
12
13
  from careamics.dataset.dataset_utils import (
@@ -28,9 +29,9 @@ DatasetType = Union[InMemoryDataset, PathIterableDataset]
28
29
  logger = get_logger(__name__)
29
30
 
30
31
 
31
- class CAREamicsWood(L.LightningDataModule):
32
+ class CAREamicsTrainData(L.LightningDataModule):
32
33
  """
33
- LightningDataModule for training and validation datasets.
34
+ CAREamics Ligthning training and validation data module.
34
35
 
35
36
  The data module can be used with Path, str or numpy arrays. In the case of
36
37
  numpy arrays, it loads and computes all the patches in memory. For Path and str
@@ -53,11 +54,70 @@ class CAREamicsWood(L.LightningDataModule):
53
54
 
54
55
  You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g.
55
56
  "*.czi") to filter the files extension using `extension_filter`.
57
+
58
+ Parameters
59
+ ----------
60
+ data_config : DataModel
61
+ Pydantic model for CAREamics data configuration.
62
+ train_data : Union[Path, str, np.ndarray]
63
+ Training data, can be a path to a folder, a file or a numpy array.
64
+ val_data : Optional[Union[Path, str, np.ndarray]], optional
65
+ Validation data, can be a path to a folder, a file or a numpy array, by
66
+ default None.
67
+ train_data_target : Optional[Union[Path, str, np.ndarray]], optional
68
+ Training target data, can be a path to a folder, a file or a numpy array, by
69
+ default None.
70
+ val_data_target : Optional[Union[Path, str, np.ndarray]], optional
71
+ Validation target data, can be a path to a folder, a file or a numpy array,
72
+ by default None.
73
+ read_source_func : Optional[Callable], optional
74
+ Function to read the source data, by default None. Only used for `custom`
75
+ data type (see DataModel).
76
+ extension_filter : str, optional
77
+ Filter for file extensions, by default "". Only used for `custom` data types
78
+ (see DataModel).
79
+ val_percentage : float, optional
80
+ Percentage of the training data to use for validation, by default 0.1. Only
81
+ used if `val_data` is None.
82
+ val_minimum_split : int, optional
83
+ Minimum number of patches or files to split from the training data for
84
+ validation, by default 5. Only used if `val_data` is None.
85
+ use_in_memory : bool, optional
86
+ Use in memory dataset if possible, by default True.
87
+
88
+ Attributes
89
+ ----------
90
+ data_config : DataModel
91
+ CAREamics data configuration.
92
+ data_type : SupportedData
93
+ Expected data type, one of "tiff", "array" or "custom".
94
+ batch_size : int
95
+ Batch size.
96
+ use_in_memory : bool
97
+ Whether to use in memory dataset if possible.
98
+ train_data : Union[Path, str, np.ndarray]
99
+ Training data.
100
+ val_data : Optional[Union[Path, str, np.ndarray]]
101
+ Validation data.
102
+ train_data_target : Optional[Union[Path, str, np.ndarray]]
103
+ Training target data.
104
+ val_data_target : Optional[Union[Path, str, np.ndarray]]
105
+ Validation target data.
106
+ val_percentage : float
107
+ Percentage of the training data to use for validation, if no validation data is
108
+ provided.
109
+ val_minimum_split : int
110
+ Minimum number of patches or files to split from the training data for
111
+ validation, if no validation data is provided.
112
+ read_source_func : Optional[Callable]
113
+ Function to read the source data, used if `data_type` is `custom`.
114
+ extension_filter : str
115
+ Filter for file extensions, used if `data_type` is `custom`.
56
116
  """
57
117
 
58
118
  def __init__(
59
119
  self,
60
- data_config: DataModel,
120
+ data_config: DataConfig,
61
121
  train_data: Union[Path, str, np.ndarray],
62
122
  val_data: Optional[Union[Path, str, np.ndarray]] = None,
63
123
  train_data_target: Optional[Union[Path, str, np.ndarray]] = None,
@@ -98,6 +158,8 @@ class CAREamicsWood(L.LightningDataModule):
98
158
  val_minimum_split : int, optional
99
159
  Minimum number of patches or files to split from the training data for
100
160
  validation, by default 5. Only used if `val_data` is None.
161
+ use_in_memory : bool, optional
162
+ Use in memory dataset if possible, by default True.
101
163
 
102
164
  Raises
103
165
  ------
@@ -128,25 +190,30 @@ class CAREamicsWood(L.LightningDataModule):
128
190
  if data_config.data_type == SupportedData.CUSTOM and read_source_func is None:
129
191
  raise ValueError(
130
192
  f"Data type {SupportedData.CUSTOM} is not allowed without "
131
- f"specifying a `read_source_func`."
193
+ f"specifying a `read_source_func` and an `extension_filer`."
132
194
  )
133
195
 
134
- # and that arrays are passed, if array type specified
135
- elif data_config.data_type == SupportedData.ARRAY and not isinstance(
136
- train_data, np.ndarray
196
+ # check correct input type
197
+ if (
198
+ isinstance(train_data, np.ndarray)
199
+ and data_config.data_type != SupportedData.ARRAY
137
200
  ):
138
201
  raise ValueError(
139
- f"Expected array input (see configuration.data.data_type), but got "
140
- f"{type(train_data)} instead."
202
+ f"Received a numpy array as input, but the data type was set to "
203
+ f"{data_config.data_type}. Set the data type in the configuration "
204
+ f"to {SupportedData.ARRAY} to train on numpy arrays."
141
205
  )
142
206
 
143
207
  # and that Path or str are passed, if tiff file type specified
144
- elif data_config.data_type == SupportedData.TIFF and (
145
- not isinstance(train_data, Path) and not isinstance(train_data, str)
208
+ elif (isinstance(train_data, Path) or isinstance(train_data, str)) and (
209
+ data_config.data_type != SupportedData.TIFF
210
+ and data_config.data_type != SupportedData.CUSTOM
146
211
  ):
147
212
  raise ValueError(
148
- f"Expected Path or str input (see configuration.data.data_type), "
149
- f"but got {type(train_data)} instead."
213
+ f"Received a path as input, but the data type was neither set to "
214
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
215
+ f"in the configuration to {SupportedData.TIFF} or "
216
+ f"{SupportedData.CUSTOM} to train on files."
150
217
  )
151
218
 
152
219
  # configuration
@@ -231,7 +298,15 @@ class CAREamicsWood(L.LightningDataModule):
231
298
  validate_source_target_files(self.val_files, self.val_target_files)
232
299
 
233
300
  def setup(self, *args: Any, **kwargs: Any) -> None:
234
- """Hook called at the beginning of fit, validate, or predict."""
301
+ """Hook called at the beginning of fit, validate, or predict.
302
+
303
+ Parameters
304
+ ----------
305
+ *args : Any
306
+ Unused.
307
+ **kwargs : Any
308
+ Unused.
309
+ """
235
310
  # if numpy array
236
311
  if self.data_type == SupportedData.ARRAY:
237
312
  # train dataset
@@ -353,9 +428,12 @@ class CAREamicsWood(L.LightningDataModule):
353
428
  )
354
429
 
355
430
 
356
- class CAREamicsTrainDataModule(CAREamicsWood):
431
+ class TrainingDataWrapper(CAREamicsTrainData):
357
432
  """
358
- LightningDataModule wrapper for training and validation datasets.
433
+ Wrapper around the CAREamics Lightning training data module.
434
+
435
+ This class is used to explicitely pass the parameters usually contained in a
436
+ `data_model` configuration.
359
437
 
360
438
  Since the lightning datamodule has no access to the model, make sure that the
361
439
  parameters passed to the datamodule are consistent with the model's requirements and
@@ -442,11 +520,11 @@ class CAREamicsTrainDataModule(CAREamicsWood):
442
520
 
443
521
  Examples
444
522
  --------
445
- Create a CAREamicsTrainDataModule with default transforms with a numpy array:
523
+ Create a TrainingDataWrapper with default transforms with a numpy array:
446
524
  >>> import numpy as np
447
- >>> from careamics import CAREamicsTrainDataModule
525
+ >>> from careamics import TrainingDataWrapper
448
526
  >>> my_array = np.arange(256).reshape(16, 16)
449
- >>> data_module = CAREamicsTrainDataModule(
527
+ >>> data_module = TrainingDataWrapper(
450
528
  ... train_data=my_array,
451
529
  ... data_type="array",
452
530
  ... patch_size=(8, 8),
@@ -457,12 +535,12 @@ class CAREamicsTrainDataModule(CAREamicsWood):
457
535
  For custom data types (those not supported by CAREamics), then one can pass a read
458
536
  function and a filter for the files extension:
459
537
  >>> import numpy as np
460
- >>> from careamics import CAREamicsTrainDataModule
538
+ >>> from careamics import TrainingDataWrapper
461
539
  >>>
462
540
  >>> def read_npy(path):
463
541
  ... return np.load(path)
464
542
  >>>
465
- >>> data_module = CAREamicsTrainDataModule(
543
+ >>> data_module = TrainingDataWrapper(
466
544
  ... train_data="path/to/data",
467
545
  ... data_type="custom",
468
546
  ... patch_size=(8, 8),
@@ -475,7 +553,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
475
553
  If you want to use a different set of transformations, you can pass a list of
476
554
  transforms:
477
555
  >>> import numpy as np
478
- >>> from careamics import CAREamicsTrainDataModule
556
+ >>> from careamics import TrainingDataWrapper
479
557
  >>> from careamics.config.support import SupportedTransform
480
558
  >>> my_array = np.arange(256).reshape(16, 16)
481
559
  >>> my_transforms = [
@@ -488,7 +566,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
488
566
  ... "name": SupportedTransform.N2V_MANIPULATE.value,
489
567
  ... }
490
568
  ... ]
491
- >>> data_module = CAREamicsTrainDataModule(
569
+ >>> data_module = TrainingDataWrapper(
492
570
  ... train_data=my_array,
493
571
  ... data_type="array",
494
572
  ... patch_size=(8, 8),
@@ -628,7 +706,7 @@ class CAREamicsTrainDataModule(CAREamicsWood):
628
706
  data_dict["transforms"] = transforms
629
707
 
630
708
  # validate configuration
631
- self.data_config = DataModel(**data_dict)
709
+ self.data_config = DataConfig(**data_dict)
632
710
 
633
711
  # N2V specific checks, N2V, structN2V, and transforms
634
712
  if (
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
3
3
  import pytorch_lightning as L
4
4
  from torch import Tensor, nn
5
5
 
6
- from careamics.config import AlgorithmModel
6
+ from careamics.config import AlgorithmConfig
7
7
  from careamics.config.support import (
8
8
  SupportedAlgorithm,
9
9
  SupportedArchitecture,
@@ -17,7 +17,7 @@ from careamics.transforms import Denormalize, ImageRestorationTTA
17
17
  from careamics.utils.torch_utils import get_optimizer, get_scheduler
18
18
 
19
19
 
20
- class CAREamicsKiln(L.LightningModule):
20
+ class CAREamicsModule(L.LightningModule):
21
21
  """
22
22
  CAREamics Lightning module.
23
23
 
@@ -38,7 +38,7 @@ class CAREamicsKiln(L.LightningModule):
38
38
  Learning rate scheduler name.
39
39
  """
40
40
 
41
- def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None:
41
+ def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None:
42
42
  """
43
43
  CAREamics Lightning module.
44
44
 
@@ -53,7 +53,7 @@ class CAREamicsKiln(L.LightningModule):
53
53
  super().__init__()
54
54
  # if loading from a checkpoint, AlgorithmModel needs to be instantiated
55
55
  if isinstance(algorithm_config, dict):
56
- algorithm_config = AlgorithmModel(**algorithm_config)
56
+ algorithm_config = AlgorithmConfig(**algorithm_config)
57
57
 
58
58
  # create model and loss function
59
59
  self.model: nn.Module = model_factory(algorithm_config.model)
@@ -192,7 +192,7 @@ class CAREamicsKiln(L.LightningModule):
192
192
  }
193
193
 
194
194
 
195
- class CAREamicsModule(CAREamicsKiln):
195
+ class CAREamicsModuleWrapper(CAREamicsModule):
196
196
  """Class defining the API for CAREamics Lightning layer.
197
197
 
198
198
  This class exposes parameters used to create an AlgorithmModel instance, triggering
@@ -287,6 +287,6 @@ class CAREamicsModule(CAREamicsKiln):
287
287
  algorithm_configuration["model"] = model_configuration
288
288
 
289
289
  # call the parent init using an AlgorithmModel instance
290
- super().__init__(AlgorithmModel(**algorithm_configuration))
290
+ super().__init__(AlgorithmConfig(**algorithm_configuration))
291
291
 
292
292
  # TODO add load_from_checkpoint wrapper
@@ -1,3 +1,4 @@
1
+ """Prediction Lightning data modules."""
1
2
  from pathlib import Path
2
3
  from typing import Any, Callable, List, Literal, Optional, Tuple, Union
3
4
 
@@ -7,7 +8,7 @@ from albumentations import Compose
7
8
  from torch.utils.data import DataLoader
8
9
  from torch.utils.data.dataloader import default_collate
9
10
 
10
- from careamics.config import InferenceModel
11
+ from careamics.config import InferenceConfig
11
12
  from careamics.config.support import SupportedData
12
13
  from careamics.config.tile_information import TileInformation
13
14
  from careamics.dataset.dataset_utils import (
@@ -62,9 +63,9 @@ def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any:
62
63
  return default_collate(new_batch)
63
64
 
64
65
 
65
- class CAREamicsClay(L.LightningDataModule):
66
+ class CAREamicsPredictData(L.LightningDataModule):
66
67
  """
67
- LightningDataModule for prediction dataset.
68
+ CAREamics Lightning prediction data module.
68
69
 
69
70
  The data module can be used with Path, str or numpy arrays. The data can be either
70
71
  a folder containing images or a single file.
@@ -79,7 +80,7 @@ class CAREamicsClay(L.LightningDataModule):
79
80
 
80
81
  Parameters
81
82
  ----------
82
- prediction_config : InferenceModel
83
+ pred_config : InferenceModel
83
84
  Pydantic model for CAREamics prediction configuration.
84
85
  pred_data : Union[Path, str, np.ndarray]
85
86
  Prediction data, can be a path to a folder, a file or a numpy array.
@@ -93,7 +94,7 @@ class CAREamicsClay(L.LightningDataModule):
93
94
 
94
95
  def __init__(
95
96
  self,
96
- prediction_config: InferenceModel,
97
+ pred_config: InferenceConfig,
97
98
  pred_data: Union[Path, str, np.ndarray],
98
99
  read_source_func: Optional[Callable] = None,
99
100
  extension_filter: str = "",
@@ -115,7 +116,7 @@ class CAREamicsClay(L.LightningDataModule):
115
116
 
116
117
  Parameters
117
118
  ----------
118
- prediction_config : InferenceModel
119
+ pred_config : InferenceModel
119
120
  Pydantic model for CAREamics prediction configuration.
120
121
  pred_data : Union[Path, str, np.ndarray]
121
122
  Prediction data, can be a path to a folder, a file or a numpy array.
@@ -142,51 +143,53 @@ class CAREamicsClay(L.LightningDataModule):
142
143
  super().__init__()
143
144
 
144
145
  # check that a read source function is provided for custom types
145
- if (
146
- prediction_config.data_type == SupportedData.CUSTOM
147
- and read_source_func is None
148
- ):
146
+ if pred_config.data_type == SupportedData.CUSTOM and read_source_func is None:
149
147
  raise ValueError(
150
148
  f"Data type {SupportedData.CUSTOM} is not allowed without "
151
- f"specifying a `read_source_func`."
149
+ f"specifying a `read_source_func` and an `extension_filer`."
152
150
  )
153
151
 
154
- # and that arrays are passed, if array type specified
155
- elif prediction_config.data_type == SupportedData.ARRAY and not isinstance(
156
- pred_data, np.ndarray
152
+ # check correct input type
153
+ if (
154
+ isinstance(pred_data, np.ndarray)
155
+ and pred_config.data_type != SupportedData.ARRAY
157
156
  ):
158
157
  raise ValueError(
159
- f"Expected array input (see configuration.data.data_type), but got "
160
- f"{type(pred_data)} instead."
158
+ f"Received a numpy array as input, but the data type was set to "
159
+ f"{pred_config.data_type}. Set the data type "
160
+ f"to {SupportedData.ARRAY} to predict on numpy arrays."
161
161
  )
162
162
 
163
163
  # and that Path or str are passed, if tiff file type specified
164
- elif prediction_config.data_type == SupportedData.TIFF and not (
165
- isinstance(pred_data, Path) or isinstance(pred_data, str)
164
+ elif (isinstance(pred_data, Path) or isinstance(pred_config, str)) and (
165
+ pred_config.data_type != SupportedData.TIFF
166
+ and pred_config.data_type != SupportedData.CUSTOM
166
167
  ):
167
168
  raise ValueError(
168
- f"Expected Path or str input (see configuration.data.data_type), "
169
- f"but got {type(pred_data)} instead."
169
+ f"Received a path as input, but the data type was neither set to "
170
+ f"{SupportedData.TIFF} nor {SupportedData.CUSTOM}. Set the data type "
171
+ f" to {SupportedData.TIFF} or "
172
+ f"{SupportedData.CUSTOM} to predict on files."
170
173
  )
171
174
 
172
175
  # configuration data
173
- self.prediction_config = prediction_config
174
- self.data_type = prediction_config.data_type
175
- self.batch_size = prediction_config.batch_size
176
+ self.prediction_config = pred_config
177
+ self.data_type = pred_config.data_type
178
+ self.batch_size = pred_config.batch_size
176
179
  self.dataloader_params = dataloader_params
177
180
 
178
181
  self.pred_data = pred_data
179
- self.tile_size = prediction_config.tile_size
180
- self.tile_overlap = prediction_config.tile_overlap
182
+ self.tile_size = pred_config.tile_size
183
+ self.tile_overlap = pred_config.tile_overlap
181
184
 
182
185
  # read source function
183
- if prediction_config.data_type == SupportedData.CUSTOM:
186
+ if pred_config.data_type == SupportedData.CUSTOM:
184
187
  # mypy check
185
188
  assert read_source_func is not None
186
189
 
187
190
  self.read_source_func: Callable = read_source_func
188
- elif prediction_config.data_type != SupportedData.ARRAY:
189
- self.read_source_func = get_read_func(prediction_config.data_type)
191
+ elif pred_config.data_type != SupportedData.ARRAY:
192
+ self.read_source_func = get_read_func(pred_config.data_type)
190
193
 
191
194
  self.extension_filter = extension_filter
192
195
 
@@ -238,9 +241,12 @@ class CAREamicsClay(L.LightningDataModule):
238
241
  ) # TODO check workers are used
239
242
 
240
243
 
241
- class CAREamicsPredictDataModule(CAREamicsClay):
244
+ class PredictDataWrapper(CAREamicsPredictData):
242
245
  """
243
- LightningDataModule wrapper of an inference dataset.
246
+ Wrapper around the CAREamics inference Lightning data module.
247
+
248
+ This class is used to explicitely pass the parameters usually contained in a
249
+ `inference_model` configuration.
244
250
 
245
251
  Since the lightning datamodule has no access to the model, make sure that the
246
252
  parameters passed to the datamodule are consistent with the model's requirements
@@ -329,6 +335,12 @@ class CAREamicsPredictDataModule(CAREamicsClay):
329
335
  Prediction data.
330
336
  data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
331
337
  Data type, see `SupportedData` for available options.
338
+ mean : float
339
+ Mean value for normalization, only used if Normalization is defined in the
340
+ transforms.
341
+ std : float
342
+ Standard deviation value for normalization, only used if Normalization is
343
+ defined in the transform.
332
344
  tile_size : List[int]
333
345
  Tile size, 2D or 3D tile size.
334
346
  tile_overlap : List[int]
@@ -339,12 +351,6 @@ class CAREamicsPredictDataModule(CAREamicsClay):
339
351
  Batch size.
340
352
  tta_transforms : bool, optional
341
353
  Use test time augmentation, by default True.
342
- mean : Optional[float], optional
343
- Mean value for normalization, only used if Normalization is defined, by
344
- default None.
345
- std : Optional[float], optional
346
- Standard deviation value for normalization, only used if Normalization is
347
- defined, by default None.
348
354
  transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional
349
355
  List of transforms to apply to prediction patches. If None, default
350
356
  transforms are applied.
@@ -374,7 +380,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
374
380
  prediction_dict["transforms"] = transforms
375
381
 
376
382
  # validate configuration
377
- self.prediction_config = InferenceModel(**prediction_dict)
383
+ self.prediction_config = InferenceConfig(**prediction_dict)
378
384
 
379
385
  # sanity check on the dataloader parameters
380
386
  if "batch_size" in dataloader_params:
@@ -382,7 +388,7 @@ class CAREamicsPredictDataModule(CAREamicsClay):
382
388
  del dataloader_params["batch_size"]
383
389
 
384
390
  super().__init__(
385
- prediction_config=self.prediction_config,
391
+ pred_config=self.prediction_config,
386
392
  pred_data=pred_data,
387
393
  read_source_func=read_source_func,
388
394
  extension_filter=extension_filter,
@@ -26,14 +26,14 @@ from bioimageio.spec.model.v0_5 import (
26
26
  WeightsDescr,
27
27
  )
28
28
 
29
- from careamics.config import Configuration, DataModel
29
+ from careamics.config import Configuration, DataConfig
30
30
 
31
31
  from ._readme_factory import readme_factory
32
32
 
33
33
 
34
34
  def _create_axes(
35
35
  array: np.ndarray,
36
- data_config: DataModel,
36
+ data_config: DataConfig,
37
37
  channel_names: Optional[List[str]] = None,
38
38
  is_input: bool = True,
39
39
  ) -> List[AxisBase]:
@@ -100,7 +100,7 @@ def _create_axes(
100
100
  def _create_inputs_ouputs(
101
101
  input_array: np.ndarray,
102
102
  output_array: np.ndarray,
103
- data_config: DataModel,
103
+ data_config: DataConfig,
104
104
  input_path: Union[Path, str],
105
105
  output_path: Union[Path, str],
106
106
  channel_names: Optional[List[str]] = None,
@@ -11,7 +11,7 @@ from torch import __version__, load, save
11
11
 
12
12
  from careamics.config import Configuration, load_configuration, save_configuration
13
13
  from careamics.config.support import SupportedArchitecture
14
- from careamics.lightning_module import CAREamicsKiln
14
+ from careamics.lightning_module import CAREamicsModule
15
15
 
16
16
  from .bioimage import (
17
17
  create_env_text,
@@ -21,7 +21,7 @@ from .bioimage import (
21
21
  )
22
22
 
23
23
 
24
- def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
24
+ def _export_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> Path:
25
25
  """
26
26
  Export the model state dictionary to a file.
27
27
 
@@ -51,7 +51,7 @@ def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path:
51
51
  return path
52
52
 
53
53
 
54
- def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
54
+ def _load_state_dict(model: CAREamicsModule, path: Union[Path, str]) -> None:
55
55
  """
56
56
  Load a model from a state dictionary.
57
57
 
@@ -73,7 +73,7 @@ def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None:
73
73
 
74
74
  # TODO break down in subfunctions
75
75
  def export_to_bmz(
76
- model: CAREamicsKiln,
76
+ model: CAREamicsModule,
77
77
  config: Configuration,
78
78
  path: Union[Path, str],
79
79
  name: str,
@@ -185,7 +185,7 @@ def export_to_bmz(
185
185
  save_bioimageio_package(model_description, output_path=path)
186
186
 
187
187
 
188
- def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
188
+ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
189
189
  """Load a model from a BioImage Model Zoo archive.
190
190
 
191
191
  Parameters
@@ -223,7 +223,7 @@ def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]
223
223
  config = load_configuration(config_path)
224
224
 
225
225
  # create careamics lightning module
226
- model = CAREamicsKiln(algorithm_config=config.algorithm_config)
226
+ model = CAREamicsModule(algorithm_config=config.algorithm_config)
227
227
 
228
228
  # load model state dictionary
229
229
  _load_state_dict(model, weights_path)
@@ -6,12 +6,12 @@ from typing import Tuple, Union
6
6
  from torch import load
7
7
 
8
8
  from careamics.config import Configuration
9
- from careamics.lightning_module import CAREamicsKiln
9
+ from careamics.lightning_module import CAREamicsModule
10
10
  from careamics.model_io.bmz_io import load_from_bmz
11
11
  from careamics.utils import check_path_exists
12
12
 
13
13
 
14
- def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
14
+ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
15
15
  """
16
16
  Load a pretrained model from a checkpoint or a BioImage Model Zoo model.
17
17
 
@@ -44,7 +44,7 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio
44
44
  )
45
45
 
46
46
 
47
- def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]:
47
+ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsModule, Configuration]:
48
48
  """
49
49
  Load a model from a checkpoint and return both model and configuration.
50
50
 
@@ -75,6 +75,6 @@ def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configurati
75
75
  f"checkpoint: {checkpoint.keys()}"
76
76
  ) from e
77
77
 
78
- model = CAREamicsKiln.load_from_checkpoint(path)
78
+ model = CAREamicsModule.load_from_checkpoint(path)
79
79
 
80
80
  return model, Configuration(**cfg_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: careamics
3
- Version: 0.1.0rc3
3
+ Version: 0.1.0rc4
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics