octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,273 @@
1
+ import os
2
+ import argparse
3
+ import copick
4
+ import torch
5
+ from tqdm import tqdm
6
+ from typing import Optional, Union, Tuple, List
7
+ from collections import defaultdict
8
+ import pytorch_lightning as pl
9
+ import torch.distributed as dist
10
+ from pytorch_lightning import Trainer
11
+ from pytorch_lightning.loggers import MLFlowLogger
12
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
13
+ from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
14
+ from dotenv import load_dotenv
15
+ from monai.transforms import (
16
+ Compose,
17
+ EnsureChannelFirstd,
18
+ Orientationd,
19
+ AsDiscrete,
20
+ RandFlipd,
21
+ RandRotate90d,
22
+ NormalizeIntensityd,
23
+ NormalizeIntensityd,
24
+ RandCropByLabelClassesd,
25
+ )
26
+ from monai.networks.nets import UNet
27
+ from monai.losses import TverskyLoss
28
+ from monai.metrics import DiceMetric, ConfusionMatrixMetric
29
+ import optuna
30
+ from optuna.integration import PyTorchLightningPruningCallback
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser(
35
+ description = "Hyperparamter tuning using PyTorch Lightning distributed data-parallel and Optuna."
36
+ )
37
+ parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
38
+ parser.add_argument('--copick_user_name', type=str, default='user0')
39
+ parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
40
+ parser.add_argument('--train_batch_size', type=int, default=1)
41
+ parser.add_argument('--val_batch_size', type=int, default=1)
42
+ parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
43
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
44
+ parser.add_argument('--num_epochs', type=int, default=20)
45
+ parser.add_argument('--num_gpus', type=int, default=1)
46
+ parser.add_argument('--num_optuna_trials', type=int, default=10)
47
+ parser.add_argument('--pruning', action="store_true", help="Activate the pruning feature. `MedianPruner` stops unpromising trials at the early stages of training.")
48
+ return parser.parse_args()
49
+
50
+
51
+ class Model(pl.LightningModule):
52
+ def __init__(
53
+ self,
54
+ spatial_dims: int = 3,
55
+ in_channels: int = 1,
56
+ out_channels: int = 8,
57
+ channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
58
+ strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
59
+ num_res_units: int = 1,
60
+ lr: float=1e-3):
61
+
62
+ super().__init__()
63
+ self.save_hyperparameters()
64
+ self.model = UNet(
65
+ spatial_dims=self.hparams.spatial_dims,
66
+ in_channels=self.hparams.in_channels,
67
+ out_channels=self.hparams.out_channels,
68
+ channels=self.hparams.channels,
69
+ strides=self.hparams.strides,
70
+ num_res_units=self.hparams.num_res_units,
71
+ )
72
+ self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
73
+ self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
74
+
75
+ def forward(self, x):
76
+ return self.model(x)
77
+
78
+ def training_step(self, batch, batch_idx):
79
+ x, y = batch['image'], batch['label']
80
+ y_hat = self(x)
81
+ loss = self.loss_fn(y_hat, y)
82
+ return loss
83
+
84
+ def validation_step(self, batch, batch_idx):
85
+ with torch.no_grad(): # This ensures that gradients are not stored in memory
86
+ x, y = batch['image'], batch['label']
87
+ y_hat = self(x)
88
+ metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
89
+ metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
90
+
91
+ # compute metric for current iteration
92
+ self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
93
+ metrics = self.metric_fn.aggregate(reduction="mean_batch")
94
+ for i,m in enumerate(metrics):
95
+ self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
96
+ metric = torch.mean(metrics) # cannot log ndarray
97
+ self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
98
+ return {'val_metric': metric}
99
+
100
+ def configure_optimizers(self):
101
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
102
+
103
+
104
+ class CopickDataModule(pl.LightningDataModule):
105
+ def __init__(
106
+ self,
107
+ copick_config_path: str,
108
+ train_batch_size: int,
109
+ val_batch_size: int,
110
+ num_random_samples_per_batch: int):
111
+
112
+ super().__init__()
113
+ self.train_batch_size = train_batch_size
114
+ self.val_batch_size = val_batch_size
115
+
116
+ self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path)
117
+ self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
118
+ self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
119
+ print(f"Number of training samples: {len(self.train_files)}")
120
+ print(f"Number of validation samples: {len(self.val_files)}")
121
+
122
+ # Non-random transforms to be cached
123
+ self.non_random_transforms = Compose([
124
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
125
+ NormalizeIntensityd(keys="image"),
126
+ Orientationd(keys=["image", "label"], axcodes="RAS")
127
+ ])
128
+
129
+ # Random transforms to be applied during training
130
+ self.random_transforms = Compose([
131
+ RandCropByLabelClassesd(
132
+ keys=["image", "label"],
133
+ label_key="label",
134
+ spatial_size=[96, 96, 96],
135
+ num_classes=self.nclasses,
136
+ num_samples=num_random_samples_per_batch
137
+ ),
138
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
139
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
140
+ ])
141
+
142
+ def setup(self, stage: Optional[str] = None) -> None:
143
+ self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
144
+ self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
145
+ self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
146
+ self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
147
+
148
+ def train_dataloader(self) -> DataLoader:
149
+ return DataLoader(
150
+ self.train_ds,
151
+ batch_size=self.train_batch_size,
152
+ shuffle=True,
153
+ num_workers=4,
154
+ persistent_workers=True,
155
+ pin_memory=torch.cuda.is_available(),
156
+ )
157
+
158
+ def val_dataloader(self) -> DataLoader:
159
+ return DataLoader(
160
+ self.val_ds,
161
+ batch_size=self.val_batch_size,
162
+ shuffle=False, # Ensure the data order remains consistent
163
+ num_workers=4,
164
+ persistent_workers=True,
165
+ pin_memory=torch.cuda.is_available(),
166
+ )
167
+
168
+ @staticmethod
169
+ def data_from_copick(copick_config_path):
170
+ root = copick.from_file(copick_config_path)
171
+ nclasses = len(root.pickable_objects) + 1
172
+ data_dicts = []
173
+ target_objects = defaultdict(dict)
174
+ for object in root.pickable_objects:
175
+ if object.is_particle:
176
+ target_objects[object.name]['label'] = object.label
177
+ target_objects[object.name]['radius'] = object.radius
178
+
179
+ data_dicts = []
180
+ for run in tqdm(root.runs[:8]):
181
+ tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
182
+ segmentation = run.get_segmentations(name='paintedPicks', user_id='user0', voxel_size=10, is_multilabel=True)[0].numpy()
183
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
184
+ segmentation[membrane_seg==1] = 1
185
+ data_dicts.append({"image": tomogram, "label": segmentation})
186
+
187
+ return data_dicts, nclasses
188
+
189
+
190
+ def objective(trial: optuna.trial.Trial) -> float:
191
+ args = get_args()
192
+ mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
193
+ tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
194
+ #run_name='test1'
195
+ )
196
+
197
+ # Trainer callbacks
198
+ checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='max')
199
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
200
+
201
+ # Detect distributed training environment
202
+ devices = list(range(args.num_gpus))
203
+
204
+ #channels, strides_pattern, num_res_units = sync_hyperparameters(trial)
205
+ # We optimize the number of layers, strides, and number of residual units
206
+ num_layers = trial.suggest_int("num_layers", 3, 5)
207
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32, 64])
208
+ channels = [base_channel * (2 ** i) for i in range(num_layers)]
209
+ num_downsampling_layers = trial.suggest_int("num_downsampling_layers", 1, num_layers - 1)
210
+ strides_pattern = [2] * num_downsampling_layers + [1] * (num_layers - num_downsampling_layers - 1)
211
+ num_res_units = trial.suggest_int("num_res_units", 1, 3)
212
+
213
+ model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=args.learning_rate)
214
+ datamodule = CopickDataModule(args.copick_config_path, args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
215
+ callback = PyTorchLightningPruningCallback(trial, monitor="val_metric")
216
+
217
+ # Priotize performace over precision
218
+ torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
219
+
220
+ # Trainer for distributed training with DDP
221
+ trainer = Trainer(
222
+ max_epochs=args.num_epochs,
223
+ logger=mlf_logger,
224
+ callbacks=[checkpoint_callback, lr_monitor],
225
+ strategy="ddp_spawn",
226
+ accelerator="gpu",
227
+ devices=devices,
228
+ num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
229
+ log_every_n_steps=1
230
+ )
231
+
232
+ hyperparameters = dict(op_num_layers=num_layers, op_base_channel=base_channel, op_num_downsampling_layers=num_downsampling_layers, op_num_res_units=num_res_units)
233
+ trainer.logger.log_hyperparams(hyperparameters)
234
+ trainer.fit(model, datamodule=datamodule)
235
+ callback.check_pruned()
236
+ return trainer.callback_metrics["val_metric"].item()
237
+
238
+
239
+ if __name__ == "__main__":
240
+ args = get_args()
241
+ # MLflow setup
242
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
243
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
244
+ if not password or not username:
245
+ print("Password not found in environment, loading from .env file...")
246
+ load_dotenv() # Loads environment variables from a .env file
247
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
248
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
249
+
250
+ # Check again after loading .env file
251
+ if not password:
252
+ raise ValueError("Password is not set in environment variables or .env file!")
253
+ else:
254
+ print("Password loaded successfully")
255
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
256
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
257
+
258
+ pruner: optuna.pruners.BasePruner = (
259
+ optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
260
+ )
261
+ storage = "sqlite:///example.db"
262
+ study = optuna.create_study(
263
+ study_name="pl_ddp",
264
+ storage=storage,
265
+ direction="maximize",
266
+ load_if_exists=True,
267
+ pruner=pruner
268
+ )
269
+ study.optimize(objective, n_trials=args.num_optuna_trials)
270
+
271
+ # Print the best hyperparameters
272
+ print(f"Best trial: {study.best_trial.value}")
273
+ print(f"Best hyperparameters: {study.best_trial.params}")
@@ -0,0 +1,244 @@
1
+ import os
2
+ import argparse
3
+ from typing import Optional, Union, Tuple, List
4
+ from collections import defaultdict
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning import Trainer
9
+ from pytorch_lightning.loggers import MLFlowLogger
10
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
11
+ from pytorch_lightning.strategies import DDPStrategy
12
+ import os
13
+ import copick
14
+ from tqdm import tqdm
15
+ from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
16
+ from dotenv import load_dotenv
17
+ from monai.transforms import (
18
+ Compose,
19
+ EnsureChannelFirstd,
20
+ Orientationd,
21
+ AsDiscrete,
22
+ RandFlipd,
23
+ RandRotate90d,
24
+ NormalizeIntensityd,
25
+ NormalizeIntensityd,
26
+ RandCropByLabelClassesd,
27
+ )
28
+ from monai.networks.nets import UNet
29
+ from monai.losses import TverskyLoss
30
+ from monai.metrics import DiceMetric, ConfusionMatrixMetric
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser(
35
+ description = "Train a 3d U-Net model with PyTorch Lightning supporting distributed training strategies."
36
+ )
37
+ parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
38
+ parser.add_argument('--copick_user_name', type=str, default='user0')
39
+ parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
40
+ parser.add_argument('--train_batch_size', type=int, default=1)
41
+ parser.add_argument('--val_batch_size', type=int, default=1)
42
+ parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
43
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
44
+ parser.add_argument('--num_epochs', type=int, default=20)
45
+ parser.add_argument('--num_gpus', type=int, default=1)
46
+ return parser.parse_args()
47
+
48
+
49
+ class Model(pl.LightningModule):
50
+ def __init__(
51
+ self,
52
+ spatial_dims: int = 3,
53
+ in_channels: int = 1,
54
+ out_channels: int = 8,
55
+ channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
56
+ strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
57
+ num_res_units: int = 1,
58
+ lr: float=1e-3):
59
+
60
+ super().__init__()
61
+ self.save_hyperparameters()
62
+
63
+ self.model = UNet(
64
+ spatial_dims=self.hparams.spatial_dims,
65
+ in_channels=self.hparams.in_channels,
66
+ out_channels=self.hparams.out_channels,
67
+ channels=self.hparams.channels,
68
+ strides=self.hparams.strides,
69
+ num_res_units=self.hparams.num_res_units,
70
+ )
71
+ self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
72
+ self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
73
+
74
+ def forward(self, x):
75
+ return self.model(x)
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ x, y = batch['image'], batch['label']
79
+ y_hat = self(x)
80
+ loss = self.loss_fn(y_hat, y)
81
+ return loss
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ x, y = batch['image'], batch['label']
85
+ y_hat = self(x)
86
+ metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
87
+ metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
88
+
89
+ # compute metric for current iteration
90
+ self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
91
+ metrics = self.metric_fn.aggregate(reduction="mean_batch")
92
+ for i,m in enumerate(metrics):
93
+ self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
94
+ metric = torch.mean(metrics) # cannot log ndarray
95
+ self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
96
+ return {'val_metric': metric}
97
+
98
+ def configure_optimizers(self):
99
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
100
+
101
+
102
+ class CopickDataModule(pl.LightningDataModule):
103
+ def __init__(
104
+ self,
105
+ copick_config_path: str,
106
+ copick_user_name: str,
107
+ copick_segmentation_name: str,
108
+ train_batch_size: int,
109
+ val_batch_size: int,
110
+ num_random_samples_per_batch: int):
111
+
112
+ super().__init__()
113
+ self.train_batch_size = train_batch_size
114
+ self.val_batch_size = val_batch_size
115
+
116
+ self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name)
117
+ self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
118
+ self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
119
+ print(f"Number of training samples: {len(self.train_files)}")
120
+ print(f"Number of validation samples: {len(self.val_files)}")
121
+
122
+ # Non-random transforms to be cached
123
+ self.non_random_transforms = Compose([
124
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
125
+ NormalizeIntensityd(keys="image"),
126
+ Orientationd(keys=["image", "label"], axcodes="RAS")
127
+ ])
128
+
129
+ # Random transforms to be applied during training
130
+ self.random_transforms = Compose([
131
+ RandCropByLabelClassesd(
132
+ keys=["image", "label"],
133
+ label_key="label",
134
+ spatial_size=[96, 96, 96],
135
+ num_classes=self.nclasses,
136
+ num_samples=num_random_samples_per_batch
137
+ ),
138
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
139
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
140
+ ])
141
+
142
+ def setup(self, stage: Optional[str] = None) -> None:
143
+ self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
144
+ self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
145
+ self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
146
+ self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
147
+
148
+ def train_dataloader(self) -> DataLoader:
149
+ return DataLoader(
150
+ self.train_ds,
151
+ batch_size=self.train_batch_size,
152
+ shuffle=True,
153
+ num_workers=4,
154
+ pin_memory=torch.cuda.is_available()
155
+ )
156
+
157
+ def val_dataloader(self) -> DataLoader:
158
+ return DataLoader(
159
+ self.val_ds,
160
+ batch_size=self.val_batch_size,
161
+ num_workers=4,
162
+ pin_memory=torch.cuda.is_available(),
163
+ shuffle=False, # Ensure the data order remains consistent
164
+ )
165
+
166
+ @staticmethod
167
+ def data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name):
168
+ root = copick.from_file(copick_config_path)
169
+ nclasses = len(root.pickable_objects) + 1
170
+ data_dicts = []
171
+ target_objects = defaultdict(dict)
172
+ for object in root.pickable_objects:
173
+ if object.is_particle:
174
+ target_objects[object.name]['label'] = object.label
175
+ target_objects[object.name]['radius'] = object.radius
176
+
177
+ data_dicts = []
178
+ for run in tqdm(root.runs[:2]):
179
+ tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
180
+ segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=10, is_multilabel=True)[0].numpy()
181
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")
182
+ if membrane_seg:
183
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
184
+ segmentation[membrane_seg==1]=1
185
+ data_dicts.append({"image": tomogram, "label": segmentation})
186
+
187
+ return data_dicts, nclasses
188
+
189
+
190
+ def train():
191
+ args = get_args()
192
+ mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
193
+ tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
194
+ #run_name='test1'
195
+ )
196
+ # Trainer callbacks
197
+ checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='min')
198
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
199
+
200
+ # Detect distributed training environment
201
+ devices = list(range(args.num_gpus))
202
+
203
+ # Initialize model
204
+ model = Model(lr=args.learning_rate)
205
+ datamodule = CopickDataModule(args.copick_config_path, args.copick_user_name, args.copick_segmentation_name,
206
+ args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
207
+
208
+ # Priotize performace over precision
209
+ torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
210
+
211
+ # Trainer for distributed training with DDP
212
+ trainer = Trainer(
213
+ max_epochs=args.num_epochs,
214
+ logger=mlf_logger,
215
+ callbacks=[checkpoint_callback, lr_monitor],
216
+ strategy=DDPStrategy(find_unused_parameters=False),
217
+ accelerator="gpu",
218
+ devices=devices,
219
+ num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
220
+ log_every_n_steps=1
221
+ )
222
+
223
+ trainer.fit(model, datamodule=datamodule)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ # MLflow setup
228
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
229
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
230
+ if not password or not username:
231
+ print("Password not found in environment, loading from .env file...")
232
+ load_dotenv() # Loads environment variables from a .env file
233
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
234
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
235
+
236
+ # Check again after loading .env file
237
+ if not password:
238
+ raise ValueError("Password is not set in environment variables or .env file!")
239
+ else:
240
+ print("Password loaded successfully")
241
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
242
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
243
+
244
+ train()
@@ -0,0 +1,143 @@
1
+ import numpy as np
2
+
3
+ class EarlyStoppingChecker:
4
+ """
5
+ A class to manage various early stopping criteria for model training.
6
+ """
7
+
8
+ def __init__(self,
9
+ max_nan_epochs=15,
10
+ plateau_patience=20,
11
+ plateau_min_delta=0.001,
12
+ stagnation_patience=50,
13
+ convergence_window=5,
14
+ convergence_threshold=0.005,
15
+ val_interval=15,
16
+ monitor_metric='avg_fbeta'):
17
+ """
18
+ Initialize early stopping parameters.
19
+
20
+ Args:
21
+ max_nan_epochs: Maximum number of epochs with NaN loss before stopping
22
+ plateau_patience: Number of validation checks to wait for plateau detection
23
+ plateau_min_delta: Minimum change to qualify as improvement
24
+ stagnation_patience: Number of validation intervals to wait for best metric improvement
25
+ convergence_window: Window size for calculating improvement rate
26
+ convergence_threshold: Minimum improvement rate threshold
27
+ val_interval: Number of epochs between validation runs
28
+ monitor_metric: Primary metric to monitor for early stopping criteria
29
+ """
30
+ self.max_nan_epochs = max_nan_epochs
31
+ self.plateau_patience = plateau_patience
32
+ self.plateau_min_delta = plateau_min_delta
33
+ self.stagnation_patience = stagnation_patience
34
+ self.convergence_window = convergence_window
35
+ self.convergence_threshold = convergence_threshold
36
+ self.val_interval = val_interval
37
+ self.monitor_metric = monitor_metric
38
+
39
+ # Counters
40
+ self.nan_counter = 0
41
+
42
+ # Flags for detailed reporting
43
+ self.stopped_reason = None
44
+
45
+ def check_for_nan(self, epoch_loss):
46
+ """Check for NaN in the loss."""
47
+ if np.isnan(epoch_loss):
48
+ self.nan_counter += 1
49
+ if self.nan_counter > self.max_nan_epochs:
50
+ self.stopped_reason = f"NaN values in loss for more than {self.max_nan_epochs} epochs"
51
+ return True
52
+ else:
53
+ self.nan_counter = 0 # Reset the counter if loss is valid
54
+ return False
55
+
56
+ def check_for_plateau(self, results):
57
+ """Detect plateaus in validation metrics."""
58
+ if len(results[self.monitor_metric]) < self.plateau_patience + 1:
59
+ return False
60
+
61
+ # Get the last 'patience' number of validation points
62
+ recent_values = [x[1] for x in results[self.monitor_metric][-self.plateau_patience:]]
63
+ # Find the max value in the window
64
+ max_value = max(recent_values)
65
+ # Find the min value in the window
66
+ min_value = min(recent_values)
67
+
68
+ # If the range of values is small, consider it a plateau
69
+ if max_value - min_value < self.plateau_min_delta:
70
+ self.stopped_reason = f"{self.monitor_metric} plateaued for {self.plateau_patience} validations"
71
+ return True
72
+
73
+ return False
74
+
75
+ def check_best_metric_stagnation(self, results):
76
+ """Stop if best metric hasn't improved for a number of validation intervals."""
77
+ if "best_metric_epoch" not in results or len(results[self.monitor_metric]) < self.stagnation_patience + 1:
78
+ return False
79
+
80
+ # Get epoch of the best metric so far
81
+ best_epoch = results["best_metric_epoch"]
82
+ current_epoch = results[self.monitor_metric][-1][0]
83
+
84
+ # Check if it's been more than 'patience' validation intervals
85
+ if (current_epoch - best_epoch) >= (self.stagnation_patience * self.val_interval):
86
+ self.stopped_reason = f"No improvement for {self.stagnation_patience} validation intervals"
87
+ return True
88
+
89
+ return False
90
+
91
+ # def check_convergence_rate(self, results):
92
+ # """Stop when improvement rate slows below threshold."""
93
+ # if len(results[self.monitor_metric]) < self.convergence_window + 1:
94
+ # return False
95
+
96
+ # # Calculate average improvement rate over window
97
+ # recent_values = [x[1] for x in results[self.monitor_metric][-(self.convergence_window+1):]]
98
+ # improvements = [recent_values[i+1] - recent_values[i] for i in range(self.convergence_window)]
99
+ # avg_improvement = sum(improvements) / self.convergence_window
100
+
101
+ # if avg_improvement < self.convergence_threshold and avg_improvement > 0:
102
+ # self.stopped_reason = f"Convergence rate ({avg_improvement:.6f}) below threshold"
103
+ # return True
104
+
105
+ # return False
106
+
107
+ def should_stop_training(self, epoch_loss, results=None, check_metrics=False):
108
+ """
109
+ Comprehensive check for whether training should stop.
110
+
111
+ Args:
112
+ epoch_loss: Current epoch's loss value
113
+ results: Dictionary containing training metrics history
114
+ check_metrics: Whether to also check validation metrics-based criteria
115
+
116
+ Returns:
117
+ bool: True if training should stop, False otherwise
118
+ """
119
+ # Check for NaN in loss (can be done every epoch)
120
+ if self.check_for_nan(epoch_loss):
121
+ return True
122
+
123
+ # Only check metric-based criteria if requested and results are provided
124
+ if check_metrics and results:
125
+ # Check for plateau in validation metrics
126
+ if self.check_for_plateau(results):
127
+ return True
128
+
129
+ # Check if best metric hasn't improved for a while
130
+ if self.check_best_metric_stagnation(results):
131
+ return True
132
+
133
+ # # Check if convergence rate has slowed down
134
+ # if self.check_convergence_rate(results):
135
+ # return True
136
+
137
+ return False
138
+
139
+ def get_stopped_reason(self):
140
+ """Get the reason for stopping, if any."""
141
+ if self.stopped_reason:
142
+ return f"Early stopping triggered: {self.stopped_reason}"
143
+ return "No early stopping criteria met."