octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
import copick
|
|
4
|
+
import torch
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
from typing import Optional, Union, Tuple, List
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
import pytorch_lightning as pl
|
|
9
|
+
import torch.distributed as dist
|
|
10
|
+
from pytorch_lightning import Trainer
|
|
11
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
12
|
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
|
13
|
+
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
|
|
14
|
+
from dotenv import load_dotenv
|
|
15
|
+
from monai.transforms import (
|
|
16
|
+
Compose,
|
|
17
|
+
EnsureChannelFirstd,
|
|
18
|
+
Orientationd,
|
|
19
|
+
AsDiscrete,
|
|
20
|
+
RandFlipd,
|
|
21
|
+
RandRotate90d,
|
|
22
|
+
NormalizeIntensityd,
|
|
23
|
+
NormalizeIntensityd,
|
|
24
|
+
RandCropByLabelClassesd,
|
|
25
|
+
)
|
|
26
|
+
from monai.networks.nets import UNet
|
|
27
|
+
from monai.losses import TverskyLoss
|
|
28
|
+
from monai.metrics import DiceMetric, ConfusionMatrixMetric
|
|
29
|
+
import optuna
|
|
30
|
+
from optuna.integration import PyTorchLightningPruningCallback
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_args():
|
|
34
|
+
parser = argparse.ArgumentParser(
|
|
35
|
+
description = "Hyperparamter tuning using PyTorch Lightning distributed data-parallel and Optuna."
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
|
|
38
|
+
parser.add_argument('--copick_user_name', type=str, default='user0')
|
|
39
|
+
parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
|
|
40
|
+
parser.add_argument('--train_batch_size', type=int, default=1)
|
|
41
|
+
parser.add_argument('--val_batch_size', type=int, default=1)
|
|
42
|
+
parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
|
|
43
|
+
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
|
44
|
+
parser.add_argument('--num_epochs', type=int, default=20)
|
|
45
|
+
parser.add_argument('--num_gpus', type=int, default=1)
|
|
46
|
+
parser.add_argument('--num_optuna_trials', type=int, default=10)
|
|
47
|
+
parser.add_argument('--pruning', action="store_true", help="Activate the pruning feature. `MedianPruner` stops unpromising trials at the early stages of training.")
|
|
48
|
+
return parser.parse_args()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Model(pl.LightningModule):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
spatial_dims: int = 3,
|
|
55
|
+
in_channels: int = 1,
|
|
56
|
+
out_channels: int = 8,
|
|
57
|
+
channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
|
|
58
|
+
strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
|
|
59
|
+
num_res_units: int = 1,
|
|
60
|
+
lr: float=1e-3):
|
|
61
|
+
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.save_hyperparameters()
|
|
64
|
+
self.model = UNet(
|
|
65
|
+
spatial_dims=self.hparams.spatial_dims,
|
|
66
|
+
in_channels=self.hparams.in_channels,
|
|
67
|
+
out_channels=self.hparams.out_channels,
|
|
68
|
+
channels=self.hparams.channels,
|
|
69
|
+
strides=self.hparams.strides,
|
|
70
|
+
num_res_units=self.hparams.num_res_units,
|
|
71
|
+
)
|
|
72
|
+
self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
|
|
73
|
+
self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
|
|
74
|
+
|
|
75
|
+
def forward(self, x):
|
|
76
|
+
return self.model(x)
|
|
77
|
+
|
|
78
|
+
def training_step(self, batch, batch_idx):
|
|
79
|
+
x, y = batch['image'], batch['label']
|
|
80
|
+
y_hat = self(x)
|
|
81
|
+
loss = self.loss_fn(y_hat, y)
|
|
82
|
+
return loss
|
|
83
|
+
|
|
84
|
+
def validation_step(self, batch, batch_idx):
|
|
85
|
+
with torch.no_grad(): # This ensures that gradients are not stored in memory
|
|
86
|
+
x, y = batch['image'], batch['label']
|
|
87
|
+
y_hat = self(x)
|
|
88
|
+
metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
|
|
89
|
+
metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
|
|
90
|
+
|
|
91
|
+
# compute metric for current iteration
|
|
92
|
+
self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
|
|
93
|
+
metrics = self.metric_fn.aggregate(reduction="mean_batch")
|
|
94
|
+
for i,m in enumerate(metrics):
|
|
95
|
+
self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
|
|
96
|
+
metric = torch.mean(metrics) # cannot log ndarray
|
|
97
|
+
self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
|
|
98
|
+
return {'val_metric': metric}
|
|
99
|
+
|
|
100
|
+
def configure_optimizers(self):
|
|
101
|
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class CopickDataModule(pl.LightningDataModule):
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
copick_config_path: str,
|
|
108
|
+
train_batch_size: int,
|
|
109
|
+
val_batch_size: int,
|
|
110
|
+
num_random_samples_per_batch: int):
|
|
111
|
+
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.train_batch_size = train_batch_size
|
|
114
|
+
self.val_batch_size = val_batch_size
|
|
115
|
+
|
|
116
|
+
self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path)
|
|
117
|
+
self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
|
|
118
|
+
self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
|
|
119
|
+
print(f"Number of training samples: {len(self.train_files)}")
|
|
120
|
+
print(f"Number of validation samples: {len(self.val_files)}")
|
|
121
|
+
|
|
122
|
+
# Non-random transforms to be cached
|
|
123
|
+
self.non_random_transforms = Compose([
|
|
124
|
+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
|
|
125
|
+
NormalizeIntensityd(keys="image"),
|
|
126
|
+
Orientationd(keys=["image", "label"], axcodes="RAS")
|
|
127
|
+
])
|
|
128
|
+
|
|
129
|
+
# Random transforms to be applied during training
|
|
130
|
+
self.random_transforms = Compose([
|
|
131
|
+
RandCropByLabelClassesd(
|
|
132
|
+
keys=["image", "label"],
|
|
133
|
+
label_key="label",
|
|
134
|
+
spatial_size=[96, 96, 96],
|
|
135
|
+
num_classes=self.nclasses,
|
|
136
|
+
num_samples=num_random_samples_per_batch
|
|
137
|
+
),
|
|
138
|
+
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
|
|
139
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
143
|
+
self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
144
|
+
self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
|
|
145
|
+
self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
146
|
+
self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
|
|
147
|
+
|
|
148
|
+
def train_dataloader(self) -> DataLoader:
|
|
149
|
+
return DataLoader(
|
|
150
|
+
self.train_ds,
|
|
151
|
+
batch_size=self.train_batch_size,
|
|
152
|
+
shuffle=True,
|
|
153
|
+
num_workers=4,
|
|
154
|
+
persistent_workers=True,
|
|
155
|
+
pin_memory=torch.cuda.is_available(),
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def val_dataloader(self) -> DataLoader:
|
|
159
|
+
return DataLoader(
|
|
160
|
+
self.val_ds,
|
|
161
|
+
batch_size=self.val_batch_size,
|
|
162
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
163
|
+
num_workers=4,
|
|
164
|
+
persistent_workers=True,
|
|
165
|
+
pin_memory=torch.cuda.is_available(),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def data_from_copick(copick_config_path):
|
|
170
|
+
root = copick.from_file(copick_config_path)
|
|
171
|
+
nclasses = len(root.pickable_objects) + 1
|
|
172
|
+
data_dicts = []
|
|
173
|
+
target_objects = defaultdict(dict)
|
|
174
|
+
for object in root.pickable_objects:
|
|
175
|
+
if object.is_particle:
|
|
176
|
+
target_objects[object.name]['label'] = object.label
|
|
177
|
+
target_objects[object.name]['radius'] = object.radius
|
|
178
|
+
|
|
179
|
+
data_dicts = []
|
|
180
|
+
for run in tqdm(root.runs[:8]):
|
|
181
|
+
tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
|
|
182
|
+
segmentation = run.get_segmentations(name='paintedPicks', user_id='user0', voxel_size=10, is_multilabel=True)[0].numpy()
|
|
183
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
|
|
184
|
+
segmentation[membrane_seg==1] = 1
|
|
185
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
186
|
+
|
|
187
|
+
return data_dicts, nclasses
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def objective(trial: optuna.trial.Trial) -> float:
|
|
191
|
+
args = get_args()
|
|
192
|
+
mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
|
|
193
|
+
tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
|
|
194
|
+
#run_name='test1'
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Trainer callbacks
|
|
198
|
+
checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='max')
|
|
199
|
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
200
|
+
|
|
201
|
+
# Detect distributed training environment
|
|
202
|
+
devices = list(range(args.num_gpus))
|
|
203
|
+
|
|
204
|
+
#channels, strides_pattern, num_res_units = sync_hyperparameters(trial)
|
|
205
|
+
# We optimize the number of layers, strides, and number of residual units
|
|
206
|
+
num_layers = trial.suggest_int("num_layers", 3, 5)
|
|
207
|
+
base_channel = trial.suggest_categorical("base_channel", [8, 16, 32, 64])
|
|
208
|
+
channels = [base_channel * (2 ** i) for i in range(num_layers)]
|
|
209
|
+
num_downsampling_layers = trial.suggest_int("num_downsampling_layers", 1, num_layers - 1)
|
|
210
|
+
strides_pattern = [2] * num_downsampling_layers + [1] * (num_layers - num_downsampling_layers - 1)
|
|
211
|
+
num_res_units = trial.suggest_int("num_res_units", 1, 3)
|
|
212
|
+
|
|
213
|
+
model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=args.learning_rate)
|
|
214
|
+
datamodule = CopickDataModule(args.copick_config_path, args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
|
|
215
|
+
callback = PyTorchLightningPruningCallback(trial, monitor="val_metric")
|
|
216
|
+
|
|
217
|
+
# Priotize performace over precision
|
|
218
|
+
torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
|
|
219
|
+
|
|
220
|
+
# Trainer for distributed training with DDP
|
|
221
|
+
trainer = Trainer(
|
|
222
|
+
max_epochs=args.num_epochs,
|
|
223
|
+
logger=mlf_logger,
|
|
224
|
+
callbacks=[checkpoint_callback, lr_monitor],
|
|
225
|
+
strategy="ddp_spawn",
|
|
226
|
+
accelerator="gpu",
|
|
227
|
+
devices=devices,
|
|
228
|
+
num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
|
|
229
|
+
log_every_n_steps=1
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
hyperparameters = dict(op_num_layers=num_layers, op_base_channel=base_channel, op_num_downsampling_layers=num_downsampling_layers, op_num_res_units=num_res_units)
|
|
233
|
+
trainer.logger.log_hyperparams(hyperparameters)
|
|
234
|
+
trainer.fit(model, datamodule=datamodule)
|
|
235
|
+
callback.check_pruned()
|
|
236
|
+
return trainer.callback_metrics["val_metric"].item()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
if __name__ == "__main__":
|
|
240
|
+
args = get_args()
|
|
241
|
+
# MLflow setup
|
|
242
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
243
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
244
|
+
if not password or not username:
|
|
245
|
+
print("Password not found in environment, loading from .env file...")
|
|
246
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
247
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
248
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
249
|
+
|
|
250
|
+
# Check again after loading .env file
|
|
251
|
+
if not password:
|
|
252
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
253
|
+
else:
|
|
254
|
+
print("Password loaded successfully")
|
|
255
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
256
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
257
|
+
|
|
258
|
+
pruner: optuna.pruners.BasePruner = (
|
|
259
|
+
optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
|
|
260
|
+
)
|
|
261
|
+
storage = "sqlite:///example.db"
|
|
262
|
+
study = optuna.create_study(
|
|
263
|
+
study_name="pl_ddp",
|
|
264
|
+
storage=storage,
|
|
265
|
+
direction="maximize",
|
|
266
|
+
load_if_exists=True,
|
|
267
|
+
pruner=pruner
|
|
268
|
+
)
|
|
269
|
+
study.optimize(objective, n_trials=args.num_optuna_trials)
|
|
270
|
+
|
|
271
|
+
# Print the best hyperparameters
|
|
272
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
273
|
+
print(f"Best hyperparameters: {study.best_trial.params}")
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
from typing import Optional, Union, Tuple, List
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import DataLoader
|
|
7
|
+
import pytorch_lightning as pl
|
|
8
|
+
from pytorch_lightning import Trainer
|
|
9
|
+
from pytorch_lightning.loggers import MLFlowLogger
|
|
10
|
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
|
11
|
+
from pytorch_lightning.strategies import DDPStrategy
|
|
12
|
+
import os
|
|
13
|
+
import copick
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
from monai.transforms import (
|
|
18
|
+
Compose,
|
|
19
|
+
EnsureChannelFirstd,
|
|
20
|
+
Orientationd,
|
|
21
|
+
AsDiscrete,
|
|
22
|
+
RandFlipd,
|
|
23
|
+
RandRotate90d,
|
|
24
|
+
NormalizeIntensityd,
|
|
25
|
+
NormalizeIntensityd,
|
|
26
|
+
RandCropByLabelClassesd,
|
|
27
|
+
)
|
|
28
|
+
from monai.networks.nets import UNet
|
|
29
|
+
from monai.losses import TverskyLoss
|
|
30
|
+
from monai.metrics import DiceMetric, ConfusionMatrixMetric
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_args():
|
|
34
|
+
parser = argparse.ArgumentParser(
|
|
35
|
+
description = "Train a 3d U-Net model with PyTorch Lightning supporting distributed training strategies."
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
|
|
38
|
+
parser.add_argument('--copick_user_name', type=str, default='user0')
|
|
39
|
+
parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
|
|
40
|
+
parser.add_argument('--train_batch_size', type=int, default=1)
|
|
41
|
+
parser.add_argument('--val_batch_size', type=int, default=1)
|
|
42
|
+
parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
|
|
43
|
+
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
|
44
|
+
parser.add_argument('--num_epochs', type=int, default=20)
|
|
45
|
+
parser.add_argument('--num_gpus', type=int, default=1)
|
|
46
|
+
return parser.parse_args()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Model(pl.LightningModule):
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
spatial_dims: int = 3,
|
|
53
|
+
in_channels: int = 1,
|
|
54
|
+
out_channels: int = 8,
|
|
55
|
+
channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
|
|
56
|
+
strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
|
|
57
|
+
num_res_units: int = 1,
|
|
58
|
+
lr: float=1e-3):
|
|
59
|
+
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.save_hyperparameters()
|
|
62
|
+
|
|
63
|
+
self.model = UNet(
|
|
64
|
+
spatial_dims=self.hparams.spatial_dims,
|
|
65
|
+
in_channels=self.hparams.in_channels,
|
|
66
|
+
out_channels=self.hparams.out_channels,
|
|
67
|
+
channels=self.hparams.channels,
|
|
68
|
+
strides=self.hparams.strides,
|
|
69
|
+
num_res_units=self.hparams.num_res_units,
|
|
70
|
+
)
|
|
71
|
+
self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
|
|
72
|
+
self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
return self.model(x)
|
|
76
|
+
|
|
77
|
+
def training_step(self, batch, batch_idx):
|
|
78
|
+
x, y = batch['image'], batch['label']
|
|
79
|
+
y_hat = self(x)
|
|
80
|
+
loss = self.loss_fn(y_hat, y)
|
|
81
|
+
return loss
|
|
82
|
+
|
|
83
|
+
def validation_step(self, batch, batch_idx):
|
|
84
|
+
x, y = batch['image'], batch['label']
|
|
85
|
+
y_hat = self(x)
|
|
86
|
+
metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
|
|
87
|
+
metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
|
|
88
|
+
|
|
89
|
+
# compute metric for current iteration
|
|
90
|
+
self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
|
|
91
|
+
metrics = self.metric_fn.aggregate(reduction="mean_batch")
|
|
92
|
+
for i,m in enumerate(metrics):
|
|
93
|
+
self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
|
|
94
|
+
metric = torch.mean(metrics) # cannot log ndarray
|
|
95
|
+
self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
|
|
96
|
+
return {'val_metric': metric}
|
|
97
|
+
|
|
98
|
+
def configure_optimizers(self):
|
|
99
|
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class CopickDataModule(pl.LightningDataModule):
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
copick_config_path: str,
|
|
106
|
+
copick_user_name: str,
|
|
107
|
+
copick_segmentation_name: str,
|
|
108
|
+
train_batch_size: int,
|
|
109
|
+
val_batch_size: int,
|
|
110
|
+
num_random_samples_per_batch: int):
|
|
111
|
+
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.train_batch_size = train_batch_size
|
|
114
|
+
self.val_batch_size = val_batch_size
|
|
115
|
+
|
|
116
|
+
self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name)
|
|
117
|
+
self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
|
|
118
|
+
self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
|
|
119
|
+
print(f"Number of training samples: {len(self.train_files)}")
|
|
120
|
+
print(f"Number of validation samples: {len(self.val_files)}")
|
|
121
|
+
|
|
122
|
+
# Non-random transforms to be cached
|
|
123
|
+
self.non_random_transforms = Compose([
|
|
124
|
+
EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
|
|
125
|
+
NormalizeIntensityd(keys="image"),
|
|
126
|
+
Orientationd(keys=["image", "label"], axcodes="RAS")
|
|
127
|
+
])
|
|
128
|
+
|
|
129
|
+
# Random transforms to be applied during training
|
|
130
|
+
self.random_transforms = Compose([
|
|
131
|
+
RandCropByLabelClassesd(
|
|
132
|
+
keys=["image", "label"],
|
|
133
|
+
label_key="label",
|
|
134
|
+
spatial_size=[96, 96, 96],
|
|
135
|
+
num_classes=self.nclasses,
|
|
136
|
+
num_samples=num_random_samples_per_batch
|
|
137
|
+
),
|
|
138
|
+
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
|
|
139
|
+
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
|
|
140
|
+
])
|
|
141
|
+
|
|
142
|
+
def setup(self, stage: Optional[str] = None) -> None:
|
|
143
|
+
self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
144
|
+
self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
|
|
145
|
+
self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
|
|
146
|
+
self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
|
|
147
|
+
|
|
148
|
+
def train_dataloader(self) -> DataLoader:
|
|
149
|
+
return DataLoader(
|
|
150
|
+
self.train_ds,
|
|
151
|
+
batch_size=self.train_batch_size,
|
|
152
|
+
shuffle=True,
|
|
153
|
+
num_workers=4,
|
|
154
|
+
pin_memory=torch.cuda.is_available()
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def val_dataloader(self) -> DataLoader:
|
|
158
|
+
return DataLoader(
|
|
159
|
+
self.val_ds,
|
|
160
|
+
batch_size=self.val_batch_size,
|
|
161
|
+
num_workers=4,
|
|
162
|
+
pin_memory=torch.cuda.is_available(),
|
|
163
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name):
|
|
168
|
+
root = copick.from_file(copick_config_path)
|
|
169
|
+
nclasses = len(root.pickable_objects) + 1
|
|
170
|
+
data_dicts = []
|
|
171
|
+
target_objects = defaultdict(dict)
|
|
172
|
+
for object in root.pickable_objects:
|
|
173
|
+
if object.is_particle:
|
|
174
|
+
target_objects[object.name]['label'] = object.label
|
|
175
|
+
target_objects[object.name]['radius'] = object.radius
|
|
176
|
+
|
|
177
|
+
data_dicts = []
|
|
178
|
+
for run in tqdm(root.runs[:2]):
|
|
179
|
+
tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
|
|
180
|
+
segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=10, is_multilabel=True)[0].numpy()
|
|
181
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")
|
|
182
|
+
if membrane_seg:
|
|
183
|
+
membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
|
|
184
|
+
segmentation[membrane_seg==1]=1
|
|
185
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
186
|
+
|
|
187
|
+
return data_dicts, nclasses
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def train():
|
|
191
|
+
args = get_args()
|
|
192
|
+
mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
|
|
193
|
+
tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
|
|
194
|
+
#run_name='test1'
|
|
195
|
+
)
|
|
196
|
+
# Trainer callbacks
|
|
197
|
+
checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='min')
|
|
198
|
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
199
|
+
|
|
200
|
+
# Detect distributed training environment
|
|
201
|
+
devices = list(range(args.num_gpus))
|
|
202
|
+
|
|
203
|
+
# Initialize model
|
|
204
|
+
model = Model(lr=args.learning_rate)
|
|
205
|
+
datamodule = CopickDataModule(args.copick_config_path, args.copick_user_name, args.copick_segmentation_name,
|
|
206
|
+
args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
|
|
207
|
+
|
|
208
|
+
# Priotize performace over precision
|
|
209
|
+
torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
|
|
210
|
+
|
|
211
|
+
# Trainer for distributed training with DDP
|
|
212
|
+
trainer = Trainer(
|
|
213
|
+
max_epochs=args.num_epochs,
|
|
214
|
+
logger=mlf_logger,
|
|
215
|
+
callbacks=[checkpoint_callback, lr_monitor],
|
|
216
|
+
strategy=DDPStrategy(find_unused_parameters=False),
|
|
217
|
+
accelerator="gpu",
|
|
218
|
+
devices=devices,
|
|
219
|
+
num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
|
|
220
|
+
log_every_n_steps=1
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
trainer.fit(model, datamodule=datamodule)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if __name__ == "__main__":
|
|
227
|
+
# MLflow setup
|
|
228
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
229
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
230
|
+
if not password or not username:
|
|
231
|
+
print("Password not found in environment, loading from .env file...")
|
|
232
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
233
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
234
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
235
|
+
|
|
236
|
+
# Check again after loading .env file
|
|
237
|
+
if not password:
|
|
238
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
239
|
+
else:
|
|
240
|
+
print("Password loaded successfully")
|
|
241
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
242
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
243
|
+
|
|
244
|
+
train()
|
octopi/utils/__init__.py
ADDED
|
File without changes
|
octopi/utils/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for MLflow setup and reproducibility.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
import torch, numpy as np
|
|
7
|
+
import os, random
|
|
8
|
+
import octopi
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mlflow_setup():
|
|
12
|
+
"""
|
|
13
|
+
Set up MLflow configuration from environment variables.
|
|
14
|
+
"""
|
|
15
|
+
module_root = os.path.dirname(octopi.__file__)
|
|
16
|
+
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
17
|
+
load_dotenv(dotenv_path=dotenv_path)
|
|
18
|
+
|
|
19
|
+
# MLflow setup
|
|
20
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
21
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
22
|
+
if not password or not username:
|
|
23
|
+
print("Password not found in environment, loading from .env file...")
|
|
24
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
25
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
26
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
27
|
+
|
|
28
|
+
# Check again after loading .env file
|
|
29
|
+
if not password:
|
|
30
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
31
|
+
else:
|
|
32
|
+
print("Password loaded successfully")
|
|
33
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
34
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
35
|
+
|
|
36
|
+
return os.getenv('MLFLOW_TRACKING_URI')
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_seed(seed):
|
|
40
|
+
"""
|
|
41
|
+
Set random seeds for reproducibility across Python, NumPy, and PyTorch.
|
|
42
|
+
"""
|
|
43
|
+
# Set the seed for Python's random module
|
|
44
|
+
random.seed(seed)
|
|
45
|
+
|
|
46
|
+
# Set the seed for NumPy
|
|
47
|
+
np.random.seed(seed)
|
|
48
|
+
|
|
49
|
+
# Set the seed for PyTorch (both CPU and GPU)
|
|
50
|
+
torch.manual_seed(seed)
|
|
51
|
+
if torch.cuda.is_available():
|
|
52
|
+
torch.cuda.manual_seed(seed)
|
|
53
|
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
54
|
+
|
|
55
|
+
# Ensure reproducibility of operations by disabling certain optimizations
|
|
56
|
+
torch.backends.cudnn.deterministic = True
|
|
57
|
+
torch.backends.cudnn.benchmark = False
|