spacr 0.5.0__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -2
- spacr/__main__.py +3 -3
- spacr/core.py +13 -106
- spacr/gui_core.py +2 -77
- spacr/gui_utils.py +1 -13
- spacr/io.py +24 -25
- spacr/mediar.py +12 -8
- spacr/plot.py +50 -135
- spacr/settings.py +42 -30
- spacr/submodules.py +11 -1
- spacr/timelapse.py +7 -79
- spacr/utils.py +152 -61
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/METADATA +62 -62
- spacr-0.9.1.dist-info/RECORD +109 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/WHEEL +1 -1
- spacr/resources/MEDIAR/.gitignore +0 -18
- spacr/resources/MEDIAR/LICENSE +0 -21
- spacr/resources/MEDIAR/README.md +0 -189
- spacr/resources/MEDIAR/SetupDict.py +0 -39
- spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/config/baseline.json +0 -60
- spacr/resources/MEDIAR/config/mediar_example.json +0 -72
- spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
- spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
- spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
- spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
- spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
- spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
- spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
- spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
- spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
- spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
- spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
- spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
- spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
- spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
- spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
- spacr/resources/MEDIAR/core/__init__.py +0 -2
- spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/core/utils.py +0 -40
- spacr/resources/MEDIAR/evaluate.py +0 -71
- spacr/resources/MEDIAR/generate_mapping.py +0 -121
- spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
- spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
- spacr/resources/MEDIAR/image/failure_cases.png +0 -0
- spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
- spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
- spacr/resources/MEDIAR/image/mediar_results.png +0 -0
- spacr/resources/MEDIAR/main.py +0 -125
- spacr/resources/MEDIAR/predict.py +0 -70
- spacr/resources/MEDIAR/requirements.txt +0 -14
- spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
- spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
- spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
- spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
- spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
- spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
- spacr/resources/MEDIAR/train_tools/measures.py +0 -200
- spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
- spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
- spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
- spacr/resources/MEDIAR/train_tools/utils.py +0 -70
- spacr-0.5.0.dist-info/RECORD +0 -190
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/LICENSE +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,105 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import os, sys, copy
|
3
|
-
import numpy as np
|
4
|
-
|
5
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
|
6
|
-
|
7
|
-
from core.MEDIAR.Predictor import Predictor
|
8
|
-
|
9
|
-
__all__ = ["EnsemblePredictor"]
|
10
|
-
|
11
|
-
|
12
|
-
class EnsemblePredictor(Predictor):
|
13
|
-
def __init__(
|
14
|
-
self,
|
15
|
-
model,
|
16
|
-
model_aux,
|
17
|
-
device,
|
18
|
-
input_path,
|
19
|
-
output_path,
|
20
|
-
make_submission=False,
|
21
|
-
exp_name=None,
|
22
|
-
algo_params=None,
|
23
|
-
):
|
24
|
-
super(EnsemblePredictor, self).__init__(
|
25
|
-
model,
|
26
|
-
device,
|
27
|
-
input_path,
|
28
|
-
output_path,
|
29
|
-
make_submission,
|
30
|
-
exp_name,
|
31
|
-
algo_params,
|
32
|
-
)
|
33
|
-
self.model_aux = model_aux
|
34
|
-
|
35
|
-
@torch.no_grad()
|
36
|
-
def _inference(self, img_data):
|
37
|
-
|
38
|
-
self.model_aux.to(self.device)
|
39
|
-
self.model_aux.eval()
|
40
|
-
|
41
|
-
img_data = img_data.to(self.device)
|
42
|
-
img_base = img_data
|
43
|
-
|
44
|
-
outputs_base = self._window_inference(img_base)
|
45
|
-
outputs_base = outputs_base.cpu().squeeze()
|
46
|
-
|
47
|
-
outputs_aux = self._window_inference(img_base, aux=True)
|
48
|
-
outputs_aux = outputs_aux.cpu().squeeze()
|
49
|
-
img_base.cpu()
|
50
|
-
|
51
|
-
if not self.use_tta:
|
52
|
-
pred_mask = (outputs_base + outputs_aux) / 2
|
53
|
-
return pred_mask
|
54
|
-
|
55
|
-
else:
|
56
|
-
# HorizontalFlip TTA
|
57
|
-
img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True)
|
58
|
-
|
59
|
-
outputs_hflip = self._window_inference(img_hflip)
|
60
|
-
outputs_hflip_aux = self._window_inference(img_hflip, aux=True)
|
61
|
-
|
62
|
-
outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True)
|
63
|
-
outputs_hflip_aux = self.hflip_tta.apply_deaug_mask(
|
64
|
-
outputs_hflip_aux, apply=True
|
65
|
-
)
|
66
|
-
|
67
|
-
outputs_hflip = outputs_hflip.cpu().squeeze()
|
68
|
-
outputs_hflip_aux = outputs_hflip_aux.cpu().squeeze()
|
69
|
-
img_hflip = img_hflip.cpu()
|
70
|
-
|
71
|
-
# VertricalFlip TTA
|
72
|
-
img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True)
|
73
|
-
|
74
|
-
outputs_vflip = self._window_inference(img_vflip)
|
75
|
-
outputs_vflip_aux = self._window_inference(img_vflip, aux=True)
|
76
|
-
|
77
|
-
outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True)
|
78
|
-
outputs_vflip_aux = self.vflip_tta.apply_deaug_mask(
|
79
|
-
outputs_vflip_aux, apply=True
|
80
|
-
)
|
81
|
-
|
82
|
-
outputs_vflip = outputs_vflip.cpu().squeeze()
|
83
|
-
outputs_vflip_aux = outputs_vflip_aux.cpu().squeeze()
|
84
|
-
img_vflip = img_vflip.cpu()
|
85
|
-
|
86
|
-
# Merge Results
|
87
|
-
pred_mask = torch.zeros_like(outputs_base)
|
88
|
-
pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3
|
89
|
-
pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3
|
90
|
-
pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3
|
91
|
-
|
92
|
-
pred_mask_aux = torch.zeros_like(outputs_aux)
|
93
|
-
pred_mask_aux[0] = (
|
94
|
-
outputs_aux[0] + outputs_hflip_aux[0] - outputs_vflip_aux[0]
|
95
|
-
) / 3
|
96
|
-
pred_mask_aux[1] = (
|
97
|
-
outputs_aux[1] - outputs_hflip_aux[1] + outputs_vflip_aux[1]
|
98
|
-
) / 3
|
99
|
-
pred_mask_aux[2] = (
|
100
|
-
outputs_aux[2] + outputs_hflip_aux[2] + outputs_vflip_aux[2]
|
101
|
-
) / 3
|
102
|
-
|
103
|
-
pred_mask = (pred_mask + pred_mask_aux) / 2
|
104
|
-
|
105
|
-
return pred_mask
|
@@ -1,234 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import numpy as np
|
3
|
-
import os, sys
|
4
|
-
from monai.inferers import sliding_window_inference
|
5
|
-
|
6
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
|
7
|
-
|
8
|
-
from core.BasePredictor import BasePredictor
|
9
|
-
from core.MEDIAR.utils import compute_masks
|
10
|
-
|
11
|
-
__all__ = ["Predictor"]
|
12
|
-
|
13
|
-
|
14
|
-
class Predictor(BasePredictor):
|
15
|
-
def __init__(
|
16
|
-
self,
|
17
|
-
model,
|
18
|
-
device,
|
19
|
-
input_path,
|
20
|
-
output_path,
|
21
|
-
make_submission=False,
|
22
|
-
exp_name=None,
|
23
|
-
algo_params=None,
|
24
|
-
):
|
25
|
-
super(Predictor, self).__init__(
|
26
|
-
model,
|
27
|
-
device,
|
28
|
-
input_path,
|
29
|
-
output_path,
|
30
|
-
make_submission,
|
31
|
-
exp_name,
|
32
|
-
algo_params,
|
33
|
-
)
|
34
|
-
self.hflip_tta = HorizontalFlip()
|
35
|
-
self.vflip_tta = VerticalFlip()
|
36
|
-
|
37
|
-
@torch.no_grad()
|
38
|
-
def _inference(self, img_data):
|
39
|
-
"""Conduct model prediction"""
|
40
|
-
|
41
|
-
img_data = img_data.to(self.device)
|
42
|
-
img_base = img_data
|
43
|
-
outputs_base = self._window_inference(img_base)
|
44
|
-
outputs_base = outputs_base.cpu().squeeze()
|
45
|
-
img_base.cpu()
|
46
|
-
|
47
|
-
if not self.use_tta:
|
48
|
-
pred_mask = outputs_base
|
49
|
-
return pred_mask
|
50
|
-
|
51
|
-
else:
|
52
|
-
# HorizontalFlip TTA
|
53
|
-
img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True)
|
54
|
-
outputs_hflip = self._window_inference(img_hflip)
|
55
|
-
outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True)
|
56
|
-
outputs_hflip = outputs_hflip.cpu().squeeze()
|
57
|
-
img_hflip = img_hflip.cpu()
|
58
|
-
|
59
|
-
# VertricalFlip TTA
|
60
|
-
img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True)
|
61
|
-
outputs_vflip = self._window_inference(img_vflip)
|
62
|
-
outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True)
|
63
|
-
outputs_vflip = outputs_vflip.cpu().squeeze()
|
64
|
-
img_vflip = img_vflip.cpu()
|
65
|
-
|
66
|
-
# Merge Results
|
67
|
-
pred_mask = torch.zeros_like(outputs_base)
|
68
|
-
pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3
|
69
|
-
pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3
|
70
|
-
pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3
|
71
|
-
|
72
|
-
return pred_mask
|
73
|
-
|
74
|
-
def _window_inference(self, img_data, aux=False):
|
75
|
-
"""Inference on RoI-sized window"""
|
76
|
-
outputs = sliding_window_inference(
|
77
|
-
img_data,
|
78
|
-
roi_size=512,
|
79
|
-
sw_batch_size=4,
|
80
|
-
predictor=self.model if not aux else self.model_aux,
|
81
|
-
padding_mode="constant",
|
82
|
-
mode="gaussian",
|
83
|
-
overlap=0.6,
|
84
|
-
)
|
85
|
-
|
86
|
-
return outputs
|
87
|
-
|
88
|
-
def _post_process(self, pred_mask):
|
89
|
-
"""Generate cell instance masks."""
|
90
|
-
dP, cellprob = pred_mask[:2], self._sigmoid(pred_mask[-1])
|
91
|
-
H, W = pred_mask.shape[-2], pred_mask.shape[-1]
|
92
|
-
|
93
|
-
if np.prod(H * W) < (5000 * 5000):
|
94
|
-
pred_mask = compute_masks(
|
95
|
-
dP,
|
96
|
-
cellprob,
|
97
|
-
use_gpu=True,
|
98
|
-
flow_threshold=0.4,
|
99
|
-
device=self.device,
|
100
|
-
cellprob_threshold=0.5,
|
101
|
-
)[0]
|
102
|
-
|
103
|
-
else:
|
104
|
-
print("\n[Whole Slide] Grid Prediction starting...")
|
105
|
-
roi_size = 2000
|
106
|
-
|
107
|
-
# Get patch grid by roi_size
|
108
|
-
if H % roi_size != 0:
|
109
|
-
n_H = H // roi_size + 1
|
110
|
-
new_H = roi_size * n_H
|
111
|
-
else:
|
112
|
-
n_H = H // roi_size
|
113
|
-
new_H = H
|
114
|
-
|
115
|
-
if W % roi_size != 0:
|
116
|
-
n_W = W // roi_size + 1
|
117
|
-
new_W = roi_size * n_W
|
118
|
-
else:
|
119
|
-
n_W = W // roi_size
|
120
|
-
new_W = W
|
121
|
-
|
122
|
-
# Allocate values on the grid
|
123
|
-
pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
|
124
|
-
dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
|
125
|
-
cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
|
126
|
-
|
127
|
-
dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
|
128
|
-
|
129
|
-
for i in range(n_H):
|
130
|
-
for j in range(n_W):
|
131
|
-
print("Pred on Grid (%d, %d) processing..." % (i, j))
|
132
|
-
dP_roi = dP_pad[
|
133
|
-
:,
|
134
|
-
roi_size * i : roi_size * (i + 1),
|
135
|
-
roi_size * j : roi_size * (j + 1),
|
136
|
-
]
|
137
|
-
cellprob_roi = cellprob_pad[
|
138
|
-
roi_size * i : roi_size * (i + 1),
|
139
|
-
roi_size * j : roi_size * (j + 1),
|
140
|
-
]
|
141
|
-
|
142
|
-
pred_mask = compute_masks(
|
143
|
-
dP_roi,
|
144
|
-
cellprob_roi,
|
145
|
-
use_gpu=True,
|
146
|
-
flow_threshold=0.4,
|
147
|
-
device=self.device,
|
148
|
-
cellprob_threshold=0.5,
|
149
|
-
)[0]
|
150
|
-
|
151
|
-
pred_pad[
|
152
|
-
roi_size * i : roi_size * (i + 1),
|
153
|
-
roi_size * j : roi_size * (j + 1),
|
154
|
-
] = pred_mask
|
155
|
-
|
156
|
-
pred_mask = pred_pad[:H, :W]
|
157
|
-
|
158
|
-
return pred_mask
|
159
|
-
|
160
|
-
def _sigmoid(self, z):
|
161
|
-
return 1 / (1 + np.exp(-z))
|
162
|
-
|
163
|
-
|
164
|
-
"""
|
165
|
-
Adapted from the following references:
|
166
|
-
[1] https://github.com/qubvel/ttach/blob/master/ttach/transforms.py
|
167
|
-
|
168
|
-
"""
|
169
|
-
|
170
|
-
|
171
|
-
def hflip(x):
|
172
|
-
"""flip batch of images horizontally"""
|
173
|
-
return x.flip(3)
|
174
|
-
|
175
|
-
|
176
|
-
def vflip(x):
|
177
|
-
"""flip batch of images vertically"""
|
178
|
-
return x.flip(2)
|
179
|
-
|
180
|
-
|
181
|
-
class DualTransform:
|
182
|
-
identity_param = None
|
183
|
-
|
184
|
-
def __init__(
|
185
|
-
self, name: str, params,
|
186
|
-
):
|
187
|
-
self.params = params
|
188
|
-
self.pname = name
|
189
|
-
|
190
|
-
def apply_aug_image(self, image, *args, **params):
|
191
|
-
raise NotImplementedError
|
192
|
-
|
193
|
-
def apply_deaug_mask(self, mask, *args, **params):
|
194
|
-
raise NotImplementedError
|
195
|
-
|
196
|
-
|
197
|
-
class HorizontalFlip(DualTransform):
|
198
|
-
"""Flip images horizontally (left -> right)"""
|
199
|
-
|
200
|
-
identity_param = False
|
201
|
-
|
202
|
-
def __init__(self):
|
203
|
-
super().__init__("apply", [False, True])
|
204
|
-
|
205
|
-
def apply_aug_image(self, image, apply=False, **kwargs):
|
206
|
-
if apply:
|
207
|
-
image = hflip(image)
|
208
|
-
return image
|
209
|
-
|
210
|
-
def apply_deaug_mask(self, mask, apply=False, **kwargs):
|
211
|
-
if apply:
|
212
|
-
mask = hflip(mask)
|
213
|
-
return mask
|
214
|
-
|
215
|
-
|
216
|
-
class VerticalFlip(DualTransform):
|
217
|
-
"""Flip images vertically (up -> down)"""
|
218
|
-
|
219
|
-
identity_param = False
|
220
|
-
|
221
|
-
def __init__(self):
|
222
|
-
super().__init__("apply", [False, True])
|
223
|
-
|
224
|
-
def apply_aug_image(self, image, apply=False, **kwargs):
|
225
|
-
if apply:
|
226
|
-
image = vflip(image)
|
227
|
-
|
228
|
-
return image
|
229
|
-
|
230
|
-
def apply_deaug_mask(self, mask, apply=False, **kwargs):
|
231
|
-
if apply:
|
232
|
-
mask = vflip(mask)
|
233
|
-
|
234
|
-
return mask
|
@@ -1,172 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
import torch.nn as nn
|
3
|
-
import numpy as np
|
4
|
-
import os, sys
|
5
|
-
from tqdm import tqdm
|
6
|
-
from monai.inferers import sliding_window_inference
|
7
|
-
|
8
|
-
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
|
9
|
-
|
10
|
-
from core.BaseTrainer import BaseTrainer
|
11
|
-
from core.MEDIAR.utils import *
|
12
|
-
|
13
|
-
__all__ = ["Trainer"]
|
14
|
-
|
15
|
-
|
16
|
-
class Trainer(BaseTrainer):
|
17
|
-
def __init__(
|
18
|
-
self,
|
19
|
-
model,
|
20
|
-
dataloaders,
|
21
|
-
optimizer,
|
22
|
-
scheduler=None,
|
23
|
-
criterion=None,
|
24
|
-
num_epochs=100,
|
25
|
-
device="cuda:0",
|
26
|
-
no_valid=False,
|
27
|
-
valid_frequency=1,
|
28
|
-
amp=False,
|
29
|
-
algo_params=None,
|
30
|
-
):
|
31
|
-
super(Trainer, self).__init__(
|
32
|
-
model,
|
33
|
-
dataloaders,
|
34
|
-
optimizer,
|
35
|
-
scheduler,
|
36
|
-
criterion,
|
37
|
-
num_epochs,
|
38
|
-
device,
|
39
|
-
no_valid,
|
40
|
-
valid_frequency,
|
41
|
-
amp,
|
42
|
-
algo_params,
|
43
|
-
)
|
44
|
-
|
45
|
-
self.mse_loss = nn.MSELoss(reduction="mean")
|
46
|
-
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
|
47
|
-
|
48
|
-
def mediar_criterion(self, outputs, labels_onehot_flows):
|
49
|
-
"""loss function between true labels and prediction outputs"""
|
50
|
-
|
51
|
-
# Cell Recognition Loss
|
52
|
-
cellprob_loss = self.bce_loss(
|
53
|
-
outputs[:, -1],
|
54
|
-
torch.from_numpy(labels_onehot_flows[:, 1] > 0.5).to(self.device).float(),
|
55
|
-
)
|
56
|
-
|
57
|
-
# Cell Distinction Loss
|
58
|
-
gradient_flows = torch.from_numpy(labels_onehot_flows[:, 2:]).to(self.device)
|
59
|
-
gradflow_loss = 0.5 * self.mse_loss(outputs[:, :2], 5.0 * gradient_flows)
|
60
|
-
|
61
|
-
loss = cellprob_loss + gradflow_loss
|
62
|
-
|
63
|
-
return loss
|
64
|
-
|
65
|
-
def _epoch_phase(self, phase):
|
66
|
-
phase_results = {}
|
67
|
-
|
68
|
-
# Set model mode
|
69
|
-
self.model.train() if phase == "train" else self.model.eval()
|
70
|
-
|
71
|
-
# Epoch process
|
72
|
-
for batch_data in tqdm(self.dataloaders[phase]):
|
73
|
-
images, labels = batch_data["img"], batch_data["label"]
|
74
|
-
|
75
|
-
if self.with_public:
|
76
|
-
# Load batches sequentially from the unlabeled dataloader
|
77
|
-
try:
|
78
|
-
batch_data = next(self.public_iterator)
|
79
|
-
images_pub, labels_pub = batch_data["img"], batch_data["label"]
|
80
|
-
|
81
|
-
except:
|
82
|
-
# Assign memory loader if the cycle ends
|
83
|
-
self.public_iterator = iter(self.public_loader)
|
84
|
-
batch_data = next(self.public_iterator)
|
85
|
-
images_pub, labels_pub = batch_data["img"], batch_data["label"]
|
86
|
-
|
87
|
-
# Concat memory data to the batch
|
88
|
-
images = torch.cat([images, images_pub], dim=0)
|
89
|
-
labels = torch.cat([labels, labels_pub], dim=0)
|
90
|
-
|
91
|
-
images = images.to(self.device)
|
92
|
-
labels = labels.to(self.device)
|
93
|
-
|
94
|
-
self.optimizer.zero_grad()
|
95
|
-
|
96
|
-
# Forward pass
|
97
|
-
with torch.cuda.amp.autocast(enabled=self.amp):
|
98
|
-
with torch.set_grad_enabled(phase == "train"):
|
99
|
-
# Output shape is B x [grad y, grad x, cellprob] x H x W
|
100
|
-
outputs = self._inference(images, phase)
|
101
|
-
|
102
|
-
# Map label masks to graidnet and onehot
|
103
|
-
labels_onehot_flows = labels_to_flows(
|
104
|
-
labels, use_gpu=True, device=self.device
|
105
|
-
)
|
106
|
-
# Calculate loss
|
107
|
-
loss = self.mediar_criterion(outputs, labels_onehot_flows)
|
108
|
-
self.loss_metric.append(loss)
|
109
|
-
|
110
|
-
# Calculate valid statistics
|
111
|
-
if phase != "train":
|
112
|
-
outputs, labels = self._post_process(outputs, labels)
|
113
|
-
f1_score = self._get_f1_metric(outputs, labels)
|
114
|
-
self.f1_metric.append(f1_score)
|
115
|
-
|
116
|
-
# Backward pass
|
117
|
-
if phase == "train":
|
118
|
-
# For the mixed precision training
|
119
|
-
if self.amp:
|
120
|
-
self.scaler.scale(loss).backward()
|
121
|
-
self.scaler.unscale_(self.optimizer)
|
122
|
-
self.scaler.step(self.optimizer)
|
123
|
-
self.scaler.update()
|
124
|
-
|
125
|
-
else:
|
126
|
-
loss.backward()
|
127
|
-
self.optimizer.step()
|
128
|
-
|
129
|
-
# Update metrics
|
130
|
-
phase_results = self._update_results(
|
131
|
-
phase_results, self.loss_metric, "dice_loss", phase
|
132
|
-
)
|
133
|
-
if phase != "train":
|
134
|
-
phase_results = self._update_results(
|
135
|
-
phase_results, self.f1_metric, "f1_score", phase
|
136
|
-
)
|
137
|
-
|
138
|
-
return phase_results
|
139
|
-
|
140
|
-
def _inference(self, images, phase="train"):
|
141
|
-
"""inference methods for different phase"""
|
142
|
-
|
143
|
-
if phase != "train":
|
144
|
-
outputs = sliding_window_inference(
|
145
|
-
images,
|
146
|
-
roi_size=512,
|
147
|
-
sw_batch_size=4,
|
148
|
-
predictor=self.model,
|
149
|
-
padding_mode="constant",
|
150
|
-
mode="gaussian",
|
151
|
-
overlap=0.5,
|
152
|
-
)
|
153
|
-
else:
|
154
|
-
outputs = self.model(images)
|
155
|
-
|
156
|
-
return outputs
|
157
|
-
|
158
|
-
def _post_process(self, outputs, labels=None):
|
159
|
-
"""Predict cell instances using the gradient tracking"""
|
160
|
-
outputs = outputs.squeeze(0).cpu().numpy()
|
161
|
-
gradflows, cellprob = outputs[:2], self._sigmoid(outputs[-1])
|
162
|
-
outputs = compute_masks(gradflows, cellprob, use_gpu=True, device=self.device)
|
163
|
-
outputs = outputs[0] # (1, C, H, W) -> (C, H, W)
|
164
|
-
|
165
|
-
if labels is not None:
|
166
|
-
labels = labels.squeeze(0).squeeze(0).cpu().numpy()
|
167
|
-
|
168
|
-
return outputs, labels
|
169
|
-
|
170
|
-
def _sigmoid(self, z):
|
171
|
-
"""Sigmoid function for numpy arrays"""
|
172
|
-
return 1 / (1 + np.exp(-z))
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|