careamics 0.1.0rc3__py3-none-any.whl → 0.1.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +8 -6
- careamics/careamist.py +18 -18
- careamics/config/__init__.py +12 -8
- careamics/config/algorithm_model.py +5 -5
- careamics/config/configuration_example.py +89 -0
- careamics/config/configuration_factory.py +187 -50
- careamics/config/configuration_model.py +8 -7
- careamics/config/data_model.py +3 -3
- careamics/config/inference_model.py +1 -1
- careamics/config/support/supported_optimizers.py +3 -3
- careamics/config/training_model.py +1 -1
- careamics/config/transformations/n2v_manipulate_model.py +1 -1
- careamics/config/transformations/nd_flip_model.py +1 -1
- careamics/config/transformations/normalize_model.py +1 -1
- careamics/config/transformations/xy_random_rotate90_model.py +1 -1
- careamics/dataset/in_memory_dataset.py +3 -3
- careamics/dataset/iterable_dataset.py +3 -3
- careamics/lightning_datamodule.py +103 -25
- careamics/lightning_module.py +6 -6
- careamics/lightning_prediction_datamodule.py +44 -38
- careamics/model_io/bioimage/model_description.py +3 -3
- careamics/model_io/bmz_io.py +6 -6
- careamics/model_io/model_io_utils.py +4 -4
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/METADATA +1 -1
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/RECORD +27 -26
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/WHEEL +0 -0
- {careamics-0.1.0rc3.dist-info → careamics-0.1.0rc4.dist-info}/licenses/LICENSE +0 -0
careamics/__init__.py
CHANGED
|
@@ -9,16 +9,18 @@ except PackageNotFoundError:
|
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"CAREamist",
|
|
12
|
-
"
|
|
12
|
+
"CAREamicsModuleWrapper",
|
|
13
|
+
"CAREamicsPredictData",
|
|
14
|
+
"CAREamicsTrainData",
|
|
13
15
|
"Configuration",
|
|
14
16
|
"load_configuration",
|
|
15
17
|
"save_configuration",
|
|
16
|
-
"
|
|
17
|
-
"
|
|
18
|
+
"TrainingDataWrapper",
|
|
19
|
+
"PredictDataWrapper",
|
|
18
20
|
]
|
|
19
21
|
|
|
20
22
|
from .careamist import CAREamist
|
|
21
23
|
from .config import Configuration, load_configuration, save_configuration
|
|
22
|
-
from .lightning_datamodule import
|
|
23
|
-
from .lightning_module import
|
|
24
|
-
from .lightning_prediction_datamodule import
|
|
24
|
+
from .lightning_datamodule import CAREamicsTrainData, TrainingDataWrapper
|
|
25
|
+
from .lightning_module import CAREamicsModuleWrapper
|
|
26
|
+
from .lightning_prediction_datamodule import CAREamicsPredictData, PredictDataWrapper
|
careamics/careamist.py
CHANGED
|
@@ -20,9 +20,9 @@ from careamics.config import (
|
|
|
20
20
|
)
|
|
21
21
|
from careamics.config.inference_model import TRANSFORMS_UNION
|
|
22
22
|
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
|
|
23
|
-
from careamics.lightning_datamodule import
|
|
24
|
-
from careamics.lightning_module import
|
|
25
|
-
from careamics.lightning_prediction_datamodule import
|
|
23
|
+
from careamics.lightning_datamodule import CAREamicsTrainData
|
|
24
|
+
from careamics.lightning_module import CAREamicsModule
|
|
25
|
+
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
|
|
26
26
|
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
|
|
27
27
|
from careamics.model_io import export_to_bmz, load_pretrained
|
|
28
28
|
from careamics.utils import check_path_exists, get_logger
|
|
@@ -140,7 +140,7 @@ class CAREamist:
|
|
|
140
140
|
self.cfg = source
|
|
141
141
|
|
|
142
142
|
# instantiate model
|
|
143
|
-
self.model =
|
|
143
|
+
self.model = CAREamicsModule(
|
|
144
144
|
algorithm_config=self.cfg.algorithm_config,
|
|
145
145
|
)
|
|
146
146
|
|
|
@@ -156,7 +156,7 @@ class CAREamist:
|
|
|
156
156
|
self.cfg = load_configuration(source)
|
|
157
157
|
|
|
158
158
|
# instantiate model
|
|
159
|
-
self.model =
|
|
159
|
+
self.model = CAREamicsModule(
|
|
160
160
|
algorithm_config=self.cfg.algorithm_config,
|
|
161
161
|
)
|
|
162
162
|
|
|
@@ -193,8 +193,8 @@ class CAREamist:
|
|
|
193
193
|
self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer)
|
|
194
194
|
|
|
195
195
|
# place holder for the datamodules
|
|
196
|
-
self.train_datamodule: Optional[
|
|
197
|
-
self.pred_datamodule: Optional[
|
|
196
|
+
self.train_datamodule: Optional[CAREamicsTrainData] = None
|
|
197
|
+
self.pred_datamodule: Optional[CAREamicsPredictData] = None
|
|
198
198
|
|
|
199
199
|
def _define_callbacks(self) -> List[Callback]:
|
|
200
200
|
"""
|
|
@@ -227,7 +227,7 @@ class CAREamist:
|
|
|
227
227
|
def train(
|
|
228
228
|
self,
|
|
229
229
|
*,
|
|
230
|
-
datamodule: Optional[
|
|
230
|
+
datamodule: Optional[CAREamicsTrainData] = None,
|
|
231
231
|
train_source: Optional[Union[Path, str, np.ndarray]] = None,
|
|
232
232
|
val_source: Optional[Union[Path, str, np.ndarray]] = None,
|
|
233
233
|
train_target: Optional[Union[Path, str, np.ndarray]] = None,
|
|
@@ -360,7 +360,7 @@ class CAREamist:
|
|
|
360
360
|
f"instance (got {type(train_source)})."
|
|
361
361
|
)
|
|
362
362
|
|
|
363
|
-
def _train_on_datamodule(self, datamodule:
|
|
363
|
+
def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
|
|
364
364
|
"""
|
|
365
365
|
Train the model on the provided datamodule.
|
|
366
366
|
|
|
@@ -402,7 +402,7 @@ class CAREamist:
|
|
|
402
402
|
Minimum number of patches to use for validation, by default 5.
|
|
403
403
|
"""
|
|
404
404
|
# create datamodule
|
|
405
|
-
datamodule =
|
|
405
|
+
datamodule = CAREamicsTrainData(
|
|
406
406
|
data_config=self.cfg.data_config,
|
|
407
407
|
train_data=train_data,
|
|
408
408
|
val_data=val_data,
|
|
@@ -458,7 +458,7 @@ class CAREamist:
|
|
|
458
458
|
path_to_val_target = check_path_exists(path_to_val_target)
|
|
459
459
|
|
|
460
460
|
# create datamodule
|
|
461
|
-
datamodule =
|
|
461
|
+
datamodule = CAREamicsTrainData(
|
|
462
462
|
data_config=self.cfg.data_config,
|
|
463
463
|
train_data=path_to_train_data,
|
|
464
464
|
val_data=path_to_val_data,
|
|
@@ -475,7 +475,7 @@ class CAREamist:
|
|
|
475
475
|
@overload
|
|
476
476
|
def predict( # numpydoc ignore=GL08
|
|
477
477
|
self,
|
|
478
|
-
source:
|
|
478
|
+
source: CAREamicsPredictData,
|
|
479
479
|
*,
|
|
480
480
|
checkpoint: Optional[Literal["best", "last"]] = None,
|
|
481
481
|
) -> Union[list, np.ndarray]:
|
|
@@ -519,7 +519,7 @@ class CAREamist:
|
|
|
519
519
|
|
|
520
520
|
def predict(
|
|
521
521
|
self,
|
|
522
|
-
source: Union[
|
|
522
|
+
source: Union[CAREamicsPredictData, Path, str, np.ndarray],
|
|
523
523
|
*,
|
|
524
524
|
batch_size: int = 1,
|
|
525
525
|
tile_size: Optional[Tuple[int, ...]] = None,
|
|
@@ -587,7 +587,7 @@ class CAREamist:
|
|
|
587
587
|
ValueError
|
|
588
588
|
If the input is not a CAREamicsClay instance, a path or a numpy array.
|
|
589
589
|
"""
|
|
590
|
-
if isinstance(source,
|
|
590
|
+
if isinstance(source, CAREamicsPredictData):
|
|
591
591
|
# record datamodule
|
|
592
592
|
self.pred_datamodule = source
|
|
593
593
|
|
|
@@ -623,8 +623,8 @@ class CAREamist:
|
|
|
623
623
|
source_path = check_path_exists(source)
|
|
624
624
|
|
|
625
625
|
# create datamodule
|
|
626
|
-
datamodule =
|
|
627
|
-
|
|
626
|
+
datamodule = CAREamicsPredictData(
|
|
627
|
+
pred_config=prediction_config,
|
|
628
628
|
pred_data=source_path,
|
|
629
629
|
read_source_func=read_source_func,
|
|
630
630
|
extension_filter=extension_filter,
|
|
@@ -640,8 +640,8 @@ class CAREamist:
|
|
|
640
640
|
|
|
641
641
|
elif isinstance(source, np.ndarray):
|
|
642
642
|
# create datamodule
|
|
643
|
-
datamodule =
|
|
644
|
-
|
|
643
|
+
datamodule = CAREamicsPredictData(
|
|
644
|
+
pred_config=prediction_config,
|
|
645
645
|
pred_data=source,
|
|
646
646
|
dataloader_params=dataloader_params,
|
|
647
647
|
)
|
careamics/config/__init__.py
CHANGED
|
@@ -2,15 +2,17 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
-
"
|
|
6
|
-
"
|
|
5
|
+
"AlgorithmConfig",
|
|
6
|
+
"DataConfig",
|
|
7
7
|
"Configuration",
|
|
8
8
|
"CheckpointModel",
|
|
9
|
-
"
|
|
9
|
+
"InferenceConfig",
|
|
10
10
|
"load_configuration",
|
|
11
11
|
"save_configuration",
|
|
12
|
-
"
|
|
12
|
+
"TrainingConfig",
|
|
13
13
|
"create_n2v_configuration",
|
|
14
|
+
"create_n2n_configuration",
|
|
15
|
+
"create_care_configuration",
|
|
14
16
|
"register_model",
|
|
15
17
|
"CustomModel",
|
|
16
18
|
"create_inference_configuration",
|
|
@@ -18,11 +20,13 @@ __all__ = [
|
|
|
18
20
|
"ConfigurationInformation",
|
|
19
21
|
]
|
|
20
22
|
|
|
21
|
-
from .algorithm_model import
|
|
23
|
+
from .algorithm_model import AlgorithmConfig
|
|
22
24
|
from .architectures import CustomModel, clear_custom_models, register_model
|
|
23
25
|
from .callback_model import CheckpointModel
|
|
24
26
|
from .configuration_factory import (
|
|
27
|
+
create_care_configuration,
|
|
25
28
|
create_inference_configuration,
|
|
29
|
+
create_n2n_configuration,
|
|
26
30
|
create_n2v_configuration,
|
|
27
31
|
)
|
|
28
32
|
from .configuration_model import (
|
|
@@ -30,6 +34,6 @@ from .configuration_model import (
|
|
|
30
34
|
load_configuration,
|
|
31
35
|
save_configuration,
|
|
32
36
|
)
|
|
33
|
-
from .data_model import
|
|
34
|
-
from .inference_model import
|
|
35
|
-
from .training_model import
|
|
37
|
+
from .data_model import DataConfig
|
|
38
|
+
from .inference_model import InferenceConfig
|
|
39
|
+
from .training_model import TrainingConfig
|
|
@@ -10,7 +10,7 @@ from .architectures import CustomModel, UNetModel, VAEModel
|
|
|
10
10
|
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
13
|
+
class AlgorithmConfig(BaseModel):
|
|
14
14
|
"""Algorithm configuration.
|
|
15
15
|
|
|
16
16
|
This Pydantic model validates the parameters governing the components of the
|
|
@@ -45,7 +45,7 @@ class AlgorithmModel(BaseModel):
|
|
|
45
45
|
Examples
|
|
46
46
|
--------
|
|
47
47
|
Minimum example:
|
|
48
|
-
>>> from careamics.config import
|
|
48
|
+
>>> from careamics.config import AlgorithmConfig
|
|
49
49
|
>>> config_dict = {
|
|
50
50
|
... "algorithm": "n2v",
|
|
51
51
|
... "loss": "n2v",
|
|
@@ -53,11 +53,11 @@ class AlgorithmModel(BaseModel):
|
|
|
53
53
|
... "architecture": "UNet",
|
|
54
54
|
... }
|
|
55
55
|
... }
|
|
56
|
-
>>> config =
|
|
56
|
+
>>> config = AlgorithmConfig(**config_dict)
|
|
57
57
|
|
|
58
58
|
Using a custom model:
|
|
59
59
|
>>> from torch import nn, ones
|
|
60
|
-
>>> from careamics.config import
|
|
60
|
+
>>> from careamics.config import AlgorithmConfig, register_model
|
|
61
61
|
...
|
|
62
62
|
>>> @register_model(name="linear_model")
|
|
63
63
|
... class LinearModel(nn.Module):
|
|
@@ -80,7 +80,7 @@ class AlgorithmModel(BaseModel):
|
|
|
80
80
|
... "out_features": 5,
|
|
81
81
|
... }
|
|
82
82
|
... }
|
|
83
|
-
>>> config =
|
|
83
|
+
>>> config = AlgorithmConfig(**config_dict)
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
86
|
# Pydantic class configuration
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from .algorithm_model import AlgorithmConfig
|
|
2
|
+
from .architectures import UNetModel
|
|
3
|
+
from .configuration_model import Configuration
|
|
4
|
+
from .data_model import DataConfig
|
|
5
|
+
from .optimizer_models import LrSchedulerModel, OptimizerModel
|
|
6
|
+
from .support import (
|
|
7
|
+
SupportedActivation,
|
|
8
|
+
SupportedAlgorithm,
|
|
9
|
+
SupportedArchitecture,
|
|
10
|
+
SupportedData,
|
|
11
|
+
SupportedLogger,
|
|
12
|
+
SupportedLoss,
|
|
13
|
+
SupportedOptimizer,
|
|
14
|
+
SupportedPixelManipulation,
|
|
15
|
+
SupportedScheduler,
|
|
16
|
+
SupportedTransform,
|
|
17
|
+
)
|
|
18
|
+
from .training_model import TrainingConfig
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def full_configuration_example() -> Configuration:
|
|
22
|
+
"""Returns a dictionnary representing a full configuration example.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
Configuration
|
|
27
|
+
Full configuration example.
|
|
28
|
+
"""
|
|
29
|
+
experiment_name = "Full example"
|
|
30
|
+
algorithm_model = AlgorithmConfig(
|
|
31
|
+
algorithm=SupportedAlgorithm.N2V.value,
|
|
32
|
+
loss=SupportedLoss.N2V.value,
|
|
33
|
+
model=UNetModel(
|
|
34
|
+
architecture=SupportedArchitecture.UNET.value,
|
|
35
|
+
in_channels=1,
|
|
36
|
+
num_classes=1,
|
|
37
|
+
depth=2,
|
|
38
|
+
num_channels_init=32,
|
|
39
|
+
final_activation=SupportedActivation.NONE.value,
|
|
40
|
+
n2v2=True,
|
|
41
|
+
),
|
|
42
|
+
optimizer=OptimizerModel(
|
|
43
|
+
name=SupportedOptimizer.ADAM.value, parameters={"lr": 0.0001}
|
|
44
|
+
),
|
|
45
|
+
lr_scheduler=LrSchedulerModel(
|
|
46
|
+
name=SupportedScheduler.REDUCE_LR_ON_PLATEAU.value,
|
|
47
|
+
),
|
|
48
|
+
)
|
|
49
|
+
data_model = DataConfig(
|
|
50
|
+
data_type=SupportedData.ARRAY.value,
|
|
51
|
+
patch_size=(256, 256),
|
|
52
|
+
batch_size=8,
|
|
53
|
+
axes="YX",
|
|
54
|
+
transforms=[
|
|
55
|
+
{
|
|
56
|
+
"name": SupportedTransform.NORMALIZE.value,
|
|
57
|
+
},
|
|
58
|
+
{
|
|
59
|
+
"name": SupportedTransform.NDFLIP.value,
|
|
60
|
+
"is_3D": False,
|
|
61
|
+
},
|
|
62
|
+
{
|
|
63
|
+
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
|
|
64
|
+
"is_3D": False,
|
|
65
|
+
},
|
|
66
|
+
{
|
|
67
|
+
"name": SupportedTransform.N2V_MANIPULATE.value,
|
|
68
|
+
"roi_size": 11,
|
|
69
|
+
"masked_pixel_percentage": 0.2,
|
|
70
|
+
"strategy": SupportedPixelManipulation.MEDIAN.value,
|
|
71
|
+
},
|
|
72
|
+
],
|
|
73
|
+
mean=0.485,
|
|
74
|
+
std=0.229,
|
|
75
|
+
dataloader_params={
|
|
76
|
+
"num_workers": 4,
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
training_model = TrainingConfig(
|
|
80
|
+
num_epochs=30,
|
|
81
|
+
logger=SupportedLogger.WANDB.value,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return Configuration(
|
|
85
|
+
experiment_name=experiment_name,
|
|
86
|
+
algorithm_config=algorithm_model,
|
|
87
|
+
data_config=data_model,
|
|
88
|
+
training_config=training_model,
|
|
89
|
+
)
|