octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
import copick
|
|
4
|
+
import torch
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from typing import Optional, Union, Tuple, List
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
import pytorch_lightning as pl
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from pytorch_lightning import Trainer
|
|
11
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
12
|
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
|
13
|
+
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
|
|
14
|
+
from dotenv import load_dotenv
|
|
15
|
+
from monai.transforms import (
|
|
16
|
+
Compose,
|
|
17
|
+
EnsureChannelFirstd,
|
|
18
|
+
Orientationd,
|
|
19
|
+
AsDiscrete,
|
|
20
|
+
RandFlipd,
|
|
21
|
+
RandRotate90d,
|
|
22
|
+
NormalizeIntensityd,
|
|
23
|
+
NormalizeIntensityd,
|
|
24
|
+
RandCropByLabelClassesd,
|
|
25
|
+
)
|
|
26
|
+
from monai.networks.nets import UNet
|
|
27
|
+
from monai.losses import TverskyLoss
|
|
28
|
+
from monai.metrics import DiceMetric, ConfusionMatrixMetric
|
|
29
|
+
import optuna
|
|
30
|
+
from optuna.integration import PyTorchLightningPruningCallback
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_args():
|
|
34
|
+
parser = argparse.ArgumentParser(
|
|
35
|
+
description = "Hyperparamter tuning using PyTorch Lightning distributed data-parallel and Optuna."
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
|
|
38
|
+
parser.add_argument('--copick_user_name', type=str, default='user0')
|
|
39
|
+
parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
|
|
40
|
+
parser.add_argument('--train_batch_size', type=int, default=1)
|
|
41
|
+
parser.add_argument('--val_batch_size', type=int, default=1)
|
|
42
|
+
parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
|
|
43
|
+
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
|
44
|
+
parser.add_argument('--num_epochs', type=int, default=20)
|
|
45
|
+
parser.add_argument('--num_gpus', type=int, default=1)
|
|
46
|
+
parser.add_argument('--num_optuna_trials', type=int, default=10)
|
|
47
|
+
parser.add_argument('--pruning', action="store_true", help="Activate the pruning feature. `MedianPruner` stops unpromising trials at the early stages of training.")
|
|
48
|
+
return parser.parse_args()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Model(pl.LightningModule):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
spatial_dims: int = 3,
|
|
55
|
+
in_channels: int = 1,
|
|
56
|
+
out_channels: int = 8,
|
|
57
|
+
channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
|
|
58
|
+
strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
|
|
59
|
+
num_res_units: int = 1,
|
|
60
|
+
lr: float=1e-3):
|
|
61
|
+
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.save_hyperparameters()
|
|
64
|
+
self.model = UNet(
|
|
65
|
+
spatial_dims=self.hparams.spatial_dims,
|
|
66
|
+
in_channels=self.hparams.in_channels,
|
|
67
|
+
out_channels=self.hparams.out_channels,
|
|
68
|
+
channels=self.hparams.channels,
|
|
69
|
+
strides=self.hparams.strides,
|
|
70
|
+
num_res_units=self.hparams.num_res_units,
|
|
71
|
+
)
|
|
72
|
+
self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
|
|
73
|
+
self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
|
|
74
|
+
|
|
75
|
+
def forward(self, x):
|
|
76
|
+
return self.model(x)
|
|
77
|
+
|
|
78
|
+
def training_step(self, batch, batch_idx):
|
|
79
|
+
x, y = batch['image'], batch['label']
|
|
80
|
+
y_hat = self(x)
|
|
81
|
+
loss = self.loss_fn(y_hat, y)
|
|
82
|
+
return loss
|
|
83
|
+
|
|
84
|
+
def validation_step(self, batch, batch_idx):
|
|
85
|
+
with torch.no_grad(): # This ensures that gradients are not stored in memory
|
|
86
|
+
x, y = batch['image'], batch['label']
|
|
87
|
+
y_hat = self(x)
|
|
88
|
+
metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
|
|
89
|
+
metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
|
|
90
|
+
|
|
91
|
+
# compute metric for current iteration
|
|
92
|
+
self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
|
|
93
|
+
metrics = self.metric_fn.aggregate(reduction="mean_batch")
|
|
94
|
+
for i,m in enumerate(metrics):
|
|
95
|
+
self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
|
|
96
|
+
metric = torch.mean(metrics) # cannot log ndarray
|
|
97
|
+
self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
|
|
98
|
+
return {'val_metric': metric}
|
|
99
|
+
|
|
100
|
+
def configure_optimizers(self):
|
|
101
|
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class CopickDataModule(pl.LightningDataModule):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
copick_config_path: str,
|
|
108
|
+
train_batch_size: int,
|
|
109
|
+
val_batch_size: int,
|
|
110
|
+
num_random_samples_per_batch: int):
|
|
111
|
+
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.train_batch_size = train_batch_size
|
|
114
|
+
self.val_batch_size = val_batch_size
|
|
115
|
+
|
|
116
|
+
self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path)
|
|
117
|
+
self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
|
|
118
|
+
self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
|
|
119
|
+
print(f"Number of training samples: {len(self.train_files)}")
|
|
120
|
+
print(f"Number of validation samples: {len(self.val_files)}")
|
|
121
|
+
|
|
122
|
+
# Non-random transforms to be cached
|
|
123
|
+
self.non_random_transforms = Compose([
|
|
124
|
+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
|
|
125
|
+
NormalizeIntensityd(keys="image"),
|
|
126
|
+
Orientationd(keys=["image", "label"], axcodes="RAS")
|
|
127
|
+
])
|
|
128
|
+
|
|
129
|
+
# Random transforms to be applied during training
|
|
130
|
+
self.random_transforms = Compose([
|
|
131
|
+
RandCropByLabelClassesd(
|
|
132
|
+
keys=["image", "label"],
|
|
133
|
+
label_key="label",
|
|
134
|
+
spatial_size=[96, 96, 96],
|
|
135
|
+
num_classes=self.nclasses,
|
|
136
|
+
num_samples=num_random_samples_per_batch
|
|
137
|
+
),
|
|
138
|
+
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
|
|
139
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
143
|
+
self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
144
|
+
self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
|
|
145
|
+
self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
146
|
+
self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
|
|
147
|
+
|
|
148
|
+
def train_dataloader(self) -> DataLoader:
|
|
149
|
+
return DataLoader(
|
|
150
|
+
self.train_ds,
|
|
151
|
+
batch_size=self.train_batch_size,
|
|
152
|
+
shuffle=True,
|
|
153
|
+
num_workers=4,
|
|
154
|
+
persistent_workers=True,
|
|
155
|
+
pin_memory=torch.cuda.is_available(),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def val_dataloader(self) -> DataLoader:
|
|
159
|
+
return DataLoader(
|
|
160
|
+
self.val_ds,
|
|
161
|
+
batch_size=self.val_batch_size,
|
|
162
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
163
|
+
num_workers=4,
|
|
164
|
+
persistent_workers=True,
|
|
165
|
+
pin_memory=torch.cuda.is_available(),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def data_from_copick(copick_config_path):
|
|
170
|
+
root = copick.from_file(copick_config_path)
|
|
171
|
+
nclasses = len(root.pickable_objects) + 1
|
|
172
|
+
data_dicts = []
|
|
173
|
+
target_objects = defaultdict(dict)
|
|
174
|
+
for object in root.pickable_objects:
|
|
175
|
+
if object.is_particle:
|
|
176
|
+
target_objects[object.name]['label'] = object.label
|
|
177
|
+
target_objects[object.name]['radius'] = object.radius
|
|
178
|
+
|
|
179
|
+
data_dicts = []
|
|
180
|
+
for run in tqdm(root.runs[:8]):
|
|
181
|
+
tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
|
|
182
|
+
segmentation = run.get_segmentations(name='paintedPicks', user_id='user0', voxel_size=10, is_multilabel=True)[0].numpy()
|
|
183
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
|
|
184
|
+
segmentation[membrane_seg==1] = 1
|
|
185
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
186
|
+
|
|
187
|
+
return data_dicts, nclasses
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def objective(trial: optuna.trial.Trial) -> float:
|
|
191
|
+
args = get_args()
|
|
192
|
+
mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
|
|
193
|
+
tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
|
|
194
|
+
#run_name='test1'
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Trainer callbacks
|
|
198
|
+
checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='max')
|
|
199
|
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
200
|
+
|
|
201
|
+
# Detect distributed training environment
|
|
202
|
+
devices = list(range(args.num_gpus))
|
|
203
|
+
|
|
204
|
+
#channels, strides_pattern, num_res_units = sync_hyperparameters(trial)
|
|
205
|
+
# We optimize the number of layers, strides, and number of residual units
|
|
206
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
207
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32, 64])
|
|
208
|
+
channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
209
|
+
num_downsampling_layers = trial.suggest_int("num_downsampling_layers", 1, num_layers - 1)
|
|
210
|
+
strides_pattern = [2] * num_downsampling_layers + [1] * (num_layers - num_downsampling_layers - 1)
|
|
211
|
+
num_res_units = trial.suggest_int("num_res_units", 1, 3)
|
|
212
|
+
|
|
213
|
+
model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=args.learning_rate)
|
|
214
|
+
datamodule = CopickDataModule(args.copick_config_path, args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
|
|
215
|
+
callback = PyTorchLightningPruningCallback(trial, monitor="val_metric")
|
|
216
|
+
|
|
217
|
+
# Priotize performace over precision
|
|
218
|
+
torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
|
|
219
|
+
|
|
220
|
+
# Trainer for distributed training with DDP
|
|
221
|
+
trainer = Trainer(
|
|
222
|
+
max_epochs=args.num_epochs,
|
|
223
|
+
logger=mlf_logger,
|
|
224
|
+
callbacks=[checkpoint_callback, lr_monitor],
|
|
225
|
+
strategy="ddp_spawn",
|
|
226
|
+
accelerator="gpu",
|
|
227
|
+
devices=devices,
|
|
228
|
+
num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
|
|
229
|
+
log_every_n_steps=1
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
hyperparameters = dict(op_num_layers=num_layers, op_base_channel=base_channel, op_num_downsampling_layers=num_downsampling_layers, op_num_res_units=num_res_units)
|
|
233
|
+
trainer.logger.log_hyperparams(hyperparameters)
|
|
234
|
+
trainer.fit(model, datamodule=datamodule)
|
|
235
|
+
callback.check_pruned()
|
|
236
|
+
return trainer.callback_metrics["val_metric"].item()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
if __name__ == "__main__":
|
|
240
|
+
args = get_args()
|
|
241
|
+
# MLflow setup
|
|
242
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
243
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
244
|
+
if not password or not username:
|
|
245
|
+
print("Password not found in environment, loading from .env file...")
|
|
246
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
247
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
248
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
249
|
+
|
|
250
|
+
# Check again after loading .env file
|
|
251
|
+
if not password:
|
|
252
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
253
|
+
else:
|
|
254
|
+
print("Password loaded successfully")
|
|
255
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
256
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
257
|
+
|
|
258
|
+
pruner: optuna.pruners.BasePruner = (
|
|
259
|
+
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
|
|
260
|
+
)
|
|
261
|
+
storage = "sqlite:///example.db"
|
|
262
|
+
study = optuna.create_study(
|
|
263
|
+
study_name="pl_ddp",
|
|
264
|
+
storage=storage,
|
|
265
|
+
direction="maximize",
|
|
266
|
+
load_if_exists=True,
|
|
267
|
+
pruner=pruner
|
|
268
|
+
)
|
|
269
|
+
study.optimize(objective, n_trials=args.num_optuna_trials)
|
|
270
|
+
|
|
271
|
+
# Print the best hyperparameters
|
|
272
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
273
|
+
print(f"Best hyperparameters: {study.best_trial.params}")
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
from typing import Optional, Union, Tuple, List
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import DataLoader
|
|
7
|
+
import pytorch_lightning as pl
|
|
8
|
+
from pytorch_lightning import Trainer
|
|
9
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
10
|
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
|
11
|
+
from pytorch_lightning.strategies import DDPStrategy
|
|
12
|
+
import os
|
|
13
|
+
import copick
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
from monai.transforms import (
|
|
18
|
+
Compose,
|
|
19
|
+
EnsureChannelFirstd,
|
|
20
|
+
Orientationd,
|
|
21
|
+
AsDiscrete,
|
|
22
|
+
RandFlipd,
|
|
23
|
+
RandRotate90d,
|
|
24
|
+
NormalizeIntensityd,
|
|
25
|
+
NormalizeIntensityd,
|
|
26
|
+
RandCropByLabelClassesd,
|
|
27
|
+
)
|
|
28
|
+
from monai.networks.nets import UNet
|
|
29
|
+
from monai.losses import TverskyLoss
|
|
30
|
+
from monai.metrics import DiceMetric, ConfusionMatrixMetric
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_args():
|
|
34
|
+
parser = argparse.ArgumentParser(
|
|
35
|
+
description = "Train a 3d U-Net model with PyTorch Lightning supporting distributed training strategies."
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
|
|
38
|
+
parser.add_argument('--copick_user_name', type=str, default='user0')
|
|
39
|
+
parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
|
|
40
|
+
parser.add_argument('--train_batch_size', type=int, default=1)
|
|
41
|
+
parser.add_argument('--val_batch_size', type=int, default=1)
|
|
42
|
+
parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
|
|
43
|
+
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
|
44
|
+
parser.add_argument('--num_epochs', type=int, default=20)
|
|
45
|
+
parser.add_argument('--num_gpus', type=int, default=1)
|
|
46
|
+
return parser.parse_args()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Model(pl.LightningModule):
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
spatial_dims: int = 3,
|
|
53
|
+
in_channels: int = 1,
|
|
54
|
+
out_channels: int = 8,
|
|
55
|
+
channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
|
|
56
|
+
strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
|
|
57
|
+
num_res_units: int = 1,
|
|
58
|
+
lr: float=1e-3):
|
|
59
|
+
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.save_hyperparameters()
|
|
62
|
+
|
|
63
|
+
self.model = UNet(
|
|
64
|
+
spatial_dims=self.hparams.spatial_dims,
|
|
65
|
+
in_channels=self.hparams.in_channels,
|
|
66
|
+
out_channels=self.hparams.out_channels,
|
|
67
|
+
channels=self.hparams.channels,
|
|
68
|
+
strides=self.hparams.strides,
|
|
69
|
+
num_res_units=self.hparams.num_res_units,
|
|
70
|
+
)
|
|
71
|
+
self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
|
|
72
|
+
self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
return self.model(x)
|
|
76
|
+
|
|
77
|
+
def training_step(self, batch, batch_idx):
|
|
78
|
+
x, y = batch['image'], batch['label']
|
|
79
|
+
y_hat = self(x)
|
|
80
|
+
loss = self.loss_fn(y_hat, y)
|
|
81
|
+
return loss
|
|
82
|
+
|
|
83
|
+
def validation_step(self, batch, batch_idx):
|
|
84
|
+
x, y = batch['image'], batch['label']
|
|
85
|
+
y_hat = self(x)
|
|
86
|
+
metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
|
|
87
|
+
metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
|
|
88
|
+
|
|
89
|
+
# compute metric for current iteration
|
|
90
|
+
self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
|
|
91
|
+
metrics = self.metric_fn.aggregate(reduction="mean_batch")
|
|
92
|
+
for i,m in enumerate(metrics):
|
|
93
|
+
self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
|
|
94
|
+
metric = torch.mean(metrics) # cannot log ndarray
|
|
95
|
+
self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
|
|
96
|
+
return {'val_metric': metric}
|
|
97
|
+
|
|
98
|
+
def configure_optimizers(self):
|
|
99
|
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class CopickDataModule(pl.LightningDataModule):
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
copick_config_path: str,
|
|
106
|
+
copick_user_name: str,
|
|
107
|
+
copick_segmentation_name: str,
|
|
108
|
+
train_batch_size: int,
|
|
109
|
+
val_batch_size: int,
|
|
110
|
+
num_random_samples_per_batch: int):
|
|
111
|
+
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.train_batch_size = train_batch_size
|
|
114
|
+
self.val_batch_size = val_batch_size
|
|
115
|
+
|
|
116
|
+
self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name)
|
|
117
|
+
self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
|
|
118
|
+
self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
|
|
119
|
+
print(f"Number of training samples: {len(self.train_files)}")
|
|
120
|
+
print(f"Number of validation samples: {len(self.val_files)}")
|
|
121
|
+
|
|
122
|
+
# Non-random transforms to be cached
|
|
123
|
+
self.non_random_transforms = Compose([
|
|
124
|
+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
|
|
125
|
+
NormalizeIntensityd(keys="image"),
|
|
126
|
+
Orientationd(keys=["image", "label"], axcodes="RAS")
|
|
127
|
+
])
|
|
128
|
+
|
|
129
|
+
# Random transforms to be applied during training
|
|
130
|
+
self.random_transforms = Compose([
|
|
131
|
+
RandCropByLabelClassesd(
|
|
132
|
+
keys=["image", "label"],
|
|
133
|
+
label_key="label",
|
|
134
|
+
spatial_size=[96, 96, 96],
|
|
135
|
+
num_classes=self.nclasses,
|
|
136
|
+
num_samples=num_random_samples_per_batch
|
|
137
|
+
),
|
|
138
|
+
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
|
|
139
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
143
|
+
self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
144
|
+
self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
|
|
145
|
+
self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
146
|
+
self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
|
|
147
|
+
|
|
148
|
+
def train_dataloader(self) -> DataLoader:
|
|
149
|
+
return DataLoader(
|
|
150
|
+
self.train_ds,
|
|
151
|
+
batch_size=self.train_batch_size,
|
|
152
|
+
shuffle=True,
|
|
153
|
+
num_workers=4,
|
|
154
|
+
pin_memory=torch.cuda.is_available()
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def val_dataloader(self) -> DataLoader:
|
|
158
|
+
return DataLoader(
|
|
159
|
+
self.val_ds,
|
|
160
|
+
batch_size=self.val_batch_size,
|
|
161
|
+
num_workers=4,
|
|
162
|
+
pin_memory=torch.cuda.is_available(),
|
|
163
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name):
|
|
168
|
+
root = copick.from_file(copick_config_path)
|
|
169
|
+
nclasses = len(root.pickable_objects) + 1
|
|
170
|
+
data_dicts = []
|
|
171
|
+
target_objects = defaultdict(dict)
|
|
172
|
+
for object in root.pickable_objects:
|
|
173
|
+
if object.is_particle:
|
|
174
|
+
target_objects[object.name]['label'] = object.label
|
|
175
|
+
target_objects[object.name]['radius'] = object.radius
|
|
176
|
+
|
|
177
|
+
data_dicts = []
|
|
178
|
+
for run in tqdm(root.runs[:2]):
|
|
179
|
+
tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
|
|
180
|
+
segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=10, is_multilabel=True)[0].numpy()
|
|
181
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")
|
|
182
|
+
if membrane_seg:
|
|
183
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
|
|
184
|
+
segmentation[membrane_seg==1]=1
|
|
185
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
186
|
+
|
|
187
|
+
return data_dicts, nclasses
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def train():
|
|
191
|
+
args = get_args()
|
|
192
|
+
mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
|
|
193
|
+
tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
|
|
194
|
+
#run_name='test1'
|
|
195
|
+
)
|
|
196
|
+
# Trainer callbacks
|
|
197
|
+
checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='min')
|
|
198
|
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
199
|
+
|
|
200
|
+
# Detect distributed training environment
|
|
201
|
+
devices = list(range(args.num_gpus))
|
|
202
|
+
|
|
203
|
+
# Initialize model
|
|
204
|
+
model = Model(lr=args.learning_rate)
|
|
205
|
+
datamodule = CopickDataModule(args.copick_config_path, args.copick_user_name, args.copick_segmentation_name,
|
|
206
|
+
args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
|
|
207
|
+
|
|
208
|
+
# Priotize performace over precision
|
|
209
|
+
torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
|
|
210
|
+
|
|
211
|
+
# Trainer for distributed training with DDP
|
|
212
|
+
trainer = Trainer(
|
|
213
|
+
max_epochs=args.num_epochs,
|
|
214
|
+
logger=mlf_logger,
|
|
215
|
+
callbacks=[checkpoint_callback, lr_monitor],
|
|
216
|
+
strategy=DDPStrategy(find_unused_parameters=False),
|
|
217
|
+
accelerator="gpu",
|
|
218
|
+
devices=devices,
|
|
219
|
+
num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
|
|
220
|
+
log_every_n_steps=1
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
trainer.fit(model, datamodule=datamodule)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
# MLflow setup
|
|
228
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
229
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
230
|
+
if not password or not username:
|
|
231
|
+
print("Password not found in environment, loading from .env file...")
|
|
232
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
233
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
234
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
235
|
+
|
|
236
|
+
# Check again after loading .env file
|
|
237
|
+
if not password:
|
|
238
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
239
|
+
else:
|
|
240
|
+
print("Password loaded successfully")
|
|
241
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
242
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
243
|
+
|
|
244
|
+
train()
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
class EarlyStoppingChecker:
|
|
4
|
+
"""
|
|
5
|
+
A class to manage various early stopping criteria for model training.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self,
|
|
9
|
+
max_nan_epochs=15,
|
|
10
|
+
plateau_patience=20,
|
|
11
|
+
plateau_min_delta=0.001,
|
|
12
|
+
stagnation_patience=50,
|
|
13
|
+
convergence_window=5,
|
|
14
|
+
convergence_threshold=0.005,
|
|
15
|
+
val_interval=15,
|
|
16
|
+
monitor_metric='avg_fbeta'):
|
|
17
|
+
"""
|
|
18
|
+
Initialize early stopping parameters.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
max_nan_epochs: Maximum number of epochs with NaN loss before stopping
|
|
22
|
+
plateau_patience: Number of validation checks to wait for plateau detection
|
|
23
|
+
plateau_min_delta: Minimum change to qualify as improvement
|
|
24
|
+
stagnation_patience: Number of validation intervals to wait for best metric improvement
|
|
25
|
+
convergence_window: Window size for calculating improvement rate
|
|
26
|
+
convergence_threshold: Minimum improvement rate threshold
|
|
27
|
+
val_interval: Number of epochs between validation runs
|
|
28
|
+
monitor_metric: Primary metric to monitor for early stopping criteria
|
|
29
|
+
"""
|
|
30
|
+
self.max_nan_epochs = max_nan_epochs
|
|
31
|
+
self.plateau_patience = plateau_patience
|
|
32
|
+
self.plateau_min_delta = plateau_min_delta
|
|
33
|
+
self.stagnation_patience = stagnation_patience
|
|
34
|
+
self.convergence_window = convergence_window
|
|
35
|
+
self.convergence_threshold = convergence_threshold
|
|
36
|
+
self.val_interval = val_interval
|
|
37
|
+
self.monitor_metric = monitor_metric
|
|
38
|
+
|
|
39
|
+
# Counters
|
|
40
|
+
self.nan_counter = 0
|
|
41
|
+
|
|
42
|
+
# Flags for detailed reporting
|
|
43
|
+
self.stopped_reason = None
|
|
44
|
+
|
|
45
|
+
def check_for_nan(self, epoch_loss):
|
|
46
|
+
"""Check for NaN in the loss."""
|
|
47
|
+
if np.isnan(epoch_loss):
|
|
48
|
+
self.nan_counter += 1
|
|
49
|
+
if self.nan_counter > self.max_nan_epochs:
|
|
50
|
+
self.stopped_reason = f"NaN values in loss for more than {self.max_nan_epochs} epochs"
|
|
51
|
+
return True
|
|
52
|
+
else:
|
|
53
|
+
self.nan_counter = 0 # Reset the counter if loss is valid
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
def check_for_plateau(self, results):
|
|
57
|
+
"""Detect plateaus in validation metrics."""
|
|
58
|
+
if len(results[self.monitor_metric]) < self.plateau_patience + 1:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
# Get the last 'patience' number of validation points
|
|
62
|
+
recent_values = [x[1] for x in results[self.monitor_metric][-self.plateau_patience:]]
|
|
63
|
+
# Find the max value in the window
|
|
64
|
+
max_value = max(recent_values)
|
|
65
|
+
# Find the min value in the window
|
|
66
|
+
min_value = min(recent_values)
|
|
67
|
+
|
|
68
|
+
# If the range of values is small, consider it a plateau
|
|
69
|
+
if max_value - min_value < self.plateau_min_delta:
|
|
70
|
+
self.stopped_reason = f"{self.monitor_metric} plateaued for {self.plateau_patience} validations"
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
def check_best_metric_stagnation(self, results):
|
|
76
|
+
"""Stop if best metric hasn't improved for a number of validation intervals."""
|
|
77
|
+
if "best_metric_epoch" not in results or len(results[self.monitor_metric]) < self.stagnation_patience + 1:
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
# Get epoch of the best metric so far
|
|
81
|
+
best_epoch = results["best_metric_epoch"]
|
|
82
|
+
current_epoch = results[self.monitor_metric][-1][0]
|
|
83
|
+
|
|
84
|
+
# Check if it's been more than 'patience' validation intervals
|
|
85
|
+
if (current_epoch - best_epoch) >= (self.stagnation_patience * self.val_interval):
|
|
86
|
+
self.stopped_reason = f"No improvement for {self.stagnation_patience} validation intervals"
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
# def check_convergence_rate(self, results):
|
|
92
|
+
# """Stop when improvement rate slows below threshold."""
|
|
93
|
+
# if len(results[self.monitor_metric]) < self.convergence_window + 1:
|
|
94
|
+
# return False
|
|
95
|
+
|
|
96
|
+
# # Calculate average improvement rate over window
|
|
97
|
+
# recent_values = [x[1] for x in results[self.monitor_metric][-(self.convergence_window+1):]]
|
|
98
|
+
# improvements = [recent_values[i+1] - recent_values[i] for i in range(self.convergence_window)]
|
|
99
|
+
# avg_improvement = sum(improvements) / self.convergence_window
|
|
100
|
+
|
|
101
|
+
# if avg_improvement < self.convergence_threshold and avg_improvement > 0:
|
|
102
|
+
# self.stopped_reason = f"Convergence rate ({avg_improvement:.6f}) below threshold"
|
|
103
|
+
# return True
|
|
104
|
+
|
|
105
|
+
# return False
|
|
106
|
+
|
|
107
|
+
def should_stop_training(self, epoch_loss, results=None, check_metrics=False):
|
|
108
|
+
"""
|
|
109
|
+
Comprehensive check for whether training should stop.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
epoch_loss: Current epoch's loss value
|
|
113
|
+
results: Dictionary containing training metrics history
|
|
114
|
+
check_metrics: Whether to also check validation metrics-based criteria
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
bool: True if training should stop, False otherwise
|
|
118
|
+
"""
|
|
119
|
+
# Check for NaN in loss (can be done every epoch)
|
|
120
|
+
if self.check_for_nan(epoch_loss):
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
# Only check metric-based criteria if requested and results are provided
|
|
124
|
+
if check_metrics and results:
|
|
125
|
+
# Check for plateau in validation metrics
|
|
126
|
+
if self.check_for_plateau(results):
|
|
127
|
+
return True
|
|
128
|
+
|
|
129
|
+
# Check if best metric hasn't improved for a while
|
|
130
|
+
if self.check_best_metric_stagnation(results):
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
# # Check if convergence rate has slowed down
|
|
134
|
+
# if self.check_convergence_rate(results):
|
|
135
|
+
# return True
|
|
136
|
+
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
def get_stopped_reason(self):
|
|
140
|
+
"""Get the reason for stopping, if any."""
|
|
141
|
+
if self.stopped_reason:
|
|
142
|
+
return f"Early stopping triggered: {self.stopped_reason}"
|
|
143
|
+
return "No early stopping criteria met."
|