octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,273 @@
1
+ import os
2
+ import argparse
3
+ import copick
4
+ import torch
5
+ from tqdm import tqdm
6
+ from typing import Optional, Union, Tuple, List
7
+ from collections import defaultdict
8
+ import pytorch_lightning as pl
9
+ import torch.distributed as dist
10
+ from pytorch_lightning import Trainer
11
+ from pytorch_lightning.loggers import MLFlowLogger
12
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
13
+ from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
14
+ from dotenv import load_dotenv
15
+ from monai.transforms import (
16
+ Compose,
17
+ EnsureChannelFirstd,
18
+ Orientationd,
19
+ AsDiscrete,
20
+ RandFlipd,
21
+ RandRotate90d,
22
+ NormalizeIntensityd,
23
+ NormalizeIntensityd,
24
+ RandCropByLabelClassesd,
25
+ )
26
+ from monai.networks.nets import UNet
27
+ from monai.losses import TverskyLoss
28
+ from monai.metrics import DiceMetric, ConfusionMatrixMetric
29
+ import optuna
30
+ from optuna.integration import PyTorchLightningPruningCallback
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser(
35
+ description = "Hyperparamter tuning using PyTorch Lightning distributed data-parallel and Optuna."
36
+ )
37
+ parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
38
+ parser.add_argument('--copick_user_name', type=str, default='user0')
39
+ parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
40
+ parser.add_argument('--train_batch_size', type=int, default=1)
41
+ parser.add_argument('--val_batch_size', type=int, default=1)
42
+ parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
43
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
44
+ parser.add_argument('--num_epochs', type=int, default=20)
45
+ parser.add_argument('--num_gpus', type=int, default=1)
46
+ parser.add_argument('--num_optuna_trials', type=int, default=10)
47
+ parser.add_argument('--pruning', action="store_true", help="Activate the pruning feature. `MedianPruner` stops unpromising trials at the early stages of training.")
48
+ return parser.parse_args()
49
+
50
+
51
+ class Model(pl.LightningModule):
52
+ def __init__(
53
+ self,
54
+ spatial_dims: int = 3,
55
+ in_channels: int = 1,
56
+ out_channels: int = 8,
57
+ channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
58
+ strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
59
+ num_res_units: int = 1,
60
+ lr: float=1e-3):
61
+
62
+ super().__init__()
63
+ self.save_hyperparameters()
64
+ self.model = UNet(
65
+ spatial_dims=self.hparams.spatial_dims,
66
+ in_channels=self.hparams.in_channels,
67
+ out_channels=self.hparams.out_channels,
68
+ channels=self.hparams.channels,
69
+ strides=self.hparams.strides,
70
+ num_res_units=self.hparams.num_res_units,
71
+ )
72
+ self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
73
+ self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
74
+
75
+ def forward(self, x):
76
+ return self.model(x)
77
+
78
+ def training_step(self, batch, batch_idx):
79
+ x, y = batch['image'], batch['label']
80
+ y_hat = self(x)
81
+ loss = self.loss_fn(y_hat, y)
82
+ return loss
83
+
84
+ def validation_step(self, batch, batch_idx):
85
+ with torch.no_grad(): # This ensures that gradients are not stored in memory
86
+ x, y = batch['image'], batch['label']
87
+ y_hat = self(x)
88
+ metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
89
+ metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
90
+
91
+ # compute metric for current iteration
92
+ self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
93
+ metrics = self.metric_fn.aggregate(reduction="mean_batch")
94
+ for i,m in enumerate(metrics):
95
+ self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
96
+ metric = torch.mean(metrics) # cannot log ndarray
97
+ self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
98
+ return {'val_metric': metric}
99
+
100
+ def configure_optimizers(self):
101
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
102
+
103
+
104
+ class CopickDataModule(pl.LightningDataModule):
105
+ def __init__(
106
+ self,
107
+ copick_config_path: str,
108
+ train_batch_size: int,
109
+ val_batch_size: int,
110
+ num_random_samples_per_batch: int):
111
+
112
+ super().__init__()
113
+ self.train_batch_size = train_batch_size
114
+ self.val_batch_size = val_batch_size
115
+
116
+ self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path)
117
+ self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
118
+ self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
119
+ print(f"Number of training samples: {len(self.train_files)}")
120
+ print(f"Number of validation samples: {len(self.val_files)}")
121
+
122
+ # Non-random transforms to be cached
123
+ self.non_random_transforms = Compose([
124
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
125
+ NormalizeIntensityd(keys="image"),
126
+ Orientationd(keys=["image", "label"], axcodes="RAS")
127
+ ])
128
+
129
+ # Random transforms to be applied during training
130
+ self.random_transforms = Compose([
131
+ RandCropByLabelClassesd(
132
+ keys=["image", "label"],
133
+ label_key="label",
134
+ spatial_size=[96, 96, 96],
135
+ num_classes=self.nclasses,
136
+ num_samples=num_random_samples_per_batch
137
+ ),
138
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
139
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
140
+ ])
141
+
142
+ def setup(self, stage: Optional[str] = None) -> None:
143
+ self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
144
+ self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
145
+ self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
146
+ self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
147
+
148
+ def train_dataloader(self) -> DataLoader:
149
+ return DataLoader(
150
+ self.train_ds,
151
+ batch_size=self.train_batch_size,
152
+ shuffle=True,
153
+ num_workers=4,
154
+ persistent_workers=True,
155
+ pin_memory=torch.cuda.is_available(),
156
+ )
157
+
158
+ def val_dataloader(self) -> DataLoader:
159
+ return DataLoader(
160
+ self.val_ds,
161
+ batch_size=self.val_batch_size,
162
+ shuffle=False, # Ensure the data order remains consistent
163
+ num_workers=4,
164
+ persistent_workers=True,
165
+ pin_memory=torch.cuda.is_available(),
166
+ )
167
+
168
+ @staticmethod
169
+ def data_from_copick(copick_config_path):
170
+ root = copick.from_file(copick_config_path)
171
+ nclasses = len(root.pickable_objects) + 1
172
+ data_dicts = []
173
+ target_objects = defaultdict(dict)
174
+ for object in root.pickable_objects:
175
+ if object.is_particle:
176
+ target_objects[object.name]['label'] = object.label
177
+ target_objects[object.name]['radius'] = object.radius
178
+
179
+ data_dicts = []
180
+ for run in tqdm(root.runs[:8]):
181
+ tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
182
+ segmentation = run.get_segmentations(name='paintedPicks', user_id='user0', voxel_size=10, is_multilabel=True)[0].numpy()
183
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
184
+ segmentation[membrane_seg==1] = 1
185
+ data_dicts.append({"image": tomogram, "label": segmentation})
186
+
187
+ return data_dicts, nclasses
188
+
189
+
190
+ def objective(trial: optuna.trial.Trial) -> float:
191
+ args = get_args()
192
+ mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
193
+ tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
194
+ #run_name='test1'
195
+ )
196
+
197
+ # Trainer callbacks
198
+ checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='max')
199
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
200
+
201
+ # Detect distributed training environment
202
+ devices = list(range(args.num_gpus))
203
+
204
+ #channels, strides_pattern, num_res_units = sync_hyperparameters(trial)
205
+ # We optimize the number of layers, strides, and number of residual units
206
+ num_layers = trial.suggest_int("num_layers", 3, 5)
207
+ base_channel = trial.suggest_categorical("base_channel", [8, 16, 32, 64])
208
+ channels = [base_channel * (2 ** i) for i in range(num_layers)]
209
+ num_downsampling_layers = trial.suggest_int("num_downsampling_layers", 1, num_layers - 1)
210
+ strides_pattern = [2] * num_downsampling_layers + [1] * (num_layers - num_downsampling_layers - 1)
211
+ num_res_units = trial.suggest_int("num_res_units", 1, 3)
212
+
213
+ model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=args.learning_rate)
214
+ datamodule = CopickDataModule(args.copick_config_path, args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
215
+ callback = PyTorchLightningPruningCallback(trial, monitor="val_metric")
216
+
217
+ # Priotize performace over precision
218
+ torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
219
+
220
+ # Trainer for distributed training with DDP
221
+ trainer = Trainer(
222
+ max_epochs=args.num_epochs,
223
+ logger=mlf_logger,
224
+ callbacks=[checkpoint_callback, lr_monitor],
225
+ strategy="ddp_spawn",
226
+ accelerator="gpu",
227
+ devices=devices,
228
+ num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
229
+ log_every_n_steps=1
230
+ )
231
+
232
+ hyperparameters = dict(op_num_layers=num_layers, op_base_channel=base_channel, op_num_downsampling_layers=num_downsampling_layers, op_num_res_units=num_res_units)
233
+ trainer.logger.log_hyperparams(hyperparameters)
234
+ trainer.fit(model, datamodule=datamodule)
235
+ callback.check_pruned()
236
+ return trainer.callback_metrics["val_metric"].item()
237
+
238
+
239
+ if __name__ == "__main__":
240
+ args = get_args()
241
+ # MLflow setup
242
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
243
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
244
+ if not password or not username:
245
+ print("Password not found in environment, loading from .env file...")
246
+ load_dotenv() # Loads environment variables from a .env file
247
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
248
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
249
+
250
+ # Check again after loading .env file
251
+ if not password:
252
+ raise ValueError("Password is not set in environment variables or .env file!")
253
+ else:
254
+ print("Password loaded successfully")
255
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
256
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
257
+
258
+ pruner: optuna.pruners.BasePruner = (
259
+ optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
260
+ )
261
+ storage = "sqlite:///example.db"
262
+ study = optuna.create_study(
263
+ study_name="pl_ddp",
264
+ storage=storage,
265
+ direction="maximize",
266
+ load_if_exists=True,
267
+ pruner=pruner
268
+ )
269
+ study.optimize(objective, n_trials=args.num_optuna_trials)
270
+
271
+ # Print the best hyperparameters
272
+ print(f"Best trial: {study.best_trial.value}")
273
+ print(f"Best hyperparameters: {study.best_trial.params}")
@@ -0,0 +1,244 @@
1
+ import os
2
+ import argparse
3
+ from typing import Optional, Union, Tuple, List
4
+ from collections import defaultdict
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ import pytorch_lightning as pl
8
+ from pytorch_lightning import Trainer
9
+ from pytorch_lightning.loggers import MLFlowLogger
10
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
11
+ from pytorch_lightning.strategies import DDPStrategy
12
+ import os
13
+ import copick
14
+ from tqdm import tqdm
15
+ from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
16
+ from dotenv import load_dotenv
17
+ from monai.transforms import (
18
+ Compose,
19
+ EnsureChannelFirstd,
20
+ Orientationd,
21
+ AsDiscrete,
22
+ RandFlipd,
23
+ RandRotate90d,
24
+ NormalizeIntensityd,
25
+ NormalizeIntensityd,
26
+ RandCropByLabelClassesd,
27
+ )
28
+ from monai.networks.nets import UNet
29
+ from monai.losses import TverskyLoss
30
+ from monai.metrics import DiceMetric, ConfusionMatrixMetric
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser(
35
+ description = "Train a 3d U-Net model with PyTorch Lightning supporting distributed training strategies."
36
+ )
37
+ parser.add_argument('--copick_config_path', type=str, default='copick_config_dataportal_10439.json')
38
+ parser.add_argument('--copick_user_name', type=str, default='user0')
39
+ parser.add_argument('--copick_segmentation_name', type=str, default='paintedPicks')
40
+ parser.add_argument('--train_batch_size', type=int, default=1)
41
+ parser.add_argument('--val_batch_size', type=int, default=1)
42
+ parser.add_argument('--num_random_samples_per_batch', type=int, default=16)
43
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
44
+ parser.add_argument('--num_epochs', type=int, default=20)
45
+ parser.add_argument('--num_gpus', type=int, default=1)
46
+ return parser.parse_args()
47
+
48
+
49
+ class Model(pl.LightningModule):
50
+ def __init__(
51
+ self,
52
+ spatial_dims: int = 3,
53
+ in_channels: int = 1,
54
+ out_channels: int = 8,
55
+ channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
56
+ strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
57
+ num_res_units: int = 1,
58
+ lr: float=1e-3):
59
+
60
+ super().__init__()
61
+ self.save_hyperparameters()
62
+
63
+ self.model = UNet(
64
+ spatial_dims=self.hparams.spatial_dims,
65
+ in_channels=self.hparams.in_channels,
66
+ out_channels=self.hparams.out_channels,
67
+ channels=self.hparams.channels,
68
+ strides=self.hparams.strides,
69
+ num_res_units=self.hparams.num_res_units,
70
+ )
71
+ self.loss_fn = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True) # softmax=True for multiclass
72
+ self.metric_fn = DiceMetric(include_background=False, reduction="mean", ignore_empty=True)
73
+
74
+ def forward(self, x):
75
+ return self.model(x)
76
+
77
+ def training_step(self, batch, batch_idx):
78
+ x, y = batch['image'], batch['label']
79
+ y_hat = self(x)
80
+ loss = self.loss_fn(y_hat, y)
81
+ return loss
82
+
83
+ def validation_step(self, batch, batch_idx):
84
+ x, y = batch['image'], batch['label']
85
+ y_hat = self(x)
86
+ metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y_hat)]
87
+ metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i) for i in decollate_batch(y)]
88
+
89
+ # compute metric for current iteration
90
+ self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)
91
+ metrics = self.metric_fn.aggregate(reduction="mean_batch")
92
+ for i,m in enumerate(metrics):
93
+ self.log(f"validation metric class {i+1}", m, prog_bar=True, on_epoch=True, sync_dist=True)
94
+ metric = torch.mean(metrics) # cannot log ndarray
95
+ self.log('val_metric', metric, prog_bar=True, on_epoch=True, sync_dist=True) # sync_dist=True for distributed training
96
+ return {'val_metric': metric}
97
+
98
+ def configure_optimizers(self):
99
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
100
+
101
+
102
+ class CopickDataModule(pl.LightningDataModule):
103
+ def __init__(
104
+ self,
105
+ copick_config_path: str,
106
+ copick_user_name: str,
107
+ copick_segmentation_name: str,
108
+ train_batch_size: int,
109
+ val_batch_size: int,
110
+ num_random_samples_per_batch: int):
111
+
112
+ super().__init__()
113
+ self.train_batch_size = train_batch_size
114
+ self.val_batch_size = val_batch_size
115
+
116
+ self.data_dicts, self.nclasses = self.data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name)
117
+ self.train_files = self.data_dicts[:int(len(self.data_dicts)//2)]
118
+ self.val_files = self.data_dicts[int(len(self.data_dicts)//2):]
119
+ print(f"Number of training samples: {len(self.train_files)}")
120
+ print(f"Number of validation samples: {len(self.val_files)}")
121
+
122
+ # Non-random transforms to be cached
123
+ self.non_random_transforms = Compose([
124
+ EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
125
+ NormalizeIntensityd(keys="image"),
126
+ Orientationd(keys=["image", "label"], axcodes="RAS")
127
+ ])
128
+
129
+ # Random transforms to be applied during training
130
+ self.random_transforms = Compose([
131
+ RandCropByLabelClassesd(
132
+ keys=["image", "label"],
133
+ label_key="label",
134
+ spatial_size=[96, 96, 96],
135
+ num_classes=self.nclasses,
136
+ num_samples=num_random_samples_per_batch
137
+ ),
138
+ RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
139
+ RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
140
+ ])
141
+
142
+ def setup(self, stage: Optional[str] = None) -> None:
143
+ self.train_ds = CacheDataset(data=self.train_files, transform=self.non_random_transforms, cache_rate=1.0)
144
+ self.train_ds = Dataset(data=self.train_ds, transform=self.random_transforms)
145
+ self.val_ds = CacheDataset(data=self.val_files, transform=self.non_random_transforms, cache_rate=1.0)
146
+ self.val_ds = Dataset(data=self.val_ds, transform=self.random_transforms)
147
+
148
+ def train_dataloader(self) -> DataLoader:
149
+ return DataLoader(
150
+ self.train_ds,
151
+ batch_size=self.train_batch_size,
152
+ shuffle=True,
153
+ num_workers=4,
154
+ pin_memory=torch.cuda.is_available()
155
+ )
156
+
157
+ def val_dataloader(self) -> DataLoader:
158
+ return DataLoader(
159
+ self.val_ds,
160
+ batch_size=self.val_batch_size,
161
+ num_workers=4,
162
+ pin_memory=torch.cuda.is_available(),
163
+ shuffle=False, # Ensure the data order remains consistent
164
+ )
165
+
166
+ @staticmethod
167
+ def data_from_copick(copick_config_path, copick_user_name, copick_segmentation_name):
168
+ root = copick.from_file(copick_config_path)
169
+ nclasses = len(root.pickable_objects) + 1
170
+ data_dicts = []
171
+ target_objects = defaultdict(dict)
172
+ for object in root.pickable_objects:
173
+ if object.is_particle:
174
+ target_objects[object.name]['label'] = object.label
175
+ target_objects[object.name]['radius'] = object.radius
176
+
177
+ data_dicts = []
178
+ for run in tqdm(root.runs[:2]):
179
+ tomogram = run.get_voxel_spacing(10).get_tomogram('wbp').numpy()
180
+ segmentation = run.get_segmentations(name=copick_segmentation_name, user_id=copick_user_name, voxel_size=10, is_multilabel=True)[0].numpy()
181
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")
182
+ if membrane_seg:
183
+ membrane_seg = run.get_segmentations(name='membrane', user_id="data-portal")[0].numpy()
184
+ segmentation[membrane_seg==1]=1
185
+ data_dicts.append({"image": tomogram, "label": segmentation})
186
+
187
+ return data_dicts, nclasses
188
+
189
+
190
+ def train():
191
+ args = get_args()
192
+ mlf_logger = MLFlowLogger(experiment_name='training-3D-UNet-model-for-the-cryoET-ML-Challenge',
193
+ tracking_uri='http://mlflow.mlflow.svc.cluster.local:5000',
194
+ #run_name='test1'
195
+ )
196
+ # Trainer callbacks
197
+ checkpoint_callback = ModelCheckpoint(monitor='val_metric', save_top_k=1, mode='min')
198
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
199
+
200
+ # Detect distributed training environment
201
+ devices = list(range(args.num_gpus))
202
+
203
+ # Initialize model
204
+ model = Model(lr=args.learning_rate)
205
+ datamodule = CopickDataModule(args.copick_config_path, args.copick_user_name, args.copick_segmentation_name,
206
+ args.train_batch_size, args.val_batch_size, args.num_random_samples_per_batch)
207
+
208
+ # Priotize performace over precision
209
+ torch.set_float32_matmul_precision('medium') # or torch.set_float32_matmul_precision('high')
210
+
211
+ # Trainer for distributed training with DDP
212
+ trainer = Trainer(
213
+ max_epochs=args.num_epochs,
214
+ logger=mlf_logger,
215
+ callbacks=[checkpoint_callback, lr_monitor],
216
+ strategy=DDPStrategy(find_unused_parameters=False),
217
+ accelerator="gpu",
218
+ devices=devices,
219
+ num_nodes=1, #int(os.environ.get("WORLD_SIZE", 1)) // args.num_gpus,
220
+ log_every_n_steps=1
221
+ )
222
+
223
+ trainer.fit(model, datamodule=datamodule)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ # MLflow setup
228
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
229
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
230
+ if not password or not username:
231
+ print("Password not found in environment, loading from .env file...")
232
+ load_dotenv() # Loads environment variables from a .env file
233
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
234
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
235
+
236
+ # Check again after loading .env file
237
+ if not password:
238
+ raise ValueError("Password is not set in environment variables or .env file!")
239
+ else:
240
+ print("Password loaded successfully")
241
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
242
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
243
+
244
+ train()
File without changes
octopi/utils/config.py ADDED
@@ -0,0 +1,57 @@
1
+ """
2
+ Configuration utilities for MLflow setup and reproducibility.
3
+ """
4
+
5
+ from dotenv import load_dotenv
6
+ import torch, numpy as np
7
+ import os, random
8
+ import octopi
9
+
10
+
11
+ def mlflow_setup():
12
+ """
13
+ Set up MLflow configuration from environment variables.
14
+ """
15
+ module_root = os.path.dirname(octopi.__file__)
16
+ dotenv_path = module_root.replace('src/octopi','') + '.env'
17
+ load_dotenv(dotenv_path=dotenv_path)
18
+
19
+ # MLflow setup
20
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
21
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
22
+ if not password or not username:
23
+ print("Password not found in environment, loading from .env file...")
24
+ load_dotenv() # Loads environment variables from a .env file
25
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
26
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
27
+
28
+ # Check again after loading .env file
29
+ if not password:
30
+ raise ValueError("Password is not set in environment variables or .env file!")
31
+ else:
32
+ print("Password loaded successfully")
33
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
34
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
35
+
36
+ return os.getenv('MLFLOW_TRACKING_URI')
37
+
38
+
39
+ def set_seed(seed):
40
+ """
41
+ Set random seeds for reproducibility across Python, NumPy, and PyTorch.
42
+ """
43
+ # Set the seed for Python's random module
44
+ random.seed(seed)
45
+
46
+ # Set the seed for NumPy
47
+ np.random.seed(seed)
48
+
49
+ # Set the seed for PyTorch (both CPU and GPU)
50
+ torch.manual_seed(seed)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed(seed)
53
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU
54
+
55
+ # Ensure reproducibility of operations by disabling certain optimizations
56
+ torch.backends.cudnn.deterministic = True
57
+ torch.backends.cudnn.benchmark = False