spacr 0.5.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -2
- spacr/__main__.py +3 -3
- spacr/core.py +13 -106
- spacr/gui_core.py +2 -2
- spacr/gui_utils.py +1 -13
- spacr/io.py +24 -25
- spacr/mediar.py +12 -8
- spacr/plot.py +50 -13
- spacr/settings.py +45 -6
- spacr/submodules.py +11 -1
- spacr/timelapse.py +21 -3
- spacr/utils.py +154 -15
- {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/METADATA +62 -62
- spacr-0.9.0.dist-info/RECORD +109 -0
- {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/WHEEL +1 -1
- spacr/resources/MEDIAR/.gitignore +0 -18
- spacr/resources/MEDIAR/LICENSE +0 -21
- spacr/resources/MEDIAR/README.md +0 -189
- spacr/resources/MEDIAR/SetupDict.py +0 -39
- spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/config/baseline.json +0 -60
- spacr/resources/MEDIAR/config/mediar_example.json +0 -72
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
- spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
- spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
- spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
- spacr/resources/MEDIAR/core/__init__.py +0 -2
- spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/utils.py +0 -40
- spacr/resources/MEDIAR/evaluate.py +0 -71
- spacr/resources/MEDIAR/generate_mapping.py +0 -121
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +0 -125
- spacr/resources/MEDIAR/predict.py +0 -70
- spacr/resources/MEDIAR/requirements.txt +0 -14
- spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
- spacr/resources/MEDIAR/train_tools/measures.py +0 -200
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
- spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/utils.py +0 -70
- spacr-0.5.0.dist-info/RECORD +0 -190
- {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/LICENSE +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +0,0 @@
|
|
1
|
-
{
|
2
|
-
"pred_setups":{
|
3
|
-
"name": "ensemble_mediar",
|
4
|
-
"input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
|
5
|
-
"output_path": "./results/mediar_ensemble_tta",
|
6
|
-
"make_submission": true,
|
7
|
-
"model_path1": "./weights/finetuned/from_phase1.pth",
|
8
|
-
"model_path2": "./weights/finetuned/from_phase2.pth",
|
9
|
-
"device": "cuda:0",
|
10
|
-
"model":{
|
11
|
-
"name": "mediar-former",
|
12
|
-
"params": {
|
13
|
-
"encoder_name":"mit_b5",
|
14
|
-
"decoder_channels": [1024, 512, 256, 128, 64],
|
15
|
-
"decoder_pab_channels": 256,
|
16
|
-
"in_channels":3,
|
17
|
-
"classes":3
|
18
|
-
}
|
19
|
-
},
|
20
|
-
"exp_name": "mediar_ensemble_tta",
|
21
|
-
"algo_params": {"use_tta": true}
|
22
|
-
}
|
23
|
-
}
|
@@ -1,120 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
import time, os
|
4
|
-
import tifffile as tif
|
5
|
-
|
6
|
-
from datetime import datetime
|
7
|
-
from zipfile import ZipFile
|
8
|
-
from pytz import timezone
|
9
|
-
|
10
|
-
from train_tools.data_utils.transforms import get_pred_transforms
|
11
|
-
|
12
|
-
|
13
|
-
class BasePredictor:
|
14
|
-
def __init__(
|
15
|
-
self,
|
16
|
-
model,
|
17
|
-
device,
|
18
|
-
input_path,
|
19
|
-
output_path,
|
20
|
-
make_submission=False,
|
21
|
-
exp_name=None,
|
22
|
-
algo_params=None,
|
23
|
-
):
|
24
|
-
self.model = model
|
25
|
-
self.device = device
|
26
|
-
self.input_path = input_path
|
27
|
-
self.output_path = output_path
|
28
|
-
self.make_submission = make_submission
|
29
|
-
self.exp_name = exp_name
|
30
|
-
|
31
|
-
# Assign algoritm-specific arguments
|
32
|
-
if algo_params:
|
33
|
-
self.__dict__.update((k, v) for k, v in algo_params.items())
|
34
|
-
|
35
|
-
# Prepare inference environments
|
36
|
-
self._setups()
|
37
|
-
|
38
|
-
@torch.no_grad()
|
39
|
-
def conduct_prediction(self):
|
40
|
-
self.model.to(self.device)
|
41
|
-
self.model.eval()
|
42
|
-
total_time = 0
|
43
|
-
total_times = []
|
44
|
-
|
45
|
-
for img_name in self.img_names:
|
46
|
-
img_data = self._get_img_data(img_name)
|
47
|
-
img_data = img_data.to(self.device)
|
48
|
-
|
49
|
-
start = time.time()
|
50
|
-
|
51
|
-
pred_mask = self._inference(img_data)
|
52
|
-
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
|
53
|
-
self.write_pred_mask(
|
54
|
-
pred_mask, self.output_path, img_name, self.make_submission
|
55
|
-
)
|
56
|
-
|
57
|
-
end = time.time()
|
58
|
-
|
59
|
-
time_cost = end - start
|
60
|
-
total_times.append(time_cost)
|
61
|
-
total_time += time_cost
|
62
|
-
print(
|
63
|
-
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
|
64
|
-
)
|
65
|
-
|
66
|
-
print(f"\n Total Time Cost: {total_time:.2f}s")
|
67
|
-
|
68
|
-
if self.make_submission:
|
69
|
-
fname = "%s.zip" % self.exp_name
|
70
|
-
|
71
|
-
os.makedirs("./submissions", exist_ok=True)
|
72
|
-
submission_path = os.path.join("./submissions", fname)
|
73
|
-
|
74
|
-
with ZipFile(submission_path, "w") as zipObj2:
|
75
|
-
pred_names = sorted(os.listdir(self.output_path))
|
76
|
-
for pred_name in pred_names:
|
77
|
-
pred_path = os.path.join(self.output_path, pred_name)
|
78
|
-
zipObj2.write(pred_path)
|
79
|
-
|
80
|
-
print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
|
81
|
-
|
82
|
-
return time_cost
|
83
|
-
|
84
|
-
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
|
85
|
-
|
86
|
-
# All images should contain at least 5 cells
|
87
|
-
if submission:
|
88
|
-
if not (np.max(pred_mask) > 5):
|
89
|
-
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
|
90
|
-
|
91
|
-
file_name = image_name.split(".")[0]
|
92
|
-
file_name = file_name + "_label.tiff"
|
93
|
-
file_path = os.path.join(output_dir, file_name)
|
94
|
-
|
95
|
-
tif.imwrite(file_path, pred_mask, compression="zlib")
|
96
|
-
|
97
|
-
def _setups(self):
|
98
|
-
self.pred_transforms = get_pred_transforms()
|
99
|
-
os.makedirs(self.output_path, exist_ok=True)
|
100
|
-
|
101
|
-
now = datetime.now(timezone("Asia/Seoul"))
|
102
|
-
dt_string = now.strftime("%m%d_%H%M")
|
103
|
-
self.exp_name = (
|
104
|
-
self.exp_name + dt_string if self.exp_name is not None else dt_string
|
105
|
-
)
|
106
|
-
|
107
|
-
self.img_names = sorted(os.listdir(self.input_path))
|
108
|
-
|
109
|
-
def _get_img_data(self, img_name):
|
110
|
-
img_path = os.path.join(self.input_path, img_name)
|
111
|
-
img_data = self.pred_transforms(img_path)
|
112
|
-
img_data = img_data.unsqueeze(0)
|
113
|
-
|
114
|
-
return img_data
|
115
|
-
|
116
|
-
def _inference(self, img_data):
|
117
|
-
raise NotImplementedError
|
118
|
-
|
119
|
-
def _post_process(self, pred_mask):
|
120
|
-
raise NotImplementedError
|
@@ -1,240 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
from tqdm import tqdm
|
4
|
-
from monai.inferers import sliding_window_inference
|
5
|
-
from monai.metrics import CumulativeAverage
|
6
|
-
from monai.transforms import (
|
7
|
-
Activations,
|
8
|
-
AsDiscrete,
|
9
|
-
Compose,
|
10
|
-
EnsureType,
|
11
|
-
)
|
12
|
-
|
13
|
-
import os, sys
|
14
|
-
import copy
|
15
|
-
|
16
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
|
17
|
-
|
18
|
-
from core.utils import print_learning_device, print_with_logging
|
19
|
-
from train_tools.measures import evaluate_f1_score_cellseg
|
20
|
-
|
21
|
-
|
22
|
-
class BaseTrainer:
|
23
|
-
"""Abstract base class for trainer implementations"""
|
24
|
-
|
25
|
-
def __init__(
|
26
|
-
self,
|
27
|
-
model,
|
28
|
-
dataloaders,
|
29
|
-
optimizer,
|
30
|
-
scheduler=None,
|
31
|
-
criterion=None,
|
32
|
-
num_epochs=100,
|
33
|
-
device="cuda:0",
|
34
|
-
no_valid=False,
|
35
|
-
valid_frequency=1,
|
36
|
-
amp=False,
|
37
|
-
algo_params=None,
|
38
|
-
):
|
39
|
-
self.model = model.to(device)
|
40
|
-
self.dataloaders = dataloaders
|
41
|
-
self.optimizer = optimizer
|
42
|
-
self.scheduler = scheduler
|
43
|
-
self.criterion = criterion
|
44
|
-
self.num_epochs = num_epochs
|
45
|
-
self.no_valid = no_valid
|
46
|
-
self.valid_frequency = valid_frequency
|
47
|
-
self.device = device
|
48
|
-
self.amp = amp
|
49
|
-
self.best_weights = None
|
50
|
-
self.best_f1_score = 0.1
|
51
|
-
|
52
|
-
# FP-16 Scaler
|
53
|
-
self.scaler = torch.cuda.amp.GradScaler() if amp else None
|
54
|
-
|
55
|
-
# Assign algoritm-specific arguments
|
56
|
-
if algo_params:
|
57
|
-
self.__dict__.update((k, v) for k, v in algo_params.items())
|
58
|
-
|
59
|
-
# Cumulitive statistics
|
60
|
-
self.loss_metric = CumulativeAverage()
|
61
|
-
self.f1_metric = CumulativeAverage()
|
62
|
-
|
63
|
-
# Post-processing functions
|
64
|
-
self.post_pred = Compose(
|
65
|
-
[EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
|
66
|
-
)
|
67
|
-
self.post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
|
68
|
-
|
69
|
-
def train(self):
|
70
|
-
"""Train the model"""
|
71
|
-
|
72
|
-
# Print learning device name
|
73
|
-
print_learning_device(self.device)
|
74
|
-
|
75
|
-
# Learning process
|
76
|
-
for epoch in range(1, self.num_epochs + 1):
|
77
|
-
print(f"[Round {epoch}/{self.num_epochs}]")
|
78
|
-
|
79
|
-
# Train Epoch Phase
|
80
|
-
print(">>> Train Epoch")
|
81
|
-
train_results = self._epoch_phase("train")
|
82
|
-
print_with_logging(train_results, epoch)
|
83
|
-
|
84
|
-
if self.scheduler is not None:
|
85
|
-
self.scheduler.step()
|
86
|
-
|
87
|
-
if epoch % self.valid_frequency == 0:
|
88
|
-
if not self.no_valid:
|
89
|
-
# Valid Epoch Phase
|
90
|
-
print(">>> Valid Epoch")
|
91
|
-
valid_results = self._epoch_phase("valid")
|
92
|
-
print_with_logging(valid_results, epoch)
|
93
|
-
|
94
|
-
if "Valid_F1_Score" in valid_results.keys():
|
95
|
-
current_f1_score = valid_results["Valid_F1_Score"]
|
96
|
-
self._update_best_model(current_f1_score)
|
97
|
-
else:
|
98
|
-
print(">>> TuningSet Epoch")
|
99
|
-
tuning_cell_counts = self._tuningset_evaluation()
|
100
|
-
tuning_count_dict = {"TuningSet_Cell_Count": tuning_cell_counts}
|
101
|
-
print_with_logging(tuning_count_dict, epoch)
|
102
|
-
|
103
|
-
current_cell_count = tuning_cell_counts
|
104
|
-
self._update_best_model(current_cell_count)
|
105
|
-
|
106
|
-
print("-" * 50)
|
107
|
-
|
108
|
-
self.best_f1_score = 0
|
109
|
-
|
110
|
-
if self.best_weights is not None:
|
111
|
-
self.model.load_state_dict(self.best_weights)
|
112
|
-
|
113
|
-
def _epoch_phase(self, phase):
|
114
|
-
"""Learning process for 1 Epoch (for different phases).
|
115
|
-
|
116
|
-
Args:
|
117
|
-
phase (str): "train", "valid", "test"
|
118
|
-
|
119
|
-
Returns:
|
120
|
-
dict: statistics for the phase results
|
121
|
-
"""
|
122
|
-
phase_results = {}
|
123
|
-
|
124
|
-
# Set model mode
|
125
|
-
self.model.train() if phase == "train" else self.model.eval()
|
126
|
-
|
127
|
-
# Epoch process
|
128
|
-
for batch_data in tqdm(self.dataloaders[phase]):
|
129
|
-
images = batch_data["img"].to(self.device)
|
130
|
-
labels = batch_data["label"].to(self.device)
|
131
|
-
self.optimizer.zero_grad()
|
132
|
-
|
133
|
-
# Forward pass
|
134
|
-
with torch.set_grad_enabled(phase == "train"):
|
135
|
-
outputs = self.model(images)
|
136
|
-
loss = self.criterion(outputs, labels)
|
137
|
-
self.loss_metric.append(loss)
|
138
|
-
|
139
|
-
# Backward pass
|
140
|
-
if phase == "train":
|
141
|
-
loss.backward()
|
142
|
-
self.optimizer.step()
|
143
|
-
|
144
|
-
# Update metrics
|
145
|
-
phase_results = self._update_results(
|
146
|
-
phase_results, self.loss_metric, "loss", phase
|
147
|
-
)
|
148
|
-
|
149
|
-
return phase_results
|
150
|
-
|
151
|
-
@torch.no_grad()
|
152
|
-
def _tuningset_evaluation(self):
|
153
|
-
cell_counts_total = []
|
154
|
-
self.model.eval()
|
155
|
-
|
156
|
-
for batch_data in tqdm(self.dataloaders["tuning"]):
|
157
|
-
images = batch_data["img"].to(self.device)
|
158
|
-
if images.shape[-1] > 5000:
|
159
|
-
continue
|
160
|
-
|
161
|
-
outputs = sliding_window_inference(
|
162
|
-
images,
|
163
|
-
roi_size=512,
|
164
|
-
sw_batch_size=4,
|
165
|
-
predictor=self.model,
|
166
|
-
padding_mode="constant",
|
167
|
-
mode="gaussian",
|
168
|
-
)
|
169
|
-
|
170
|
-
outputs = outputs.squeeze(0)
|
171
|
-
outputs, _ = self._post_process(outputs, None)
|
172
|
-
count = len(np.unique(outputs) - 1)
|
173
|
-
cell_counts_total.append(count)
|
174
|
-
|
175
|
-
cell_counts_total_sum = np.sum(cell_counts_total)
|
176
|
-
print("Cell Counts Total: (%d)" % (cell_counts_total_sum))
|
177
|
-
|
178
|
-
return cell_counts_total_sum
|
179
|
-
|
180
|
-
def _update_results(self, phase_results, metric, metric_key, phase="train"):
|
181
|
-
"""Aggregate and flush metrics
|
182
|
-
|
183
|
-
Args:
|
184
|
-
phase_results (dict): base dictionary to log metrics
|
185
|
-
metric (_type_): cumulated metrics
|
186
|
-
metric_key (_type_): name of metric
|
187
|
-
phase (str, optional): current phase name. Defaults to "train".
|
188
|
-
|
189
|
-
Returns:
|
190
|
-
dict: dictionary of metrics for the current phase
|
191
|
-
"""
|
192
|
-
|
193
|
-
# Refine metrics name
|
194
|
-
metric_key = "_".join([phase, metric_key]).title()
|
195
|
-
|
196
|
-
# Aggregate metrics
|
197
|
-
metric_item = round(metric.aggregate().item(), 4)
|
198
|
-
|
199
|
-
# Log metrics to dictionary
|
200
|
-
phase_results[metric_key] = metric_item
|
201
|
-
|
202
|
-
# Flush metrics
|
203
|
-
metric.reset()
|
204
|
-
|
205
|
-
return phase_results
|
206
|
-
|
207
|
-
def _update_best_model(self, current_f1_score):
|
208
|
-
if current_f1_score > self.best_f1_score:
|
209
|
-
self.best_weights = copy.deepcopy(self.model.state_dict())
|
210
|
-
self.best_f1_score = current_f1_score
|
211
|
-
print(
|
212
|
-
"\n>>>> Update Best Model with score: {}\n".format(self.best_f1_score)
|
213
|
-
)
|
214
|
-
else:
|
215
|
-
pass
|
216
|
-
|
217
|
-
def _inference(self, images, phase="train"):
|
218
|
-
"""inference methods for different phase"""
|
219
|
-
if phase != "train":
|
220
|
-
outputs = sliding_window_inference(
|
221
|
-
images,
|
222
|
-
roi_size=512,
|
223
|
-
sw_batch_size=4,
|
224
|
-
predictor=self.model,
|
225
|
-
padding_mode="reflect",
|
226
|
-
mode="gaussian",
|
227
|
-
overlap=0.5,
|
228
|
-
)
|
229
|
-
else:
|
230
|
-
outputs = self.model(images)
|
231
|
-
|
232
|
-
return outputs
|
233
|
-
|
234
|
-
def _post_process(self, outputs, labels):
|
235
|
-
return outputs, labels
|
236
|
-
|
237
|
-
def _get_f1_metric(self, masks_pred, masks_true):
|
238
|
-
f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1]
|
239
|
-
|
240
|
-
return f1_score
|
@@ -1,59 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import os, sys
|
3
|
-
from skimage import morphology, measure
|
4
|
-
from monai.inferers import sliding_window_inference
|
5
|
-
|
6
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
|
7
|
-
|
8
|
-
from core.BasePredictor import BasePredictor
|
9
|
-
|
10
|
-
__all__ = ["Predictor"]
|
11
|
-
|
12
|
-
|
13
|
-
class Predictor(BasePredictor):
|
14
|
-
def __init__(
|
15
|
-
self,
|
16
|
-
model,
|
17
|
-
device,
|
18
|
-
input_path,
|
19
|
-
output_path,
|
20
|
-
make_submission=False,
|
21
|
-
exp_name=None,
|
22
|
-
algo_params=None,
|
23
|
-
):
|
24
|
-
super(Predictor, self).__init__(
|
25
|
-
model,
|
26
|
-
device,
|
27
|
-
input_path,
|
28
|
-
output_path,
|
29
|
-
make_submission,
|
30
|
-
exp_name,
|
31
|
-
algo_params,
|
32
|
-
)
|
33
|
-
|
34
|
-
def _inference(self, img_data):
|
35
|
-
pred_mask = sliding_window_inference(
|
36
|
-
img_data,
|
37
|
-
512,
|
38
|
-
4,
|
39
|
-
self.model,
|
40
|
-
padding_mode="constant",
|
41
|
-
mode="gaussian",
|
42
|
-
overlap=0.6,
|
43
|
-
)
|
44
|
-
|
45
|
-
return pred_mask
|
46
|
-
|
47
|
-
def _post_process(self, pred_mask):
|
48
|
-
# Get probability map from the predicted logits
|
49
|
-
pred_mask = torch.from_numpy(pred_mask)
|
50
|
-
pred_mask = torch.softmax(pred_mask, dim=0)
|
51
|
-
pred_mask = pred_mask[1].cpu().numpy()
|
52
|
-
|
53
|
-
# Apply morphological post-processing
|
54
|
-
pred_mask = pred_mask > 0.5
|
55
|
-
pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1)
|
56
|
-
pred_mask = morphology.remove_small_objects(pred_mask, 16)
|
57
|
-
pred_mask = measure.label(pred_mask)
|
58
|
-
|
59
|
-
return pred_mask
|
@@ -1,113 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import os, sys
|
3
|
-
import monai
|
4
|
-
|
5
|
-
from monai.data import decollate_batch
|
6
|
-
|
7
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
|
8
|
-
|
9
|
-
from core.BaseTrainer import BaseTrainer
|
10
|
-
from core.Baseline.utils import create_interior_onehot, identify_instances_from_classmap
|
11
|
-
from train_tools.measures import evaluate_f1_score_cellseg
|
12
|
-
from tqdm import tqdm
|
13
|
-
|
14
|
-
__all__ = ["Trainer"]
|
15
|
-
|
16
|
-
|
17
|
-
class Trainer(BaseTrainer):
|
18
|
-
def __init__(
|
19
|
-
self,
|
20
|
-
model,
|
21
|
-
dataloaders,
|
22
|
-
optimizer,
|
23
|
-
scheduler=None,
|
24
|
-
criterion=None,
|
25
|
-
num_epochs=100,
|
26
|
-
device="cuda:0",
|
27
|
-
no_valid=False,
|
28
|
-
valid_frequency=1,
|
29
|
-
amp=False,
|
30
|
-
algo_params=None,
|
31
|
-
):
|
32
|
-
super(Trainer, self).__init__(
|
33
|
-
model,
|
34
|
-
dataloaders,
|
35
|
-
optimizer,
|
36
|
-
scheduler,
|
37
|
-
criterion,
|
38
|
-
num_epochs,
|
39
|
-
device,
|
40
|
-
no_valid,
|
41
|
-
valid_frequency,
|
42
|
-
amp,
|
43
|
-
algo_params,
|
44
|
-
)
|
45
|
-
|
46
|
-
# Dice loss as segmentation criterion
|
47
|
-
self.criterion = monai.losses.DiceCELoss(softmax=True)
|
48
|
-
|
49
|
-
def _epoch_phase(self, phase):
|
50
|
-
"""Learning process for 1 Epoch."""
|
51
|
-
|
52
|
-
phase_results = {}
|
53
|
-
|
54
|
-
# Set model mode
|
55
|
-
self.model.train() if phase == "train" else self.model.eval()
|
56
|
-
|
57
|
-
# Epoch process
|
58
|
-
for batch_data in tqdm(self.dataloaders[phase]):
|
59
|
-
images = batch_data["img"].to(self.device)
|
60
|
-
labels = batch_data["label"].to(self.device)
|
61
|
-
self.optimizer.zero_grad()
|
62
|
-
|
63
|
-
# Map label masks to 3-class onehot map
|
64
|
-
labels_onehot = create_interior_onehot(labels)
|
65
|
-
|
66
|
-
# Forward pass
|
67
|
-
with torch.set_grad_enabled(phase == "train"):
|
68
|
-
outputs = self._inference(images, phase)
|
69
|
-
loss = self.criterion(outputs, labels_onehot)
|
70
|
-
self.loss_metric.append(loss)
|
71
|
-
|
72
|
-
if phase != "train":
|
73
|
-
f1_score = self._get_f1_metric(outputs, labels)
|
74
|
-
self.f1_metric.append(f1_score)
|
75
|
-
|
76
|
-
# Backward pass
|
77
|
-
if phase == "train":
|
78
|
-
# For the mixed precision training
|
79
|
-
if self.amp:
|
80
|
-
self.scaler.scale(loss).backward()
|
81
|
-
self.scaler.unscale_(self.optimizer)
|
82
|
-
self.scaler.step(self.optimizer)
|
83
|
-
self.scaler.update()
|
84
|
-
|
85
|
-
else:
|
86
|
-
loss.backward()
|
87
|
-
self.optimizer.step()
|
88
|
-
|
89
|
-
# Update metrics
|
90
|
-
phase_results = self._update_results(
|
91
|
-
phase_results, self.loss_metric, "loss", phase
|
92
|
-
)
|
93
|
-
|
94
|
-
if phase != "train":
|
95
|
-
phase_results = self._update_results(
|
96
|
-
phase_results, self.f1_metric, "f1_score", phase
|
97
|
-
)
|
98
|
-
|
99
|
-
return phase_results
|
100
|
-
|
101
|
-
def _post_process(self, outputs, labels_onehot):
|
102
|
-
"""Conduct post-processing for outputs & labels."""
|
103
|
-
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
|
104
|
-
labels_onehot = [self.post_gt(i) for i in decollate_batch(labels_onehot)]
|
105
|
-
|
106
|
-
return outputs, labels_onehot
|
107
|
-
|
108
|
-
def _get_f1_metric(self, masks_pred, masks_true):
|
109
|
-
masks_pred = identify_instances_from_classmap(masks_pred[0])
|
110
|
-
masks_true = masks_true.squeeze(0).squeeze(0).cpu().numpy()
|
111
|
-
f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1]
|
112
|
-
|
113
|
-
return f1_score
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -1,80 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Adapted from the following references:
|
3
|
-
[1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/model_training_3class.py
|
4
|
-
|
5
|
-
"""
|
6
|
-
|
7
|
-
import torch
|
8
|
-
import numpy as np
|
9
|
-
from skimage import segmentation, morphology, measure
|
10
|
-
import monai
|
11
|
-
|
12
|
-
|
13
|
-
__all__ = ["create_interior_onehot", "identify_instances_from_classmap"]
|
14
|
-
|
15
|
-
|
16
|
-
@torch.no_grad()
|
17
|
-
def identify_instances_from_classmap(
|
18
|
-
class_map, cell_class=1, threshold=0.5, from_logits=True
|
19
|
-
):
|
20
|
-
"""Identification of cell instances from the class map"""
|
21
|
-
|
22
|
-
if from_logits:
|
23
|
-
class_map = torch.softmax(class_map, dim=0) # (C, H, W)
|
24
|
-
|
25
|
-
# Convert probability map to binary mask
|
26
|
-
pred_mask = class_map[cell_class].cpu().numpy()
|
27
|
-
|
28
|
-
# Apply morphological postprocessing
|
29
|
-
pred_mask = pred_mask > threshold
|
30
|
-
pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1)
|
31
|
-
pred_mask = morphology.remove_small_objects(pred_mask, 16)
|
32
|
-
pred_mask = measure.label(pred_mask)
|
33
|
-
|
34
|
-
return pred_mask
|
35
|
-
|
36
|
-
|
37
|
-
@torch.no_grad()
|
38
|
-
def create_interior_onehot(inst_maps):
|
39
|
-
"""
|
40
|
-
interior : (H,W), np.uint8
|
41
|
-
three-class map, values: 0,1,2
|
42
|
-
0: background
|
43
|
-
1: interior
|
44
|
-
2: boundary
|
45
|
-
"""
|
46
|
-
device = inst_maps.device
|
47
|
-
|
48
|
-
# Get (np.int16) array corresponding to label masks: (B, 1, H, W)
|
49
|
-
inst_maps = inst_maps.squeeze(1).cpu().numpy().astype(np.int16)
|
50
|
-
|
51
|
-
interior_maps = []
|
52
|
-
|
53
|
-
for inst_map in inst_maps:
|
54
|
-
# Create interior-edge map
|
55
|
-
boundary = segmentation.find_boundaries(inst_map, mode="inner")
|
56
|
-
|
57
|
-
# Refine interior-edge map
|
58
|
-
boundary = morphology.binary_dilation(boundary, morphology.disk(1))
|
59
|
-
|
60
|
-
# Assign label classes
|
61
|
-
interior_temp = np.logical_and(~boundary, inst_map > 0)
|
62
|
-
|
63
|
-
# interior_temp[boundary] = 0
|
64
|
-
interior_temp = morphology.remove_small_objects(interior_temp, min_size=16)
|
65
|
-
interior = np.zeros_like(inst_map, dtype=np.uint8)
|
66
|
-
interior[interior_temp] = 1
|
67
|
-
interior[boundary] = 2
|
68
|
-
|
69
|
-
interior_maps.append(interior)
|
70
|
-
|
71
|
-
# Aggregate interior_maps for batch
|
72
|
-
interior_maps = np.stack(interior_maps, axis=0).astype(np.uint8)
|
73
|
-
|
74
|
-
# Reshape as original label shape: (B, H, W)
|
75
|
-
interior_maps = torch.from_numpy(interior_maps).unsqueeze(1).to(device)
|
76
|
-
|
77
|
-
# Obtain one-hot map for batch
|
78
|
-
interior_onehot = monai.networks.one_hot(interior_maps, num_classes=3)
|
79
|
-
|
80
|
-
return interior_onehot
|