octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# This code is adapted from the copick-utils project,
|
|
2
|
+
# originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/segmentation/segmentation_from_picks.py
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
|
|
5
|
+
# Copyright (c) 2023 The copick-utils authors
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import zarr
|
|
9
|
+
from scipy.ndimage import zoom
|
|
10
|
+
import copick
|
|
11
|
+
|
|
12
|
+
def from_picks(pick,
|
|
13
|
+
seg_volume,
|
|
14
|
+
radius: float = 10.0,
|
|
15
|
+
label_value: int = 1,
|
|
16
|
+
voxel_spacing: float = 10):
|
|
17
|
+
"""
|
|
18
|
+
Paints picks into a segmentation volume as spheres.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
-----------
|
|
22
|
+
pick : copick.models.CopickPicks
|
|
23
|
+
Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
|
|
24
|
+
seg_volume : numpy.ndarray
|
|
25
|
+
3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
|
|
26
|
+
radius : float, optional
|
|
27
|
+
The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
|
|
28
|
+
label_value : int, optional
|
|
29
|
+
The integer value used to label the sphere regions in the segmentation volume. Default is 1.
|
|
30
|
+
voxel_spacing : float, optional
|
|
31
|
+
The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.
|
|
32
|
+
Returns:
|
|
33
|
+
--------
|
|
34
|
+
numpy.ndarray
|
|
35
|
+
The modified segmentation volume with spheres inserted at pick locations.
|
|
36
|
+
"""
|
|
37
|
+
def create_sphere(shape, center, radius, val):
|
|
38
|
+
zc, yc, xc = center
|
|
39
|
+
z, y, x = np.indices(shape)
|
|
40
|
+
distance_sq = (x - xc)**2 + (y - yc)**2 + (z - zc)**2
|
|
41
|
+
sphere = np.zeros(shape, dtype=np.float32)
|
|
42
|
+
sphere[distance_sq <= radius**2] = val
|
|
43
|
+
return sphere
|
|
44
|
+
|
|
45
|
+
def get_relative_target_coordinates(center, delta, shape):
|
|
46
|
+
low = max(int(np.floor(center - delta)), 0)
|
|
47
|
+
high = min(int(np.ceil(center + delta + 1)), shape)
|
|
48
|
+
return low, high
|
|
49
|
+
|
|
50
|
+
# Adjust radius for voxel spacing
|
|
51
|
+
radius_voxel = max(radius / voxel_spacing, 1)
|
|
52
|
+
delta = int(np.ceil(radius_voxel))
|
|
53
|
+
|
|
54
|
+
# Paint each pick as a sphere
|
|
55
|
+
for point in pick.points:
|
|
56
|
+
# Convert the pick's location from angstroms to voxel units
|
|
57
|
+
cx, cy, cz = point.location.x / voxel_spacing, point.location.y / voxel_spacing, point.location.z / voxel_spacing
|
|
58
|
+
|
|
59
|
+
# Calculate subarray bounds
|
|
60
|
+
xLow, xHigh = get_relative_target_coordinates(cx, delta, seg_volume.shape[2])
|
|
61
|
+
yLow, yHigh = get_relative_target_coordinates(cy, delta, seg_volume.shape[1])
|
|
62
|
+
zLow, zHigh = get_relative_target_coordinates(cz, delta, seg_volume.shape[0])
|
|
63
|
+
|
|
64
|
+
# Subarray shape
|
|
65
|
+
subarray_shape = (zHigh - zLow, yHigh - yLow, xHigh - xLow)
|
|
66
|
+
if any(dim <= 0 for dim in subarray_shape):
|
|
67
|
+
continue
|
|
68
|
+
|
|
69
|
+
# Compute the local center of the sphere within the subarray
|
|
70
|
+
local_center = (cz - zLow, cy - yLow, cx - xLow)
|
|
71
|
+
sphere = create_sphere(subarray_shape, local_center, radius_voxel, label_value)
|
|
72
|
+
|
|
73
|
+
# Assign Sphere to Segmentation Target Volume
|
|
74
|
+
seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh] = np.maximum(seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh], sphere)
|
|
75
|
+
|
|
76
|
+
return seg_volume
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def downsample_to_exact_shape(array, target_shape):
|
|
80
|
+
"""
|
|
81
|
+
Downsamples a 3D array to match the target shape using nearest-neighbor interpolation.
|
|
82
|
+
Ensures that the resulting array has the exact target shape.
|
|
83
|
+
"""
|
|
84
|
+
zoom_factors = [t / s for t, s in zip(target_shape, array.shape)]
|
|
85
|
+
return zoom(array, zoom_factors, order=0)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def segmentation_from_picks(radius, painting_segmentation_name, run, voxel_spacing, tomo_type, pickable_object, pick_set, user_id="paintedPicks", session_id="0"):
|
|
89
|
+
"""
|
|
90
|
+
Paints picks from a run into a multiscale segmentation array, representing them as spheres in 3D space.
|
|
91
|
+
|
|
92
|
+
Parameters:
|
|
93
|
+
-----------
|
|
94
|
+
radius : float
|
|
95
|
+
Radius of the spheres in physical units.
|
|
96
|
+
painting_segmentation_name : str
|
|
97
|
+
The name of the segmentation dataset to be created or modified.
|
|
98
|
+
run : copick.Run
|
|
99
|
+
The current Copick run object.
|
|
100
|
+
voxel_spacing : float
|
|
101
|
+
The spacing of the voxels in the tomogram data.
|
|
102
|
+
tomo_type : str
|
|
103
|
+
The type of tomogram to retrieve.
|
|
104
|
+
pickable_object : copick.models.CopickObject
|
|
105
|
+
The object that defines the label value to be used in segmentation.
|
|
106
|
+
pick_set : copick.models.CopickPicks
|
|
107
|
+
The set of picks containing the locations to paint spheres.
|
|
108
|
+
user_id : str, optional
|
|
109
|
+
The ID of the user creating the segmentation. Default is "paintedPicks".
|
|
110
|
+
session_id : str, optional
|
|
111
|
+
The session ID for this segmentation. Default is "0".
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
--------
|
|
115
|
+
copick.Segmentation
|
|
116
|
+
The created or modified segmentation object.
|
|
117
|
+
"""
|
|
118
|
+
# Fetch the tomogram and determine its multiscale structure
|
|
119
|
+
tomogram = run.get_voxel_spacing(voxel_spacing).get_tomogram(tomo_type)
|
|
120
|
+
if not tomogram:
|
|
121
|
+
raise ValueError("Tomogram not found for the given parameters.")
|
|
122
|
+
|
|
123
|
+
# Use copick to create a new segmentation if one does not exist
|
|
124
|
+
segs = run.get_segmentations(user_id=user_id, session_id=session_id, is_multilabel=True, name=painting_segmentation_name, voxel_size=voxel_spacing)
|
|
125
|
+
if len(segs) == 0:
|
|
126
|
+
seg = run.new_segmentation(voxel_spacing, painting_segmentation_name, session_id, True, user_id=user_id)
|
|
127
|
+
else:
|
|
128
|
+
seg = segs[0]
|
|
129
|
+
|
|
130
|
+
segmentation_group = zarr.open(seg.zarr(), mode="a")
|
|
131
|
+
highest_res_name = "0"
|
|
132
|
+
|
|
133
|
+
# Get the highest resolution dimensions and create a new array if necessary
|
|
134
|
+
tomogram_zarr = zarr.open(tomogram.zarr(), "r")
|
|
135
|
+
|
|
136
|
+
highest_res_shape = tomogram_zarr[highest_res_name].shape
|
|
137
|
+
if highest_res_name not in segmentation_group:
|
|
138
|
+
segmentation_group.create(highest_res_name, shape=highest_res_shape, dtype=np.uint16, overwrite=True)
|
|
139
|
+
|
|
140
|
+
# Initialize or load the highest resolution array
|
|
141
|
+
highest_res_seg = segmentation_group[highest_res_name][:]
|
|
142
|
+
highest_res_seg.fill(0)
|
|
143
|
+
|
|
144
|
+
# Paint picks into the highest resolution array
|
|
145
|
+
highest_res_seg = from_picks(pick_set, highest_res_seg, radius, pickable_object.label, voxel_spacing)
|
|
146
|
+
|
|
147
|
+
# Write back the highest resolution data
|
|
148
|
+
segmentation_group[highest_res_name][:] = highest_res_seg
|
|
149
|
+
|
|
150
|
+
# Downsample to create lower resolution scales
|
|
151
|
+
multiscale_metadata = tomogram_zarr.attrs.get('multiscales', [{}])[0].get('datasets', [])
|
|
152
|
+
for level_index, level_metadata in enumerate(multiscale_metadata):
|
|
153
|
+
if level_index == 0:
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
level_name = level_metadata.get("path", str(level_index))
|
|
157
|
+
expected_shape = tuple(tomogram_zarr[level_name].shape)
|
|
158
|
+
|
|
159
|
+
# Compute scaling factors relative to the highest resolution shape
|
|
160
|
+
scaled_array = downsample_to_exact_shape(highest_res_seg, expected_shape)
|
|
161
|
+
|
|
162
|
+
# Create/overwrite the Zarr array for this level
|
|
163
|
+
segmentation_group.create_dataset(level_name, shape=expected_shape, data=scaled_array, dtype=np.uint16, overwrite=True)
|
|
164
|
+
|
|
165
|
+
segmentation_group[level_name][:] = scaled_array
|
|
166
|
+
|
|
167
|
+
return seg
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# This code is adapted from the copick-utils project,
|
|
2
|
+
# originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/writers/write.py
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
|
|
5
|
+
# Copyright (c) 2023 The copick-utils authors
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
def tomogram(
|
|
11
|
+
run,
|
|
12
|
+
input_volume,
|
|
13
|
+
voxel_size=10,
|
|
14
|
+
algorithm="wbp"
|
|
15
|
+
):
|
|
16
|
+
"""
|
|
17
|
+
Writes a volumetric tomogram into an OME-Zarr format within a Copick directory.
|
|
18
|
+
|
|
19
|
+
Parameters:
|
|
20
|
+
-----------
|
|
21
|
+
run : copick.Run
|
|
22
|
+
The current Copick run object.
|
|
23
|
+
input_volume : np.ndarray
|
|
24
|
+
The volumetric tomogram data to be written.
|
|
25
|
+
voxel_size : float, optional
|
|
26
|
+
The size of the voxels in physical units. Default is 10.
|
|
27
|
+
algorithm : str, optional
|
|
28
|
+
The tomographic reconstruction algorithm to use. Default is 'wbp'.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
--------
|
|
32
|
+
copick.Tomogram
|
|
33
|
+
The created or modified tomogram object.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
# Retrieve or create voxel spacing
|
|
37
|
+
voxel_spacing = run.get_voxel_spacing(voxel_size)
|
|
38
|
+
if voxel_spacing is None:
|
|
39
|
+
voxel_spacing = run.new_voxel_spacing(voxel_size=voxel_size)
|
|
40
|
+
|
|
41
|
+
# Check if We Need to Create a New Tomogram for Given Algorithm
|
|
42
|
+
tomogram = voxel_spacing.get_tomogram(algorithm)
|
|
43
|
+
if tomogram is None:
|
|
44
|
+
tomogram = voxel_spacing.new_tomogram(tomo_type=algorithm)
|
|
45
|
+
|
|
46
|
+
# Write the tomogram data
|
|
47
|
+
tomogram.from_numpy(input_volume)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def segmentation(
|
|
51
|
+
run,
|
|
52
|
+
segmentation_volume,
|
|
53
|
+
user_id,
|
|
54
|
+
name="segmentation",
|
|
55
|
+
session_id="0",
|
|
56
|
+
voxel_size=10,
|
|
57
|
+
multilabel=True
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Writes a segmentation into an OME-Zarr format within a Copick directory.
|
|
61
|
+
|
|
62
|
+
Parameters:
|
|
63
|
+
-----------
|
|
64
|
+
run : copick.Run
|
|
65
|
+
The current Copick run object.
|
|
66
|
+
segmentation_volume : np.ndarray
|
|
67
|
+
The segmentation data to be written.
|
|
68
|
+
user_id : str
|
|
69
|
+
The ID of the user creating the segmentation.
|
|
70
|
+
name : str, optional
|
|
71
|
+
The name of the segmentation dataset to be created or modified. Default is 'segmentation'.
|
|
72
|
+
session_id : str, optional
|
|
73
|
+
The session ID for this segmentation. Default is '0'.
|
|
74
|
+
voxel_size : float, optional
|
|
75
|
+
The size of the voxels in physical units. Default is 10.
|
|
76
|
+
multilabel : bool, optional
|
|
77
|
+
Whether the segmentation is a multilabel segmentation. Default is True.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
--------
|
|
81
|
+
copick.Segmentation
|
|
82
|
+
The created or modified segmentation object.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# Retrieve or create a segmentation
|
|
86
|
+
segmentations = run.get_segmentations(name=name, user_id=user_id, session_id=session_id)
|
|
87
|
+
|
|
88
|
+
# If no segmentation exists or no segmentation at the given voxel size, create a new one
|
|
89
|
+
if len(segmentations) == 0 or any(seg.voxel_size != voxel_size for seg in segmentations):
|
|
90
|
+
segmentation = run.new_segmentation(
|
|
91
|
+
voxel_size=voxel_size,
|
|
92
|
+
name=name,
|
|
93
|
+
session_id=session_id,
|
|
94
|
+
is_multilabel=multilabel,
|
|
95
|
+
user_id=user_id
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
# Overwrite the current segmentation at the specified voxel size if it exists
|
|
99
|
+
segmentation = next(seg for seg in segmentations if seg.voxel_size == voxel_size)
|
|
100
|
+
|
|
101
|
+
# Write the segmentation data
|
|
102
|
+
segmentation.from_numpy(segmentation_volume, dtype=np.uint8)
|
|
File without changes
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from monai.losses import FocalLoss, TverskyLoss
|
|
2
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
3
|
+
from octopi.pytorch import trainer
|
|
4
|
+
from mlflow.tracking import MlflowClient
|
|
5
|
+
from octopi.models import common
|
|
6
|
+
from octopi import io, losses
|
|
7
|
+
import torch, mlflow, optuna, gc
|
|
8
|
+
|
|
9
|
+
class BayesianModelSearch:
|
|
10
|
+
|
|
11
|
+
def __init__(self, data_generator, model_type="Unet", parent_run_id=None, parent_run_name=None):
|
|
12
|
+
"""
|
|
13
|
+
Class to handle model creation, training, and optimization.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
data_generator (object): Data generator object containing dataset properties.
|
|
17
|
+
model_type (str): Type of model to build ("UNet", "AttentionUnet").
|
|
18
|
+
"""
|
|
19
|
+
self.data_generator = data_generator
|
|
20
|
+
self.Nclasses = data_generator.Nclasses
|
|
21
|
+
self.device = None
|
|
22
|
+
self.model_type = model_type
|
|
23
|
+
self.model = None
|
|
24
|
+
self.loss_function = None
|
|
25
|
+
self.metrics_function = None
|
|
26
|
+
self.sampling = None
|
|
27
|
+
self.parent_run_id = parent_run_id
|
|
28
|
+
|
|
29
|
+
# Define results directory path
|
|
30
|
+
self.results_dir = f'explore_results_{self.model_type}'
|
|
31
|
+
|
|
32
|
+
def my_build_model(self, trial):
|
|
33
|
+
"""Builds and initializes a model based on Optuna-suggested parameters."""
|
|
34
|
+
|
|
35
|
+
# Build the model
|
|
36
|
+
self.model_builder = common.get_model(self.model_type)
|
|
37
|
+
self.model_builder.bayesian_search(trial, self.Nclasses)
|
|
38
|
+
self.model = self.model_builder.model.to(self.device)
|
|
39
|
+
self.config = self.model_builder.config
|
|
40
|
+
|
|
41
|
+
# Define loss function
|
|
42
|
+
self.loss_function = common.get_loss_function(trial)
|
|
43
|
+
|
|
44
|
+
# Define metrics
|
|
45
|
+
self.metrics_function = ConfusionMatrixMetric(
|
|
46
|
+
include_background=False,
|
|
47
|
+
metric_name=["recall", "precision", "f1 score"],
|
|
48
|
+
reduction="none"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Sample crop size and num_samples
|
|
52
|
+
self.sampling = {
|
|
53
|
+
'crop_size': trial.suggest_int("crop_size", 48, 160, step=16),
|
|
54
|
+
'num_samples': 8
|
|
55
|
+
}
|
|
56
|
+
self.config['dim_in'] = self.sampling['crop_size']
|
|
57
|
+
|
|
58
|
+
def _define_optimizer(self):
|
|
59
|
+
# Define optimizer
|
|
60
|
+
lr0 = 1e-3
|
|
61
|
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr0, weight_decay=1e-5)
|
|
62
|
+
|
|
63
|
+
def _train_model(self, trial, model_trainer, epochs, val_interval, crop_size, num_samples, best_metric):
|
|
64
|
+
"""Handles model training and error handling."""
|
|
65
|
+
try:
|
|
66
|
+
results = model_trainer.train(
|
|
67
|
+
self.data_generator,
|
|
68
|
+
model_save_path=None,
|
|
69
|
+
crop_size=crop_size,
|
|
70
|
+
max_epochs=epochs,
|
|
71
|
+
val_interval=val_interval,
|
|
72
|
+
my_num_samples=num_samples,
|
|
73
|
+
best_metric=best_metric,
|
|
74
|
+
use_mlflow=True,
|
|
75
|
+
verbose=False
|
|
76
|
+
)
|
|
77
|
+
return results['best_metric']
|
|
78
|
+
|
|
79
|
+
except torch.cuda.OutOfMemoryError:
|
|
80
|
+
print(f"[Trial Failed] OOM Error for model={model_trainer.model}, crop_size={crop_size}, num_samples={num_samples}")
|
|
81
|
+
trial.set_user_attr("out_of_memory", True)
|
|
82
|
+
raise optuna.TrialPruned()
|
|
83
|
+
|
|
84
|
+
except Exception as e:
|
|
85
|
+
print(f"[Trial Failed] Unexpected error: {e}")
|
|
86
|
+
trial.set_user_attr("error", str(e))
|
|
87
|
+
raise optuna.TrialPruned()
|
|
88
|
+
|
|
89
|
+
def objective(self, trial, epochs, device, val_interval=15, best_metric="avg_f1"):
|
|
90
|
+
"""Runs the full training process for a given trial."""
|
|
91
|
+
|
|
92
|
+
# Set device
|
|
93
|
+
self.device = device
|
|
94
|
+
|
|
95
|
+
# Set a unique run name for each trial
|
|
96
|
+
trial_num = f"trial_{trial.number}"
|
|
97
|
+
|
|
98
|
+
# Start MLflow run
|
|
99
|
+
with mlflow.start_run(run_name=trial_num, nested=True):
|
|
100
|
+
|
|
101
|
+
# Build model
|
|
102
|
+
self.my_build_model(trial)
|
|
103
|
+
|
|
104
|
+
# Create trainer
|
|
105
|
+
self._define_optimizer()
|
|
106
|
+
model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
|
|
107
|
+
|
|
108
|
+
# Train model and evaluate score
|
|
109
|
+
score = self._train_model(
|
|
110
|
+
trial, model_trainer, epochs, val_interval,
|
|
111
|
+
self.sampling['crop_size'], self.sampling['num_samples'],
|
|
112
|
+
best_metric)
|
|
113
|
+
|
|
114
|
+
# Log parameters and metrics
|
|
115
|
+
params = {
|
|
116
|
+
'model': self.model_builder.get_model_parameters(),
|
|
117
|
+
'optimizer': io.get_optimizer_parameters(model_trainer)
|
|
118
|
+
}
|
|
119
|
+
model_trainer.my_log_params(io.flatten_params(params))
|
|
120
|
+
|
|
121
|
+
# Explicitly set the parent run ID
|
|
122
|
+
mlflow.log_param("parent_run_id", self.parent_run_id)
|
|
123
|
+
mlflow.log_param("parent_run_name", self.parent_run_name)
|
|
124
|
+
|
|
125
|
+
# Save best model
|
|
126
|
+
self._save_best_model(trial, model_trainer, score)
|
|
127
|
+
|
|
128
|
+
# Cleanup
|
|
129
|
+
self.cleanup(model_trainer, self.optimizer)
|
|
130
|
+
|
|
131
|
+
return score
|
|
132
|
+
|
|
133
|
+
def _setup_parallel_trial_run(self, trial, parent_run=None, gpu_count=1):
|
|
134
|
+
"""Set up parallel MLflow runs and assign GPU for the trial."""
|
|
135
|
+
|
|
136
|
+
trial_num = f"trial_{trial.number}"
|
|
137
|
+
|
|
138
|
+
# Create a child run under the parent MLflow experiment
|
|
139
|
+
mlflow_client = MlflowClient()
|
|
140
|
+
self.trial_run = mlflow_client.create_run(
|
|
141
|
+
experiment_id=mlflow_client.get_run(parent_run.info.run_id).info.experiment_id,
|
|
142
|
+
tags={"mlflow.parentRunId": parent_run.info.run_id},
|
|
143
|
+
run_name=trial_num
|
|
144
|
+
)
|
|
145
|
+
target_run_id = self.trial_run.info.run_id
|
|
146
|
+
print(f"Logging trial {trial.number} data to MLflow run: {target_run_id}")
|
|
147
|
+
|
|
148
|
+
# Assign GPU device
|
|
149
|
+
if gpu_count > 1:
|
|
150
|
+
gpu_id = trial.number % gpu_count
|
|
151
|
+
device = torch.device(f"cuda:{gpu_id}")
|
|
152
|
+
torch.cuda.set_device(device)
|
|
153
|
+
else:
|
|
154
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
155
|
+
|
|
156
|
+
return device, mlflow_client, target_run_id
|
|
157
|
+
|
|
158
|
+
def multi_gpu_objective(self, parent_run, trial, epochs, val_interval=5, best_metric="avg_f1", gpu_count=1):
|
|
159
|
+
"""
|
|
160
|
+
Trains model on multiple GPUs using a parent MLflow run.
|
|
161
|
+
"""
|
|
162
|
+
self.device, self.client, self.target_run_id = self._setup_parallel_trial_run(trial, parent_run, gpu_count)
|
|
163
|
+
|
|
164
|
+
# Build model
|
|
165
|
+
self.my_build_model(trial)
|
|
166
|
+
|
|
167
|
+
# Create trainer
|
|
168
|
+
self._define_optimizer()
|
|
169
|
+
model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
|
|
170
|
+
model_trainer.set_parallel_mlflow(self.client, self.target_run_id)
|
|
171
|
+
|
|
172
|
+
# # Train model
|
|
173
|
+
# score = self._train_model(
|
|
174
|
+
# trial, model_trainer, epochs, val_interval,
|
|
175
|
+
# self.sampling['crop_size'], self.sampling['num_samples'],
|
|
176
|
+
# best_metric)
|
|
177
|
+
|
|
178
|
+
# Train Model, with error handling
|
|
179
|
+
try:
|
|
180
|
+
score = self._train_model(
|
|
181
|
+
trial, model_trainer, epochs, val_interval,
|
|
182
|
+
self.sampling['crop_size'], self.sampling['num_samples'],
|
|
183
|
+
best_metric)
|
|
184
|
+
except Exception as e:
|
|
185
|
+
print(f"[Trial Failed] Unexpected error: {e}")
|
|
186
|
+
trial.set_user_attr("error", str(e))
|
|
187
|
+
raise optuna.TrialPruned()
|
|
188
|
+
|
|
189
|
+
# Log training parameters
|
|
190
|
+
params = {
|
|
191
|
+
'model': self.model_builder.get_model_parameters(),
|
|
192
|
+
'optimizer': io.get_optimizer_parameters(model_trainer)
|
|
193
|
+
}
|
|
194
|
+
model_trainer.my_log_params(io.flatten_params(params))
|
|
195
|
+
model_trainer.my_log_params({"parent_run_name": parent_run.info.run_name})
|
|
196
|
+
|
|
197
|
+
# Save best model
|
|
198
|
+
self._save_best_model(trial, model_trainer, score)
|
|
199
|
+
|
|
200
|
+
# Cleanup
|
|
201
|
+
self.cleanup(model_trainer, self.optimizer)
|
|
202
|
+
return score
|
|
203
|
+
|
|
204
|
+
def _save_best_model(self, trial, model_trainer, score):
|
|
205
|
+
"""Saves the best model if it improves upon previous scores."""
|
|
206
|
+
best_score_so_far = self.get_best_score(trial)
|
|
207
|
+
if score > best_score_so_far:
|
|
208
|
+
torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
|
|
209
|
+
io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
|
|
210
|
+
f'{self.results_dir}/best_model_config.yaml')
|
|
211
|
+
|
|
212
|
+
def get_best_score(self, trial):
|
|
213
|
+
"""Retrieve the best score from the trial."""
|
|
214
|
+
try:
|
|
215
|
+
return trial.study.best_value
|
|
216
|
+
except ValueError:
|
|
217
|
+
return 0
|
|
218
|
+
|
|
219
|
+
def cleanup(self, model_trainer, optimizer):
|
|
220
|
+
"""Handles cleanup of resources."""
|
|
221
|
+
|
|
222
|
+
# Delete the trainer and optimizer objects
|
|
223
|
+
del model_trainer, optimizer
|
|
224
|
+
|
|
225
|
+
# If the model object holds GPU memory, delete it explicitly and set it to None
|
|
226
|
+
if hasattr(self, "model"):
|
|
227
|
+
del self.model
|
|
228
|
+
self.model = None
|
|
229
|
+
|
|
230
|
+
# Optional: If your model_builder or other objects hold GPU references, delete them too
|
|
231
|
+
if hasattr(self, "model_builder"):
|
|
232
|
+
del self.model_builder
|
|
233
|
+
self.model_builder = None
|
|
234
|
+
|
|
235
|
+
# Clear the CUDA cache and force garbage collection
|
|
236
|
+
torch.cuda.empty_cache()
|
|
237
|
+
gc.collect()
|
|
238
|
+
|
|
239
|
+
# Try Terminating Multi-GPU Runs, if not run current run (works for single GPU runs)
|
|
240
|
+
try:
|
|
241
|
+
self.client.set_terminated(self.target_run_id, status="FINISHED")
|
|
242
|
+
except:
|
|
243
|
+
mlflow.end_run()
|