spacr 0.5.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. spacr/__init__.py +0 -2
  2. spacr/__main__.py +3 -3
  3. spacr/core.py +13 -106
  4. spacr/gui_core.py +2 -77
  5. spacr/gui_utils.py +1 -13
  6. spacr/io.py +24 -25
  7. spacr/mediar.py +12 -8
  8. spacr/plot.py +50 -135
  9. spacr/settings.py +42 -30
  10. spacr/submodules.py +11 -1
  11. spacr/timelapse.py +7 -79
  12. spacr/utils.py +152 -61
  13. {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/METADATA +62 -62
  14. spacr-0.9.1.dist-info/RECORD +109 -0
  15. {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/WHEEL +1 -1
  16. spacr/resources/MEDIAR/.gitignore +0 -18
  17. spacr/resources/MEDIAR/LICENSE +0 -21
  18. spacr/resources/MEDIAR/README.md +0 -189
  19. spacr/resources/MEDIAR/SetupDict.py +0 -39
  20. spacr/resources/MEDIAR/__pycache__/SetupDict.cpython-39.pyc +0 -0
  21. spacr/resources/MEDIAR/__pycache__/evaluate.cpython-39.pyc +0 -0
  22. spacr/resources/MEDIAR/__pycache__/generate_mapping.cpython-39.pyc +0 -0
  23. spacr/resources/MEDIAR/__pycache__/main.cpython-39.pyc +0 -0
  24. spacr/resources/MEDIAR/config/baseline.json +0 -60
  25. spacr/resources/MEDIAR/config/mediar_example.json +0 -72
  26. spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
  27. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
  28. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
  29. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
  30. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
  31. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
  32. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
  33. spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
  34. spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
  35. spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
  36. spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
  37. spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
  38. spacr/resources/MEDIAR/core/Baseline/__pycache__/Predictor.cpython-39.pyc +0 -0
  39. spacr/resources/MEDIAR/core/Baseline/__pycache__/Trainer.cpython-39.pyc +0 -0
  40. spacr/resources/MEDIAR/core/Baseline/__pycache__/__init__.cpython-39.pyc +0 -0
  41. spacr/resources/MEDIAR/core/Baseline/__pycache__/utils.cpython-39.pyc +0 -0
  42. spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
  43. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
  44. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
  45. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
  46. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
  47. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/EnsemblePredictor.cpython-39.pyc +0 -0
  48. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Predictor.cpython-39.pyc +0 -0
  49. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/Trainer.cpython-39.pyc +0 -0
  50. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/__init__.cpython-39.pyc +0 -0
  51. spacr/resources/MEDIAR/core/MEDIAR/__pycache__/utils.cpython-39.pyc +0 -0
  52. spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
  53. spacr/resources/MEDIAR/core/__init__.py +0 -2
  54. spacr/resources/MEDIAR/core/__pycache__/BasePredictor.cpython-39.pyc +0 -0
  55. spacr/resources/MEDIAR/core/__pycache__/BaseTrainer.cpython-39.pyc +0 -0
  56. spacr/resources/MEDIAR/core/__pycache__/__init__.cpython-39.pyc +0 -0
  57. spacr/resources/MEDIAR/core/__pycache__/utils.cpython-39.pyc +0 -0
  58. spacr/resources/MEDIAR/core/utils.py +0 -40
  59. spacr/resources/MEDIAR/evaluate.py +0 -71
  60. spacr/resources/MEDIAR/generate_mapping.py +0 -121
  61. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  62. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  63. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  64. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  65. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  66. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  67. spacr/resources/MEDIAR/main.py +0 -125
  68. spacr/resources/MEDIAR/predict.py +0 -70
  69. spacr/resources/MEDIAR/requirements.txt +0 -14
  70. spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
  71. spacr/resources/MEDIAR/train_tools/__pycache__/__init__.cpython-39.pyc +0 -0
  72. spacr/resources/MEDIAR/train_tools/__pycache__/measures.cpython-39.pyc +0 -0
  73. spacr/resources/MEDIAR/train_tools/__pycache__/utils.cpython-39.pyc +0 -0
  74. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
  75. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  76. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/datasetter.cpython-39.pyc +0 -0
  77. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/transforms.cpython-39.pyc +0 -0
  78. spacr/resources/MEDIAR/train_tools/data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  79. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
  80. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
  81. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
  82. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
  83. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/CellAware.cpython-39.pyc +0 -0
  84. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/LoadImage.cpython-39.pyc +0 -0
  85. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/NormalizeImage.cpython-39.pyc +0 -0
  86. spacr/resources/MEDIAR/train_tools/data_utils/custom/__pycache__/__init__.cpython-39.pyc +0 -0
  87. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  88. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
  89. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
  90. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
  91. spacr/resources/MEDIAR/train_tools/measures.py +0 -200
  92. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
  93. spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
  94. spacr/resources/MEDIAR/train_tools/models/__pycache__/MEDIARFormer.cpython-39.pyc +0 -0
  95. spacr/resources/MEDIAR/train_tools/models/__pycache__/__init__.cpython-39.pyc +0 -0
  96. spacr/resources/MEDIAR/train_tools/utils.py +0 -70
  97. spacr-0.5.0.dist-info/RECORD +0 -190
  98. {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/LICENSE +0 -0
  99. {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/entry_points.txt +0 -0
  100. {spacr-0.5.0.dist-info → spacr-0.9.1.dist-info}/top_level.txt +0 -0
@@ -1,23 +0,0 @@
1
- {
2
- "pred_setups":{
3
- "name": "ensemble_mediar",
4
- "input_path":"/home/gihun/MEDIAR/data/Official/Tuning/images",
5
- "output_path": "./results/mediar_ensemble_tta",
6
- "make_submission": true,
7
- "model_path1": "./weights/finetuned/from_phase1.pth",
8
- "model_path2": "./weights/finetuned/from_phase2.pth",
9
- "device": "cuda:0",
10
- "model":{
11
- "name": "mediar-former",
12
- "params": {
13
- "encoder_name":"mit_b5",
14
- "decoder_channels": [1024, 512, 256, 128, 64],
15
- "decoder_pab_channels": 256,
16
- "in_channels":3,
17
- "classes":3
18
- }
19
- },
20
- "exp_name": "mediar_ensemble_tta",
21
- "algo_params": {"use_tta": true}
22
- }
23
- }
@@ -1,120 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import time, os
4
- import tifffile as tif
5
-
6
- from datetime import datetime
7
- from zipfile import ZipFile
8
- from pytz import timezone
9
-
10
- from train_tools.data_utils.transforms import get_pred_transforms
11
-
12
-
13
- class BasePredictor:
14
- def __init__(
15
- self,
16
- model,
17
- device,
18
- input_path,
19
- output_path,
20
- make_submission=False,
21
- exp_name=None,
22
- algo_params=None,
23
- ):
24
- self.model = model
25
- self.device = device
26
- self.input_path = input_path
27
- self.output_path = output_path
28
- self.make_submission = make_submission
29
- self.exp_name = exp_name
30
-
31
- # Assign algoritm-specific arguments
32
- if algo_params:
33
- self.__dict__.update((k, v) for k, v in algo_params.items())
34
-
35
- # Prepare inference environments
36
- self._setups()
37
-
38
- @torch.no_grad()
39
- def conduct_prediction(self):
40
- self.model.to(self.device)
41
- self.model.eval()
42
- total_time = 0
43
- total_times = []
44
-
45
- for img_name in self.img_names:
46
- img_data = self._get_img_data(img_name)
47
- img_data = img_data.to(self.device)
48
-
49
- start = time.time()
50
-
51
- pred_mask = self._inference(img_data)
52
- pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
53
- self.write_pred_mask(
54
- pred_mask, self.output_path, img_name, self.make_submission
55
- )
56
-
57
- end = time.time()
58
-
59
- time_cost = end - start
60
- total_times.append(time_cost)
61
- total_time += time_cost
62
- print(
63
- f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
64
- )
65
-
66
- print(f"\n Total Time Cost: {total_time:.2f}s")
67
-
68
- if self.make_submission:
69
- fname = "%s.zip" % self.exp_name
70
-
71
- os.makedirs("./submissions", exist_ok=True)
72
- submission_path = os.path.join("./submissions", fname)
73
-
74
- with ZipFile(submission_path, "w") as zipObj2:
75
- pred_names = sorted(os.listdir(self.output_path))
76
- for pred_name in pred_names:
77
- pred_path = os.path.join(self.output_path, pred_name)
78
- zipObj2.write(pred_path)
79
-
80
- print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
81
-
82
- return time_cost
83
-
84
- def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
85
-
86
- # All images should contain at least 5 cells
87
- if submission:
88
- if not (np.max(pred_mask) > 5):
89
- print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
90
-
91
- file_name = image_name.split(".")[0]
92
- file_name = file_name + "_label.tiff"
93
- file_path = os.path.join(output_dir, file_name)
94
-
95
- tif.imwrite(file_path, pred_mask, compression="zlib")
96
-
97
- def _setups(self):
98
- self.pred_transforms = get_pred_transforms()
99
- os.makedirs(self.output_path, exist_ok=True)
100
-
101
- now = datetime.now(timezone("Asia/Seoul"))
102
- dt_string = now.strftime("%m%d_%H%M")
103
- self.exp_name = (
104
- self.exp_name + dt_string if self.exp_name is not None else dt_string
105
- )
106
-
107
- self.img_names = sorted(os.listdir(self.input_path))
108
-
109
- def _get_img_data(self, img_name):
110
- img_path = os.path.join(self.input_path, img_name)
111
- img_data = self.pred_transforms(img_path)
112
- img_data = img_data.unsqueeze(0)
113
-
114
- return img_data
115
-
116
- def _inference(self, img_data):
117
- raise NotImplementedError
118
-
119
- def _post_process(self, pred_mask):
120
- raise NotImplementedError
@@ -1,240 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from tqdm import tqdm
4
- from monai.inferers import sliding_window_inference
5
- from monai.metrics import CumulativeAverage
6
- from monai.transforms import (
7
- Activations,
8
- AsDiscrete,
9
- Compose,
10
- EnsureType,
11
- )
12
-
13
- import os, sys
14
- import copy
15
-
16
- sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
17
-
18
- from core.utils import print_learning_device, print_with_logging
19
- from train_tools.measures import evaluate_f1_score_cellseg
20
-
21
-
22
- class BaseTrainer:
23
- """Abstract base class for trainer implementations"""
24
-
25
- def __init__(
26
- self,
27
- model,
28
- dataloaders,
29
- optimizer,
30
- scheduler=None,
31
- criterion=None,
32
- num_epochs=100,
33
- device="cuda:0",
34
- no_valid=False,
35
- valid_frequency=1,
36
- amp=False,
37
- algo_params=None,
38
- ):
39
- self.model = model.to(device)
40
- self.dataloaders = dataloaders
41
- self.optimizer = optimizer
42
- self.scheduler = scheduler
43
- self.criterion = criterion
44
- self.num_epochs = num_epochs
45
- self.no_valid = no_valid
46
- self.valid_frequency = valid_frequency
47
- self.device = device
48
- self.amp = amp
49
- self.best_weights = None
50
- self.best_f1_score = 0.1
51
-
52
- # FP-16 Scaler
53
- self.scaler = torch.cuda.amp.GradScaler() if amp else None
54
-
55
- # Assign algoritm-specific arguments
56
- if algo_params:
57
- self.__dict__.update((k, v) for k, v in algo_params.items())
58
-
59
- # Cumulitive statistics
60
- self.loss_metric = CumulativeAverage()
61
- self.f1_metric = CumulativeAverage()
62
-
63
- # Post-processing functions
64
- self.post_pred = Compose(
65
- [EnsureType(), Activations(softmax=True), AsDiscrete(threshold=0.5)]
66
- )
67
- self.post_gt = Compose([EnsureType(), AsDiscrete(to_onehot=None)])
68
-
69
- def train(self):
70
- """Train the model"""
71
-
72
- # Print learning device name
73
- print_learning_device(self.device)
74
-
75
- # Learning process
76
- for epoch in range(1, self.num_epochs + 1):
77
- print(f"[Round {epoch}/{self.num_epochs}]")
78
-
79
- # Train Epoch Phase
80
- print(">>> Train Epoch")
81
- train_results = self._epoch_phase("train")
82
- print_with_logging(train_results, epoch)
83
-
84
- if self.scheduler is not None:
85
- self.scheduler.step()
86
-
87
- if epoch % self.valid_frequency == 0:
88
- if not self.no_valid:
89
- # Valid Epoch Phase
90
- print(">>> Valid Epoch")
91
- valid_results = self._epoch_phase("valid")
92
- print_with_logging(valid_results, epoch)
93
-
94
- if "Valid_F1_Score" in valid_results.keys():
95
- current_f1_score = valid_results["Valid_F1_Score"]
96
- self._update_best_model(current_f1_score)
97
- else:
98
- print(">>> TuningSet Epoch")
99
- tuning_cell_counts = self._tuningset_evaluation()
100
- tuning_count_dict = {"TuningSet_Cell_Count": tuning_cell_counts}
101
- print_with_logging(tuning_count_dict, epoch)
102
-
103
- current_cell_count = tuning_cell_counts
104
- self._update_best_model(current_cell_count)
105
-
106
- print("-" * 50)
107
-
108
- self.best_f1_score = 0
109
-
110
- if self.best_weights is not None:
111
- self.model.load_state_dict(self.best_weights)
112
-
113
- def _epoch_phase(self, phase):
114
- """Learning process for 1 Epoch (for different phases).
115
-
116
- Args:
117
- phase (str): "train", "valid", "test"
118
-
119
- Returns:
120
- dict: statistics for the phase results
121
- """
122
- phase_results = {}
123
-
124
- # Set model mode
125
- self.model.train() if phase == "train" else self.model.eval()
126
-
127
- # Epoch process
128
- for batch_data in tqdm(self.dataloaders[phase]):
129
- images = batch_data["img"].to(self.device)
130
- labels = batch_data["label"].to(self.device)
131
- self.optimizer.zero_grad()
132
-
133
- # Forward pass
134
- with torch.set_grad_enabled(phase == "train"):
135
- outputs = self.model(images)
136
- loss = self.criterion(outputs, labels)
137
- self.loss_metric.append(loss)
138
-
139
- # Backward pass
140
- if phase == "train":
141
- loss.backward()
142
- self.optimizer.step()
143
-
144
- # Update metrics
145
- phase_results = self._update_results(
146
- phase_results, self.loss_metric, "loss", phase
147
- )
148
-
149
- return phase_results
150
-
151
- @torch.no_grad()
152
- def _tuningset_evaluation(self):
153
- cell_counts_total = []
154
- self.model.eval()
155
-
156
- for batch_data in tqdm(self.dataloaders["tuning"]):
157
- images = batch_data["img"].to(self.device)
158
- if images.shape[-1] > 5000:
159
- continue
160
-
161
- outputs = sliding_window_inference(
162
- images,
163
- roi_size=512,
164
- sw_batch_size=4,
165
- predictor=self.model,
166
- padding_mode="constant",
167
- mode="gaussian",
168
- )
169
-
170
- outputs = outputs.squeeze(0)
171
- outputs, _ = self._post_process(outputs, None)
172
- count = len(np.unique(outputs) - 1)
173
- cell_counts_total.append(count)
174
-
175
- cell_counts_total_sum = np.sum(cell_counts_total)
176
- print("Cell Counts Total: (%d)" % (cell_counts_total_sum))
177
-
178
- return cell_counts_total_sum
179
-
180
- def _update_results(self, phase_results, metric, metric_key, phase="train"):
181
- """Aggregate and flush metrics
182
-
183
- Args:
184
- phase_results (dict): base dictionary to log metrics
185
- metric (_type_): cumulated metrics
186
- metric_key (_type_): name of metric
187
- phase (str, optional): current phase name. Defaults to "train".
188
-
189
- Returns:
190
- dict: dictionary of metrics for the current phase
191
- """
192
-
193
- # Refine metrics name
194
- metric_key = "_".join([phase, metric_key]).title()
195
-
196
- # Aggregate metrics
197
- metric_item = round(metric.aggregate().item(), 4)
198
-
199
- # Log metrics to dictionary
200
- phase_results[metric_key] = metric_item
201
-
202
- # Flush metrics
203
- metric.reset()
204
-
205
- return phase_results
206
-
207
- def _update_best_model(self, current_f1_score):
208
- if current_f1_score > self.best_f1_score:
209
- self.best_weights = copy.deepcopy(self.model.state_dict())
210
- self.best_f1_score = current_f1_score
211
- print(
212
- "\n>>>> Update Best Model with score: {}\n".format(self.best_f1_score)
213
- )
214
- else:
215
- pass
216
-
217
- def _inference(self, images, phase="train"):
218
- """inference methods for different phase"""
219
- if phase != "train":
220
- outputs = sliding_window_inference(
221
- images,
222
- roi_size=512,
223
- sw_batch_size=4,
224
- predictor=self.model,
225
- padding_mode="reflect",
226
- mode="gaussian",
227
- overlap=0.5,
228
- )
229
- else:
230
- outputs = self.model(images)
231
-
232
- return outputs
233
-
234
- def _post_process(self, outputs, labels):
235
- return outputs, labels
236
-
237
- def _get_f1_metric(self, masks_pred, masks_true):
238
- f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1]
239
-
240
- return f1_score
@@ -1,59 +0,0 @@
1
- import torch
2
- import os, sys
3
- from skimage import morphology, measure
4
- from monai.inferers import sliding_window_inference
5
-
6
- sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
7
-
8
- from core.BasePredictor import BasePredictor
9
-
10
- __all__ = ["Predictor"]
11
-
12
-
13
- class Predictor(BasePredictor):
14
- def __init__(
15
- self,
16
- model,
17
- device,
18
- input_path,
19
- output_path,
20
- make_submission=False,
21
- exp_name=None,
22
- algo_params=None,
23
- ):
24
- super(Predictor, self).__init__(
25
- model,
26
- device,
27
- input_path,
28
- output_path,
29
- make_submission,
30
- exp_name,
31
- algo_params,
32
- )
33
-
34
- def _inference(self, img_data):
35
- pred_mask = sliding_window_inference(
36
- img_data,
37
- 512,
38
- 4,
39
- self.model,
40
- padding_mode="constant",
41
- mode="gaussian",
42
- overlap=0.6,
43
- )
44
-
45
- return pred_mask
46
-
47
- def _post_process(self, pred_mask):
48
- # Get probability map from the predicted logits
49
- pred_mask = torch.from_numpy(pred_mask)
50
- pred_mask = torch.softmax(pred_mask, dim=0)
51
- pred_mask = pred_mask[1].cpu().numpy()
52
-
53
- # Apply morphological post-processing
54
- pred_mask = pred_mask > 0.5
55
- pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1)
56
- pred_mask = morphology.remove_small_objects(pred_mask, 16)
57
- pred_mask = measure.label(pred_mask)
58
-
59
- return pred_mask
@@ -1,113 +0,0 @@
1
- import torch
2
- import os, sys
3
- import monai
4
-
5
- from monai.data import decollate_batch
6
-
7
- sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
8
-
9
- from core.BaseTrainer import BaseTrainer
10
- from core.Baseline.utils import create_interior_onehot, identify_instances_from_classmap
11
- from train_tools.measures import evaluate_f1_score_cellseg
12
- from tqdm import tqdm
13
-
14
- __all__ = ["Trainer"]
15
-
16
-
17
- class Trainer(BaseTrainer):
18
- def __init__(
19
- self,
20
- model,
21
- dataloaders,
22
- optimizer,
23
- scheduler=None,
24
- criterion=None,
25
- num_epochs=100,
26
- device="cuda:0",
27
- no_valid=False,
28
- valid_frequency=1,
29
- amp=False,
30
- algo_params=None,
31
- ):
32
- super(Trainer, self).__init__(
33
- model,
34
- dataloaders,
35
- optimizer,
36
- scheduler,
37
- criterion,
38
- num_epochs,
39
- device,
40
- no_valid,
41
- valid_frequency,
42
- amp,
43
- algo_params,
44
- )
45
-
46
- # Dice loss as segmentation criterion
47
- self.criterion = monai.losses.DiceCELoss(softmax=True)
48
-
49
- def _epoch_phase(self, phase):
50
- """Learning process for 1 Epoch."""
51
-
52
- phase_results = {}
53
-
54
- # Set model mode
55
- self.model.train() if phase == "train" else self.model.eval()
56
-
57
- # Epoch process
58
- for batch_data in tqdm(self.dataloaders[phase]):
59
- images = batch_data["img"].to(self.device)
60
- labels = batch_data["label"].to(self.device)
61
- self.optimizer.zero_grad()
62
-
63
- # Map label masks to 3-class onehot map
64
- labels_onehot = create_interior_onehot(labels)
65
-
66
- # Forward pass
67
- with torch.set_grad_enabled(phase == "train"):
68
- outputs = self._inference(images, phase)
69
- loss = self.criterion(outputs, labels_onehot)
70
- self.loss_metric.append(loss)
71
-
72
- if phase != "train":
73
- f1_score = self._get_f1_metric(outputs, labels)
74
- self.f1_metric.append(f1_score)
75
-
76
- # Backward pass
77
- if phase == "train":
78
- # For the mixed precision training
79
- if self.amp:
80
- self.scaler.scale(loss).backward()
81
- self.scaler.unscale_(self.optimizer)
82
- self.scaler.step(self.optimizer)
83
- self.scaler.update()
84
-
85
- else:
86
- loss.backward()
87
- self.optimizer.step()
88
-
89
- # Update metrics
90
- phase_results = self._update_results(
91
- phase_results, self.loss_metric, "loss", phase
92
- )
93
-
94
- if phase != "train":
95
- phase_results = self._update_results(
96
- phase_results, self.f1_metric, "f1_score", phase
97
- )
98
-
99
- return phase_results
100
-
101
- def _post_process(self, outputs, labels_onehot):
102
- """Conduct post-processing for outputs & labels."""
103
- outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
104
- labels_onehot = [self.post_gt(i) for i in decollate_batch(labels_onehot)]
105
-
106
- return outputs, labels_onehot
107
-
108
- def _get_f1_metric(self, masks_pred, masks_true):
109
- masks_pred = identify_instances_from_classmap(masks_pred[0])
110
- masks_true = masks_true.squeeze(0).squeeze(0).cpu().numpy()
111
- f1_score = evaluate_f1_score_cellseg(masks_true, masks_pred)[-1]
112
-
113
- return f1_score
@@ -1,2 +0,0 @@
1
- from .Trainer import *
2
- from .Predictor import *
@@ -1,80 +0,0 @@
1
- """
2
- Adapted from the following references:
3
- [1] https://github.com/JunMa11/NeurIPS-CellSeg/blob/main/baseline/model_training_3class.py
4
-
5
- """
6
-
7
- import torch
8
- import numpy as np
9
- from skimage import segmentation, morphology, measure
10
- import monai
11
-
12
-
13
- __all__ = ["create_interior_onehot", "identify_instances_from_classmap"]
14
-
15
-
16
- @torch.no_grad()
17
- def identify_instances_from_classmap(
18
- class_map, cell_class=1, threshold=0.5, from_logits=True
19
- ):
20
- """Identification of cell instances from the class map"""
21
-
22
- if from_logits:
23
- class_map = torch.softmax(class_map, dim=0) # (C, H, W)
24
-
25
- # Convert probability map to binary mask
26
- pred_mask = class_map[cell_class].cpu().numpy()
27
-
28
- # Apply morphological postprocessing
29
- pred_mask = pred_mask > threshold
30
- pred_mask = morphology.remove_small_holes(pred_mask, connectivity=1)
31
- pred_mask = morphology.remove_small_objects(pred_mask, 16)
32
- pred_mask = measure.label(pred_mask)
33
-
34
- return pred_mask
35
-
36
-
37
- @torch.no_grad()
38
- def create_interior_onehot(inst_maps):
39
- """
40
- interior : (H,W), np.uint8
41
- three-class map, values: 0,1,2
42
- 0: background
43
- 1: interior
44
- 2: boundary
45
- """
46
- device = inst_maps.device
47
-
48
- # Get (np.int16) array corresponding to label masks: (B, 1, H, W)
49
- inst_maps = inst_maps.squeeze(1).cpu().numpy().astype(np.int16)
50
-
51
- interior_maps = []
52
-
53
- for inst_map in inst_maps:
54
- # Create interior-edge map
55
- boundary = segmentation.find_boundaries(inst_map, mode="inner")
56
-
57
- # Refine interior-edge map
58
- boundary = morphology.binary_dilation(boundary, morphology.disk(1))
59
-
60
- # Assign label classes
61
- interior_temp = np.logical_and(~boundary, inst_map > 0)
62
-
63
- # interior_temp[boundary] = 0
64
- interior_temp = morphology.remove_small_objects(interior_temp, min_size=16)
65
- interior = np.zeros_like(inst_map, dtype=np.uint8)
66
- interior[interior_temp] = 1
67
- interior[boundary] = 2
68
-
69
- interior_maps.append(interior)
70
-
71
- # Aggregate interior_maps for batch
72
- interior_maps = np.stack(interior_maps, axis=0).astype(np.uint8)
73
-
74
- # Reshape as original label shape: (B, H, W)
75
- interior_maps = torch.from_numpy(interior_maps).unsqueeze(1).to(device)
76
-
77
- # Obtain one-hot map for batch
78
- interior_onehot = monai.networks.one_hot(interior_maps, num_classes=3)
79
-
80
- return interior_onehot