spacr 0.4.60__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. spacr/__init__.py +2 -4
  2. spacr/__main__.py +3 -3
  3. spacr/core.py +13 -107
  4. spacr/gui.py +0 -1
  5. spacr/gui_core.py +2 -2
  6. spacr/gui_utils.py +5 -14
  7. spacr/io.py +189 -200
  8. spacr/mediar.py +12 -8
  9. spacr/plot.py +50 -13
  10. spacr/settings.py +71 -14
  11. spacr/submodules.py +21 -14
  12. spacr/timelapse.py +192 -6
  13. spacr/utils.py +180 -56
  14. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/METADATA +64 -62
  15. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/RECORD +20 -72
  16. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/WHEEL +1 -1
  17. spacr/resources/MEDIAR/.gitignore +0 -18
  18. spacr/resources/MEDIAR/LICENSE +0 -21
  19. spacr/resources/MEDIAR/README.md +0 -189
  20. spacr/resources/MEDIAR/SetupDict.py +0 -39
  21. spacr/resources/MEDIAR/config/baseline.json +0 -60
  22. spacr/resources/MEDIAR/config/mediar_example.json +0 -72
  23. spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
  24. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
  25. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
  26. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
  27. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
  28. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
  29. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
  30. spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
  31. spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
  32. spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
  33. spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
  34. spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
  35. spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
  36. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
  37. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
  38. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
  39. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
  40. spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
  41. spacr/resources/MEDIAR/core/__init__.py +0 -2
  42. spacr/resources/MEDIAR/core/utils.py +0 -40
  43. spacr/resources/MEDIAR/evaluate.py +0 -71
  44. spacr/resources/MEDIAR/generate_mapping.py +0 -121
  45. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  46. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  47. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  48. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  49. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  50. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  51. spacr/resources/MEDIAR/main.py +0 -125
  52. spacr/resources/MEDIAR/predict.py +0 -70
  53. spacr/resources/MEDIAR/requirements.txt +0 -14
  54. spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
  55. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
  56. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
  57. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
  58. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
  59. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
  60. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  61. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
  62. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
  63. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
  64. spacr/resources/MEDIAR/train_tools/measures.py +0 -200
  65. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
  66. spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
  67. spacr/resources/MEDIAR/train_tools/utils.py +0 -70
  68. spacr/stats.py +0 -221
  69. /spacr/{cellpose.py → spacr_cellpose.py} +0 -0
  70. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/LICENSE +0 -0
  71. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/entry_points.txt +0 -0
  72. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,234 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import os, sys
4
- from monai.inferers import sliding_window_inference
5
-
6
- sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
7
-
8
- from core.BasePredictor import BasePredictor
9
- from core.MEDIAR.utils import compute_masks
10
-
11
- __all__ = ["Predictor"]
12
-
13
-
14
- class Predictor(BasePredictor):
15
- def __init__(
16
- self,
17
- model,
18
- device,
19
- input_path,
20
- output_path,
21
- make_submission=False,
22
- exp_name=None,
23
- algo_params=None,
24
- ):
25
- super(Predictor, self).__init__(
26
- model,
27
- device,
28
- input_path,
29
- output_path,
30
- make_submission,
31
- exp_name,
32
- algo_params,
33
- )
34
- self.hflip_tta = HorizontalFlip()
35
- self.vflip_tta = VerticalFlip()
36
-
37
- @torch.no_grad()
38
- def _inference(self, img_data):
39
- """Conduct model prediction"""
40
-
41
- img_data = img_data.to(self.device)
42
- img_base = img_data
43
- outputs_base = self._window_inference(img_base)
44
- outputs_base = outputs_base.cpu().squeeze()
45
- img_base.cpu()
46
-
47
- if not self.use_tta:
48
- pred_mask = outputs_base
49
- return pred_mask
50
-
51
- else:
52
- # HorizontalFlip TTA
53
- img_hflip = self.hflip_tta.apply_aug_image(img_data, apply=True)
54
- outputs_hflip = self._window_inference(img_hflip)
55
- outputs_hflip = self.hflip_tta.apply_deaug_mask(outputs_hflip, apply=True)
56
- outputs_hflip = outputs_hflip.cpu().squeeze()
57
- img_hflip = img_hflip.cpu()
58
-
59
- # VertricalFlip TTA
60
- img_vflip = self.vflip_tta.apply_aug_image(img_data, apply=True)
61
- outputs_vflip = self._window_inference(img_vflip)
62
- outputs_vflip = self.vflip_tta.apply_deaug_mask(outputs_vflip, apply=True)
63
- outputs_vflip = outputs_vflip.cpu().squeeze()
64
- img_vflip = img_vflip.cpu()
65
-
66
- # Merge Results
67
- pred_mask = torch.zeros_like(outputs_base)
68
- pred_mask[0] = (outputs_base[0] + outputs_hflip[0] - outputs_vflip[0]) / 3
69
- pred_mask[1] = (outputs_base[1] - outputs_hflip[1] + outputs_vflip[1]) / 3
70
- pred_mask[2] = (outputs_base[2] + outputs_hflip[2] + outputs_vflip[2]) / 3
71
-
72
- return pred_mask
73
-
74
- def _window_inference(self, img_data, aux=False):
75
- """Inference on RoI-sized window"""
76
- outputs = sliding_window_inference(
77
- img_data,
78
- roi_size=512,
79
- sw_batch_size=4,
80
- predictor=self.model if not aux else self.model_aux,
81
- padding_mode="constant",
82
- mode="gaussian",
83
- overlap=0.6,
84
- )
85
-
86
- return outputs
87
-
88
- def _post_process(self, pred_mask):
89
- """Generate cell instance masks."""
90
- dP, cellprob = pred_mask[:2], self._sigmoid(pred_mask[-1])
91
- H, W = pred_mask.shape[-2], pred_mask.shape[-1]
92
-
93
- if np.prod(H * W) < (5000 * 5000):
94
- pred_mask = compute_masks(
95
- dP,
96
- cellprob,
97
- use_gpu=True,
98
- flow_threshold=0.4,
99
- device=self.device,
100
- cellprob_threshold=0.5,
101
- )[0]
102
-
103
- else:
104
- print("\n[Whole Slide] Grid Prediction starting...")
105
- roi_size = 2000
106
-
107
- # Get patch grid by roi_size
108
- if H % roi_size != 0:
109
- n_H = H // roi_size + 1
110
- new_H = roi_size * n_H
111
- else:
112
- n_H = H // roi_size
113
- new_H = H
114
-
115
- if W % roi_size != 0:
116
- n_W = W // roi_size + 1
117
- new_W = roi_size * n_W
118
- else:
119
- n_W = W // roi_size
120
- new_W = W
121
-
122
- # Allocate values on the grid
123
- pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
124
- dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
125
- cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
126
-
127
- dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
128
-
129
- for i in range(n_H):
130
- for j in range(n_W):
131
- print("Pred on Grid (%d, %d) processing..." % (i, j))
132
- dP_roi = dP_pad[
133
- :,
134
- roi_size * i : roi_size * (i + 1),
135
- roi_size * j : roi_size * (j + 1),
136
- ]
137
- cellprob_roi = cellprob_pad[
138
- roi_size * i : roi_size * (i + 1),
139
- roi_size * j : roi_size * (j + 1),
140
- ]
141
-
142
- pred_mask = compute_masks(
143
- dP_roi,
144
- cellprob_roi,
145
- use_gpu=True,
146
- flow_threshold=0.4,
147
- device=self.device,
148
- cellprob_threshold=0.5,
149
- )[0]
150
-
151
- pred_pad[
152
- roi_size * i : roi_size * (i + 1),
153
- roi_size * j : roi_size * (j + 1),
154
- ] = pred_mask
155
-
156
- pred_mask = pred_pad[:H, :W]
157
-
158
- return pred_mask
159
-
160
- def _sigmoid(self, z):
161
- return 1 / (1 + np.exp(-z))
162
-
163
-
164
- """
165
- Adapted from the following references:
166
- [1] https://github.com/qubvel/ttach/blob/master/ttach/transforms.py
167
-
168
- """
169
-
170
-
171
- def hflip(x):
172
- """flip batch of images horizontally"""
173
- return x.flip(3)
174
-
175
-
176
- def vflip(x):
177
- """flip batch of images vertically"""
178
- return x.flip(2)
179
-
180
-
181
- class DualTransform:
182
- identity_param = None
183
-
184
- def __init__(
185
- self, name: str, params,
186
- ):
187
- self.params = params
188
- self.pname = name
189
-
190
- def apply_aug_image(self, image, *args, **params):
191
- raise NotImplementedError
192
-
193
- def apply_deaug_mask(self, mask, *args, **params):
194
- raise NotImplementedError
195
-
196
-
197
- class HorizontalFlip(DualTransform):
198
- """Flip images horizontally (left -> right)"""
199
-
200
- identity_param = False
201
-
202
- def __init__(self):
203
- super().__init__("apply", [False, True])
204
-
205
- def apply_aug_image(self, image, apply=False, **kwargs):
206
- if apply:
207
- image = hflip(image)
208
- return image
209
-
210
- def apply_deaug_mask(self, mask, apply=False, **kwargs):
211
- if apply:
212
- mask = hflip(mask)
213
- return mask
214
-
215
-
216
- class VerticalFlip(DualTransform):
217
- """Flip images vertically (up -> down)"""
218
-
219
- identity_param = False
220
-
221
- def __init__(self):
222
- super().__init__("apply", [False, True])
223
-
224
- def apply_aug_image(self, image, apply=False, **kwargs):
225
- if apply:
226
- image = vflip(image)
227
-
228
- return image
229
-
230
- def apply_deaug_mask(self, mask, apply=False, **kwargs):
231
- if apply:
232
- mask = vflip(mask)
233
-
234
- return mask
@@ -1,172 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import os, sys
5
- from tqdm import tqdm
6
- from monai.inferers import sliding_window_inference
7
-
8
- sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
9
-
10
- from core.BaseTrainer import BaseTrainer
11
- from core.MEDIAR.utils import *
12
-
13
- __all__ = ["Trainer"]
14
-
15
-
16
- class Trainer(BaseTrainer):
17
- def __init__(
18
- self,
19
- model,
20
- dataloaders,
21
- optimizer,
22
- scheduler=None,
23
- criterion=None,
24
- num_epochs=100,
25
- device="cuda:0",
26
- no_valid=False,
27
- valid_frequency=1,
28
- amp=False,
29
- algo_params=None,
30
- ):
31
- super(Trainer, self).__init__(
32
- model,
33
- dataloaders,
34
- optimizer,
35
- scheduler,
36
- criterion,
37
- num_epochs,
38
- device,
39
- no_valid,
40
- valid_frequency,
41
- amp,
42
- algo_params,
43
- )
44
-
45
- self.mse_loss = nn.MSELoss(reduction="mean")
46
- self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
47
-
48
- def mediar_criterion(self, outputs, labels_onehot_flows):
49
- """loss function between true labels and prediction outputs"""
50
-
51
- # Cell Recognition Loss
52
- cellprob_loss = self.bce_loss(
53
- outputs[:, -1],
54
- torch.from_numpy(labels_onehot_flows[:, 1] > 0.5).to(self.device).float(),
55
- )
56
-
57
- # Cell Distinction Loss
58
- gradient_flows = torch.from_numpy(labels_onehot_flows[:, 2:]).to(self.device)
59
- gradflow_loss = 0.5 * self.mse_loss(outputs[:, :2], 5.0 * gradient_flows)
60
-
61
- loss = cellprob_loss + gradflow_loss
62
-
63
- return loss
64
-
65
- def _epoch_phase(self, phase):
66
- phase_results = {}
67
-
68
- # Set model mode
69
- self.model.train() if phase == "train" else self.model.eval()
70
-
71
- # Epoch process
72
- for batch_data in tqdm(self.dataloaders[phase]):
73
- images, labels = batch_data["img"], batch_data["label"]
74
-
75
- if self.with_public:
76
- # Load batches sequentially from the unlabeled dataloader
77
- try:
78
- batch_data = next(self.public_iterator)
79
- images_pub, labels_pub = batch_data["img"], batch_data["label"]
80
-
81
- except:
82
- # Assign memory loader if the cycle ends
83
- self.public_iterator = iter(self.public_loader)
84
- batch_data = next(self.public_iterator)
85
- images_pub, labels_pub = batch_data["img"], batch_data["label"]
86
-
87
- # Concat memory data to the batch
88
- images = torch.cat([images, images_pub], dim=0)
89
- labels = torch.cat([labels, labels_pub], dim=0)
90
-
91
- images = images.to(self.device)
92
- labels = labels.to(self.device)
93
-
94
- self.optimizer.zero_grad()
95
-
96
- # Forward pass
97
- with torch.cuda.amp.autocast(enabled=self.amp):
98
- with torch.set_grad_enabled(phase == "train"):
99
- # Output shape is B x [grad y, grad x, cellprob] x H x W
100
- outputs = self._inference(images, phase)
101
-
102
- # Map label masks to graidnet and onehot
103
- labels_onehot_flows = labels_to_flows(
104
- labels, use_gpu=True, device=self.device
105
- )
106
- # Calculate loss
107
- loss = self.mediar_criterion(outputs, labels_onehot_flows)
108
- self.loss_metric.append(loss)
109
-
110
- # Calculate valid statistics
111
- if phase != "train":
112
- outputs, labels = self._post_process(outputs, labels)
113
- f1_score = self._get_f1_metric(outputs, labels)
114
- self.f1_metric.append(f1_score)
115
-
116
- # Backward pass
117
- if phase == "train":
118
- # For the mixed precision training
119
- if self.amp:
120
- self.scaler.scale(loss).backward()
121
- self.scaler.unscale_(self.optimizer)
122
- self.scaler.step(self.optimizer)
123
- self.scaler.update()
124
-
125
- else:
126
- loss.backward()
127
- self.optimizer.step()
128
-
129
- # Update metrics
130
- phase_results = self._update_results(
131
- phase_results, self.loss_metric, "dice_loss", phase
132
- )
133
- if phase != "train":
134
- phase_results = self._update_results(
135
- phase_results, self.f1_metric, "f1_score", phase
136
- )
137
-
138
- return phase_results
139
-
140
- def _inference(self, images, phase="train"):
141
- """inference methods for different phase"""
142
-
143
- if phase != "train":
144
- outputs = sliding_window_inference(
145
- images,
146
- roi_size=512,
147
- sw_batch_size=4,
148
- predictor=self.model,
149
- padding_mode="constant",
150
- mode="gaussian",
151
- overlap=0.5,
152
- )
153
- else:
154
- outputs = self.model(images)
155
-
156
- return outputs
157
-
158
- def _post_process(self, outputs, labels=None):
159
- """Predict cell instances using the gradient tracking"""
160
- outputs = outputs.squeeze(0).cpu().numpy()
161
- gradflows, cellprob = outputs[:2], self._sigmoid(outputs[-1])
162
- outputs = compute_masks(gradflows, cellprob, use_gpu=True, device=self.device)
163
- outputs = outputs[0] # (1, C, H, W) -> (C, H, W)
164
-
165
- if labels is not None:
166
- labels = labels.squeeze(0).squeeze(0).cpu().numpy()
167
-
168
- return outputs, labels
169
-
170
- def _sigmoid(self, z):
171
- """Sigmoid function for numpy arrays"""
172
- return 1 / (1 + np.exp(-z))
@@ -1,3 +0,0 @@
1
- from .Trainer import *
2
- from .Predictor import *
3
- from .EnsemblePredictor import *