careamics 0.1.0rc1__py3-none-any.whl → 0.1.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/bioimage/__init__.py +3 -3
- careamics/bioimage/io.py +82 -171
- careamics/bioimage/rdf.py +105 -0
- careamics/config/config.py +4 -3
- careamics/config/data.py +2 -2
- careamics/dataset/dataset_utils.py +1 -5
- careamics/dataset/in_memory_dataset.py +3 -2
- careamics/dataset/patching.py +5 -6
- careamics/dataset/prepare_dataset.py +4 -3
- careamics/dataset/tiff_dataset.py +4 -3
- careamics/engine.py +170 -110
- careamics/losses/loss_factory.py +3 -2
- careamics/models/model_factory.py +30 -11
- careamics/prediction/prediction_utils.py +16 -12
- careamics/utils/__init__.py +3 -4
- careamics/utils/torch_utils.py +61 -65
- careamics/utils/validators.py +34 -20
- careamics/utils/wandb.py +1 -1
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc2.dist-info}/METADATA +4 -3
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc2.dist-info}/RECORD +22 -21
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc2.dist-info}/WHEEL +1 -1
- {careamics-0.1.0rc1.dist-info → careamics-0.1.0rc2.dist-info}/licenses/LICENSE +0 -0
careamics/engine.py
CHANGED
|
@@ -10,12 +10,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
13
|
-
from bioimageio.spec.model.raw_nodes import Model as BioimageModel
|
|
14
13
|
from torch.utils.data import DataLoader, TensorDataset
|
|
15
14
|
|
|
16
15
|
from .bioimage import (
|
|
17
|
-
build_zip_model,
|
|
18
16
|
get_default_model_specs,
|
|
17
|
+
save_bioimage_model,
|
|
19
18
|
)
|
|
20
19
|
from .config import Configuration, load_configuration
|
|
21
20
|
from .dataset.prepare_dataset import (
|
|
@@ -32,7 +31,7 @@ from .prediction import (
|
|
|
32
31
|
)
|
|
33
32
|
from .utils import (
|
|
34
33
|
MetricTracker,
|
|
35
|
-
|
|
34
|
+
add_axes,
|
|
36
35
|
denormalize,
|
|
37
36
|
get_device,
|
|
38
37
|
normalize,
|
|
@@ -40,7 +39,6 @@ from .utils import (
|
|
|
40
39
|
from .utils.logging import ProgressBar, get_logger
|
|
41
40
|
|
|
42
41
|
|
|
43
|
-
# TODO: refactor private methods and bioimage.io to other modules
|
|
44
42
|
class Engine:
|
|
45
43
|
"""
|
|
46
44
|
Class allowing training of a model and subsequent prediction.
|
|
@@ -165,47 +163,52 @@ class Engine:
|
|
|
165
163
|
self.scaler,
|
|
166
164
|
self.cfg,
|
|
167
165
|
) = create_model(config=self.cfg, model_path=model_path, device=self.device)
|
|
166
|
+
assert self.cfg is not None
|
|
168
167
|
|
|
169
168
|
# create loss function
|
|
170
|
-
|
|
171
|
-
|
|
169
|
+
self.loss_func = create_loss_function(self.cfg)
|
|
170
|
+
|
|
171
|
+
# Set logging
|
|
172
|
+
log_path = self.cfg.working_directory / "log.txt"
|
|
173
|
+
self.logger = get_logger(__name__, log_path=log_path)
|
|
172
174
|
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
self.logger = get_logger(__name__, log_path=log_path)
|
|
175
|
+
# wandb
|
|
176
|
+
self.use_wandb = self.cfg.training.use_wandb
|
|
176
177
|
|
|
177
|
-
|
|
178
|
-
|
|
178
|
+
if self.use_wandb:
|
|
179
|
+
try:
|
|
180
|
+
from wandb.errors import UsageError
|
|
181
|
+
|
|
182
|
+
from careamics.utils.wandb import WandBLogging
|
|
179
183
|
|
|
180
|
-
if self.use_wandb:
|
|
181
184
|
try:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
log_path=self.cfg.working_directory,
|
|
190
|
-
config=self.cfg,
|
|
191
|
-
model_to_watch=self.model,
|
|
192
|
-
)
|
|
193
|
-
except UsageError as e:
|
|
194
|
-
self.logger.warning(
|
|
195
|
-
f"Wandb usage error, using default logger. Check whether "
|
|
196
|
-
f"wandb correctly configured:\n"
|
|
197
|
-
f"{e}"
|
|
198
|
-
)
|
|
199
|
-
self.use_wandb = False
|
|
200
|
-
|
|
201
|
-
except ModuleNotFoundError:
|
|
185
|
+
self.wandb = WandBLogging(
|
|
186
|
+
experiment_name=self.cfg.experiment_name,
|
|
187
|
+
log_path=self.cfg.working_directory,
|
|
188
|
+
config=self.cfg,
|
|
189
|
+
model_to_watch=self.model,
|
|
190
|
+
)
|
|
191
|
+
except UsageError as e:
|
|
202
192
|
self.logger.warning(
|
|
203
|
-
"Wandb
|
|
204
|
-
"wandb"
|
|
193
|
+
f"Wandb usage error, using default logger. Check whether "
|
|
194
|
+
f"wandb correctly configured:\n"
|
|
195
|
+
f"{e}"
|
|
205
196
|
)
|
|
206
197
|
self.use_wandb = False
|
|
207
|
-
|
|
208
|
-
|
|
198
|
+
|
|
199
|
+
except ModuleNotFoundError:
|
|
200
|
+
self.logger.warning(
|
|
201
|
+
"Wandb not installed, using default logger. Try pip install "
|
|
202
|
+
"wandb"
|
|
203
|
+
)
|
|
204
|
+
self.use_wandb = False
|
|
205
|
+
|
|
206
|
+
# BMZ inputs/outputs placeholders, filled during validation
|
|
207
|
+
self._input = None
|
|
208
|
+
self._outputs = None
|
|
209
|
+
|
|
210
|
+
# torch version
|
|
211
|
+
self.torch_version = torch.__version__
|
|
209
212
|
|
|
210
213
|
def train(
|
|
211
214
|
self,
|
|
@@ -387,6 +390,12 @@ class Engine:
|
|
|
387
390
|
|
|
388
391
|
with torch.no_grad():
|
|
389
392
|
for patch, *auxillary in val_loader:
|
|
393
|
+
# if inputs is None, record a single patch
|
|
394
|
+
if self._input is None:
|
|
395
|
+
# patch has dimension SC(Z)YX
|
|
396
|
+
self._input = patch.clone().detach().cpu().numpy()
|
|
397
|
+
|
|
398
|
+
# evaluate
|
|
390
399
|
outputs = self.model(patch.to(self.device))
|
|
391
400
|
loss = self.loss_func(
|
|
392
401
|
outputs, *[a.to(self.device) for a in auxillary], self.device
|
|
@@ -409,10 +418,18 @@ class Engine:
|
|
|
409
418
|
The Engine must have previously been trained and mean/std be specified in
|
|
410
419
|
its configuration.
|
|
411
420
|
|
|
421
|
+
Data should be compatible with the axes, either from the configuration or
|
|
422
|
+
as passed using the `axes` parameter. If the batch and channel dimensions are
|
|
423
|
+
missing, then singleton dimensions are added.
|
|
424
|
+
|
|
412
425
|
To use tiling, both `tile_shape` and `overlaps` must be specified, have same
|
|
413
426
|
length, be divisible by 2 and greater than 0. Finally, the overlaps must be
|
|
414
427
|
smaller than the tiles.
|
|
415
428
|
|
|
429
|
+
By setting `tta` to `True`, the prediction is performed using test time
|
|
430
|
+
augmentation, meaning that the input is augmented and the prediction is averaged
|
|
431
|
+
over the augmentations.
|
|
432
|
+
|
|
416
433
|
Parameters
|
|
417
434
|
----------
|
|
418
435
|
input : Union[np.ndarra, str, Path]
|
|
@@ -497,6 +514,18 @@ class Engine:
|
|
|
497
514
|
UserWarning
|
|
498
515
|
If the samples have different shapes, the prediction then returns a list.
|
|
499
516
|
"""
|
|
517
|
+
# checks are done here to satisfy mypy
|
|
518
|
+
# check that configuration exists
|
|
519
|
+
if self.cfg is None:
|
|
520
|
+
raise ValueError("Configuration is not defined, cannot predict.")
|
|
521
|
+
|
|
522
|
+
# Check that the mean and std are there (= has been trained)
|
|
523
|
+
if not self.cfg.data.mean or not self.cfg.data.std:
|
|
524
|
+
raise ValueError(
|
|
525
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
526
|
+
"be performed."
|
|
527
|
+
)
|
|
528
|
+
|
|
500
529
|
prediction = []
|
|
501
530
|
tiles = []
|
|
502
531
|
stitching_data = []
|
|
@@ -513,9 +542,7 @@ class Engine:
|
|
|
513
542
|
predicted_augments = []
|
|
514
543
|
for augmented_tile in augmented_tiles:
|
|
515
544
|
augmented_pred = self.model(augmented_tile.to(self.device))
|
|
516
|
-
predicted_augments.append(
|
|
517
|
-
augmented_pred.squeeze().cpu().numpy()
|
|
518
|
-
)
|
|
545
|
+
predicted_augments.append(augmented_pred.cpu())
|
|
519
546
|
tiles.append(tta_backward(predicted_augments).squeeze())
|
|
520
547
|
else:
|
|
521
548
|
tiles.append(
|
|
@@ -529,8 +556,8 @@ class Engine:
|
|
|
529
556
|
predicted_sample = stitch_prediction(tiles, stitching_data)
|
|
530
557
|
predicted_sample = denormalize(
|
|
531
558
|
predicted_sample,
|
|
532
|
-
float(self.cfg.data.mean),
|
|
533
|
-
float(self.cfg.data.std),
|
|
559
|
+
float(self.cfg.data.mean),
|
|
560
|
+
float(self.cfg.data.std),
|
|
534
561
|
)
|
|
535
562
|
prediction.append(predicted_sample)
|
|
536
563
|
tiles.clear()
|
|
@@ -566,6 +593,18 @@ class Engine:
|
|
|
566
593
|
np.ndarray
|
|
567
594
|
Predicted image.
|
|
568
595
|
"""
|
|
596
|
+
# checks are done here to satisfy mypy
|
|
597
|
+
# check that configuration exists
|
|
598
|
+
if self.cfg is None:
|
|
599
|
+
raise ValueError("Configuration is not defined, cannot predict.")
|
|
600
|
+
|
|
601
|
+
# Check that the mean and std are there (= has been trained)
|
|
602
|
+
if not self.cfg.data.mean or not self.cfg.data.std:
|
|
603
|
+
raise ValueError(
|
|
604
|
+
"Mean or std are not specified in the configuration, prediction cannot "
|
|
605
|
+
"be performed."
|
|
606
|
+
)
|
|
607
|
+
|
|
569
608
|
prediction = []
|
|
570
609
|
with torch.no_grad():
|
|
571
610
|
for i, sample in enumerate(pred_loader):
|
|
@@ -574,9 +613,7 @@ class Engine:
|
|
|
574
613
|
predicted_augments = []
|
|
575
614
|
for augmented_pred in augmented_preds:
|
|
576
615
|
augmented_pred = self.model(augmented_pred.to(self.device))
|
|
577
|
-
predicted_augments.append(
|
|
578
|
-
augmented_pred.squeeze().cpu().numpy()
|
|
579
|
-
)
|
|
616
|
+
predicted_augments.append(augmented_pred.cpu())
|
|
580
617
|
prediction.append(tta_backward(predicted_augments).squeeze())
|
|
581
618
|
else:
|
|
582
619
|
prediction.append(
|
|
@@ -584,7 +621,9 @@ class Engine:
|
|
|
584
621
|
)
|
|
585
622
|
progress_bar.update(i, 1)
|
|
586
623
|
output = denormalize(
|
|
587
|
-
np.stack(prediction)
|
|
624
|
+
np.stack(prediction).squeeze(),
|
|
625
|
+
float(self.cfg.data.mean),
|
|
626
|
+
float(self.cfg.data.std),
|
|
588
627
|
)
|
|
589
628
|
return output
|
|
590
629
|
|
|
@@ -696,14 +735,13 @@ class Engine:
|
|
|
696
735
|
)
|
|
697
736
|
|
|
698
737
|
if input is None:
|
|
699
|
-
raise ValueError("
|
|
738
|
+
raise ValueError("Input cannot be None.")
|
|
700
739
|
|
|
701
740
|
# Create dataset
|
|
702
741
|
if isinstance(input, np.ndarray): # np.ndarray
|
|
703
|
-
#
|
|
742
|
+
# Validate axes and add missing dimensions (S)C if necessary
|
|
704
743
|
img_axes = self.cfg.data.axes if axes is None else axes
|
|
705
|
-
|
|
706
|
-
check_array_validity(input, img_axes)
|
|
744
|
+
input_expanded = add_axes(input, img_axes)
|
|
707
745
|
|
|
708
746
|
# Check if tiling requested
|
|
709
747
|
tiled = tile_shape is not None and overlaps is not None
|
|
@@ -714,11 +752,9 @@ class Engine:
|
|
|
714
752
|
"Tiling with in memory array is currently not implemented."
|
|
715
753
|
)
|
|
716
754
|
|
|
717
|
-
# check_tiling_validity(tile_shape, overlaps)
|
|
718
|
-
|
|
719
755
|
# Normalize input and cast to float32
|
|
720
756
|
normalized_input = normalize(
|
|
721
|
-
img=
|
|
757
|
+
img=input_expanded, mean=self.cfg.data.mean, std=self.cfg.data.std
|
|
722
758
|
)
|
|
723
759
|
normalized_input = normalized_input.astype(np.float32)
|
|
724
760
|
|
|
@@ -812,7 +848,66 @@ class Engine:
|
|
|
812
848
|
self.logger.removeHandler(handler)
|
|
813
849
|
handler.close()
|
|
814
850
|
|
|
815
|
-
def
|
|
851
|
+
def _get_sample_io_files(
|
|
852
|
+
self,
|
|
853
|
+
input_array: Optional[np.ndarray] = None,
|
|
854
|
+
axes: Optional[str] = None,
|
|
855
|
+
) -> Tuple[List[str], List[str]]:
|
|
856
|
+
"""
|
|
857
|
+
Create numpy format for use as inputs and outputs in the bioimage.io archive.
|
|
858
|
+
|
|
859
|
+
Parameters
|
|
860
|
+
----------
|
|
861
|
+
input_array : Optional[np.ndarray], optional
|
|
862
|
+
Input array to use for the bioimage.io model zoo, by default None.
|
|
863
|
+
axes : Optional[str], optional
|
|
864
|
+
Axes from the configuration.
|
|
865
|
+
|
|
866
|
+
Returns
|
|
867
|
+
-------
|
|
868
|
+
Tuple[List[str], List[str]]
|
|
869
|
+
Tuple of input and output file paths.
|
|
870
|
+
|
|
871
|
+
Raises
|
|
872
|
+
------
|
|
873
|
+
ValueError
|
|
874
|
+
If the configuration is not defined.
|
|
875
|
+
"""
|
|
876
|
+
if self.cfg is not None and self._input is not None:
|
|
877
|
+
# use the input array if provided, otherwise use the first validation sample
|
|
878
|
+
if input_array is not None:
|
|
879
|
+
array_in = input_array
|
|
880
|
+
|
|
881
|
+
# add axes to be compatible with the axes declared in the RDF specs
|
|
882
|
+
add_axes(array_in, axes)
|
|
883
|
+
else:
|
|
884
|
+
array_in = self._input
|
|
885
|
+
|
|
886
|
+
# predict (no tta since BMZ does not apply it)
|
|
887
|
+
array_out = self.predict(array_in, tta=False)
|
|
888
|
+
|
|
889
|
+
# add singleton dimensions (for compatibility with model axes)
|
|
890
|
+
# indeed, BMZ applies the model but CAREamics function are meant
|
|
891
|
+
# to work on user data (potentially with no S or C axe)
|
|
892
|
+
array_out = array_out[np.newaxis, np.newaxis, ...]
|
|
893
|
+
|
|
894
|
+
# save numpy files
|
|
895
|
+
workdir = self.cfg.working_directory
|
|
896
|
+
in_file = workdir.joinpath("test_inputs.npy")
|
|
897
|
+
np.save(in_file, array_in)
|
|
898
|
+
out_file = workdir.joinpath("test_outputs.npy")
|
|
899
|
+
np.save(out_file, array_out)
|
|
900
|
+
|
|
901
|
+
return [str(in_file.absolute())], [str(out_file.absolute())]
|
|
902
|
+
else:
|
|
903
|
+
raise ValueError("Configuration is not defined or model was not trained.")
|
|
904
|
+
|
|
905
|
+
def _generate_rdf(
|
|
906
|
+
self,
|
|
907
|
+
*,
|
|
908
|
+
model_specs: Optional[dict] = None,
|
|
909
|
+
input_array: Optional[np.ndarray] = None,
|
|
910
|
+
) -> dict:
|
|
816
911
|
"""
|
|
817
912
|
Generate rdf data for bioimage.io format export.
|
|
818
913
|
|
|
@@ -820,6 +915,8 @@ class Engine:
|
|
|
820
915
|
----------
|
|
821
916
|
model_specs : Optional[dict], optional
|
|
822
917
|
Custom specs if different than the default ones, by default None.
|
|
918
|
+
input_array : Optional[np.ndarray], optional
|
|
919
|
+
Input array to use for the bioimage.io model zoo, by default None.
|
|
823
920
|
|
|
824
921
|
Returns
|
|
825
922
|
-------
|
|
@@ -848,7 +945,9 @@ class Engine:
|
|
|
848
945
|
axes = "b" + axes
|
|
849
946
|
|
|
850
947
|
# get in/out samples' files
|
|
851
|
-
test_inputs, test_outputs = self._get_sample_io_files(
|
|
948
|
+
test_inputs, test_outputs = self._get_sample_io_files(
|
|
949
|
+
input_array, self.cfg.data.axes
|
|
950
|
+
)
|
|
852
951
|
|
|
853
952
|
specs = get_default_model_specs(
|
|
854
953
|
"Noise2Void",
|
|
@@ -861,7 +960,6 @@ class Engine:
|
|
|
861
960
|
|
|
862
961
|
specs.update(
|
|
863
962
|
{
|
|
864
|
-
"architecture": "careamics.models.unet",
|
|
865
963
|
"test_inputs": test_inputs,
|
|
866
964
|
"test_outputs": test_outputs,
|
|
867
965
|
"input_axes": [axes],
|
|
@@ -870,14 +968,21 @@ class Engine:
|
|
|
870
968
|
)
|
|
871
969
|
return specs
|
|
872
970
|
else:
|
|
873
|
-
raise ValueError("Configuration is not defined.")
|
|
971
|
+
raise ValueError("Configuration is not defined or model was not trained.")
|
|
874
972
|
|
|
875
973
|
def save_as_bioimage(
|
|
876
|
-
self,
|
|
877
|
-
|
|
974
|
+
self,
|
|
975
|
+
output_zip: Union[Path, str],
|
|
976
|
+
model_specs: Optional[dict] = None,
|
|
977
|
+
input_array: Optional[np.ndarray] = None,
|
|
978
|
+
) -> None:
|
|
878
979
|
"""
|
|
879
980
|
Export the current model to BioImage.io model zoo format.
|
|
880
981
|
|
|
982
|
+
Custom specs can be passed in `model_specs (e.g. maintainers). For a description
|
|
983
|
+
of the model RDF, refer to
|
|
984
|
+
github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/model_spec_latest.md.
|
|
985
|
+
|
|
881
986
|
Parameters
|
|
882
987
|
----------
|
|
883
988
|
output_zip : Union[Path, str]
|
|
@@ -885,11 +990,10 @@ class Engine:
|
|
|
885
990
|
model_specs : Optional[dict]
|
|
886
991
|
A dictionary with keys being the bioimage-core build_model parameters. If
|
|
887
992
|
None then it will be populated by the model default specs.
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
Bioimage.io model object.
|
|
993
|
+
input_array : Optional[np.ndarray]
|
|
994
|
+
An array to use as input for the bioimage.io model zoo. If None then the
|
|
995
|
+
first validation sample will be used. Note that the array must have S and
|
|
996
|
+
C dimensions (e.g. SCYX), even if only singleton dimensions.
|
|
893
997
|
|
|
894
998
|
Raises
|
|
895
999
|
------
|
|
@@ -898,57 +1002,13 @@ class Engine:
|
|
|
898
1002
|
"""
|
|
899
1003
|
if self.cfg is not None:
|
|
900
1004
|
# Generate specs
|
|
901
|
-
specs = self._generate_rdf(model_specs)
|
|
1005
|
+
specs = self._generate_rdf(model_specs=model_specs, input_array=input_array)
|
|
902
1006
|
|
|
903
1007
|
# Build model
|
|
904
|
-
|
|
1008
|
+
save_bioimage_model(
|
|
905
1009
|
path=output_zip,
|
|
906
1010
|
config=self.cfg,
|
|
907
|
-
|
|
1011
|
+
specs=specs,
|
|
908
1012
|
)
|
|
909
|
-
|
|
910
|
-
return raw_model
|
|
911
|
-
else:
|
|
912
|
-
raise ValueError("Configuration is not defined.")
|
|
913
|
-
|
|
914
|
-
def _get_sample_io_files(self, axes: str) -> Tuple[List[str], List[str]]:
|
|
915
|
-
"""
|
|
916
|
-
Create numpy format for use as inputs and outputs in the bioimage.io archive.
|
|
917
|
-
|
|
918
|
-
Parameters
|
|
919
|
-
----------
|
|
920
|
-
axes : str
|
|
921
|
-
Input and output axes.
|
|
922
|
-
|
|
923
|
-
Returns
|
|
924
|
-
-------
|
|
925
|
-
Tuple[List[str], List[str]]
|
|
926
|
-
Tuple of input and output file paths.
|
|
927
|
-
|
|
928
|
-
Raises
|
|
929
|
-
------
|
|
930
|
-
ValueError
|
|
931
|
-
If the configuration is not defined.
|
|
932
|
-
"""
|
|
933
|
-
# input:
|
|
934
|
-
if self.cfg is not None:
|
|
935
|
-
sample_input = np.random.randn(*self.cfg.training.patch_size)
|
|
936
|
-
# if there are more input axes (like channel, ...),
|
|
937
|
-
# then expand the sample dimensions.
|
|
938
|
-
len_diff = len(axes) - len(self.cfg.training.patch_size)
|
|
939
|
-
if len_diff > 0:
|
|
940
|
-
sample_input = np.expand_dims(
|
|
941
|
-
sample_input, axis=tuple(i for i in range(len_diff))
|
|
942
|
-
)
|
|
943
|
-
sample_output = np.random.randn(*sample_input.shape)
|
|
944
|
-
|
|
945
|
-
# save numpy files
|
|
946
|
-
workdir = self.cfg.working_directory
|
|
947
|
-
in_file = workdir.joinpath("test_inputs.npy")
|
|
948
|
-
np.save(in_file, sample_input)
|
|
949
|
-
out_file = workdir.joinpath("test_outputs.npy")
|
|
950
|
-
np.save(out_file, sample_output)
|
|
951
|
-
|
|
952
|
-
return [str(in_file.absolute())], [str(out_file.absolute())]
|
|
953
1013
|
else:
|
|
954
1014
|
raise ValueError("Configuration is not defined.")
|
careamics/losses/loss_factory.py
CHANGED
|
@@ -5,8 +5,9 @@ This module contains a factory function for creating loss functions.
|
|
|
5
5
|
"""
|
|
6
6
|
from typing import Callable
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
8
|
+
from careamics.config import Configuration
|
|
9
|
+
from careamics.config.algorithm import Loss
|
|
10
|
+
|
|
10
11
|
from .losses import n2v_loss
|
|
11
12
|
|
|
12
13
|
|
|
@@ -8,10 +8,11 @@ from typing import Dict, Optional, Tuple, Union
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
11
|
+
from careamics.bioimage import import_bioimage_model
|
|
12
|
+
from careamics.config import Configuration
|
|
13
|
+
from careamics.config.algorithm import Models
|
|
14
|
+
from careamics.utils.logging import get_logger
|
|
15
|
+
|
|
15
16
|
from .unet import UNet
|
|
16
17
|
|
|
17
18
|
logger = get_logger(__name__)
|
|
@@ -49,16 +50,27 @@ def create_model(
|
|
|
49
50
|
model_path: Optional[Union[str, Path]] = None,
|
|
50
51
|
config: Optional[Configuration] = None,
|
|
51
52
|
device: Optional[torch.device] = None,
|
|
52
|
-
) ->
|
|
53
|
+
) -> Tuple[
|
|
54
|
+
torch.nn.Module,
|
|
55
|
+
torch.optim.Optimizer,
|
|
56
|
+
Union[
|
|
57
|
+
torch.optim.lr_scheduler.LRScheduler,
|
|
58
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
59
|
+
],
|
|
60
|
+
torch.cuda.amp.GradScaler,
|
|
61
|
+
Configuration,
|
|
62
|
+
]:
|
|
53
63
|
"""
|
|
54
|
-
Instantiate a model from a
|
|
64
|
+
Instantiate a model from a model path or configuration.
|
|
55
65
|
|
|
56
|
-
If both
|
|
66
|
+
If both path and configuration are provided, the model path is used. The model
|
|
67
|
+
path should point to either a checkpoint (created during training) or a model
|
|
68
|
+
exported to the bioimage.io format.
|
|
57
69
|
|
|
58
70
|
Parameters
|
|
59
71
|
----------
|
|
60
72
|
model_path : Optional[Union[str, Path]], optional
|
|
61
|
-
Path to a checkpoint, by default None.
|
|
73
|
+
Path to a checkpoint or bioimage.io archive, by default None.
|
|
62
74
|
config : Optional[Configuration], optional
|
|
63
75
|
Configuration, by default None.
|
|
64
76
|
device : Optional[torch.device], optional
|
|
@@ -103,12 +115,13 @@ def create_model(
|
|
|
103
115
|
raise ValueError("Invalid checkpoint format, no configuration found.")
|
|
104
116
|
|
|
105
117
|
# Create model
|
|
106
|
-
model = model_registry(model_name)(
|
|
118
|
+
model: torch.nn.Module = model_registry(model_name)(
|
|
107
119
|
depth=model_config.depth,
|
|
108
120
|
conv_dim=algo_config.get_conv_dim(),
|
|
109
121
|
num_channels_init=model_config.num_channels_init,
|
|
110
122
|
)
|
|
111
123
|
model.to(device)
|
|
124
|
+
|
|
112
125
|
# Load the model state dict
|
|
113
126
|
if "model_state_dict" in checkpoint:
|
|
114
127
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
@@ -141,13 +154,19 @@ def create_model(
|
|
|
141
154
|
|
|
142
155
|
else:
|
|
143
156
|
raise ValueError("Either config or model_path must be provided")
|
|
144
|
-
|
|
157
|
+
|
|
145
158
|
return model, optimizer, scheduler, scaler, config
|
|
146
159
|
|
|
147
160
|
|
|
148
161
|
def get_optimizer_and_scheduler(
|
|
149
162
|
config: Configuration, model: torch.nn.Module, state_dict: Optional[Dict] = None
|
|
150
|
-
) -> Tuple[
|
|
163
|
+
) -> Tuple[
|
|
164
|
+
torch.optim.Optimizer,
|
|
165
|
+
Union[
|
|
166
|
+
torch.optim.lr_scheduler.LRScheduler,
|
|
167
|
+
torch.optim.lr_scheduler.ReduceLROnPlateau, # not a subclass of LRScheduler
|
|
168
|
+
],
|
|
169
|
+
]:
|
|
151
170
|
"""
|
|
152
171
|
Create optimizer and learning rate schedulers.
|
|
153
172
|
|
|
@@ -46,16 +46,18 @@ def stitch_prediction(
|
|
|
46
46
|
return predicted_image
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
def tta_forward(x:
|
|
49
|
+
def tta_forward(x: torch.Tensor) -> List[torch.Tensor]:
|
|
50
50
|
"""
|
|
51
51
|
Augment 8-fold an array.
|
|
52
52
|
|
|
53
53
|
The augmentation is performed using all 90 deg rotations and their flipped version,
|
|
54
54
|
as well as the original image flipped.
|
|
55
55
|
|
|
56
|
+
Tensors should be of shape SC(Z)YX, with S and C potentially singleton dimensions.
|
|
57
|
+
|
|
56
58
|
Parameters
|
|
57
59
|
----------
|
|
58
|
-
x : torch.
|
|
60
|
+
x : torch.Tensor
|
|
59
61
|
Data to augment.
|
|
60
62
|
|
|
61
63
|
Returns
|
|
@@ -75,13 +77,15 @@ def tta_forward(x: np.ndarray) -> List:
|
|
|
75
77
|
return x_aug_flip
|
|
76
78
|
|
|
77
79
|
|
|
78
|
-
def tta_backward(x_aug: List) -> np.ndarray:
|
|
80
|
+
def tta_backward(x_aug: List[torch.Tensor]) -> np.ndarray:
|
|
79
81
|
"""
|
|
80
82
|
Invert `tta_forward` and average the 8 images.
|
|
81
83
|
|
|
84
|
+
The function takes a list of torch tensors and returns a numpy array.
|
|
85
|
+
|
|
82
86
|
Parameters
|
|
83
87
|
----------
|
|
84
|
-
x_aug : List
|
|
88
|
+
x_aug : List[torch.Tensor]
|
|
85
89
|
Stack of 8-fold augmented images.
|
|
86
90
|
|
|
87
91
|
Returns
|
|
@@ -90,13 +94,13 @@ def tta_backward(x_aug: List) -> np.ndarray:
|
|
|
90
94
|
Average of de-augmented x_aug.
|
|
91
95
|
"""
|
|
92
96
|
x_deaug = [
|
|
93
|
-
x_aug[0],
|
|
94
|
-
np.rot90(x_aug[1], -1),
|
|
95
|
-
np.rot90(x_aug[2], -2),
|
|
96
|
-
np.rot90(x_aug[3], -3),
|
|
97
|
-
np.
|
|
98
|
-
np.rot90(np.
|
|
99
|
-
np.rot90(np.
|
|
100
|
-
np.rot90(np.
|
|
97
|
+
x_aug[0].numpy(),
|
|
98
|
+
np.rot90(x_aug[1], -1, axes=(2, 3)),
|
|
99
|
+
np.rot90(x_aug[2], -2, axes=(2, 3)),
|
|
100
|
+
np.rot90(x_aug[3], -3, axes=(2, 3)),
|
|
101
|
+
np.flip(x_aug[4].numpy(), axis=(1, 3)),
|
|
102
|
+
np.rot90(np.flip(x_aug[5].numpy(), axis=(1, 3)), -1, axes=(2, 3)),
|
|
103
|
+
np.rot90(np.flip(x_aug[6].numpy(), axis=(1, 3)), -2, axes=(2, 3)),
|
|
104
|
+
np.rot90(np.flip(x_aug[7].numpy(), axis=(1, 3)), -3, axes=(2, 3)),
|
|
101
105
|
]
|
|
102
106
|
return np.mean(x_deaug, 0)
|
careamics/utils/__init__.py
CHANGED
|
@@ -5,11 +5,10 @@ __all__ = [
|
|
|
5
5
|
"denormalize",
|
|
6
6
|
"normalize",
|
|
7
7
|
"get_device",
|
|
8
|
-
"check_array_validity",
|
|
9
8
|
"check_axes_validity",
|
|
9
|
+
"add_axes",
|
|
10
10
|
"check_tiling_validity",
|
|
11
11
|
"cwd",
|
|
12
|
-
"compile_model",
|
|
13
12
|
"MetricTracker",
|
|
14
13
|
]
|
|
15
14
|
|
|
@@ -17,5 +16,5 @@ __all__ = [
|
|
|
17
16
|
from .context import cwd
|
|
18
17
|
from .metrics import MetricTracker
|
|
19
18
|
from .normalization import denormalize, normalize
|
|
20
|
-
from .torch_utils import
|
|
21
|
-
from .validators import
|
|
19
|
+
from .torch_utils import get_device
|
|
20
|
+
from .validators import add_axes, check_axes_validity, check_tiling_validity
|