octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,167 @@
1
+ # This code is adapted from the copick-utils project,
2
+ # originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/segmentation/segmentation_from_picks.py
3
+ # Licensed under the MIT License.
4
+
5
+ # Copyright (c) 2023 The copick-utils authors
6
+
7
+ import numpy as np
8
+ import zarr
9
+ from scipy.ndimage import zoom
10
+ import copick
11
+
12
+ def from_picks(pick,
13
+ seg_volume,
14
+ radius: float = 10.0,
15
+ label_value: int = 1,
16
+ voxel_spacing: float = 10):
17
+ """
18
+ Paints picks into a segmentation volume as spheres.
19
+
20
+ Parameters:
21
+ -----------
22
+ pick : copick.models.CopickPicks
23
+ Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
24
+ seg_volume : numpy.ndarray
25
+ 3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
26
+ radius : float, optional
27
+ The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
28
+ label_value : int, optional
29
+ The integer value used to label the sphere regions in the segmentation volume. Default is 1.
30
+ voxel_spacing : float, optional
31
+ The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.
32
+ Returns:
33
+ --------
34
+ numpy.ndarray
35
+ The modified segmentation volume with spheres inserted at pick locations.
36
+ """
37
+ def create_sphere(shape, center, radius, val):
38
+ zc, yc, xc = center
39
+ z, y, x = np.indices(shape)
40
+ distance_sq = (x - xc)**2 + (y - yc)**2 + (z - zc)**2
41
+ sphere = np.zeros(shape, dtype=np.float32)
42
+ sphere[distance_sq <= radius**2] = val
43
+ return sphere
44
+
45
+ def get_relative_target_coordinates(center, delta, shape):
46
+ low = max(int(np.floor(center - delta)), 0)
47
+ high = min(int(np.ceil(center + delta + 1)), shape)
48
+ return low, high
49
+
50
+ # Adjust radius for voxel spacing
51
+ radius_voxel = max(radius / voxel_spacing, 1)
52
+ delta = int(np.ceil(radius_voxel))
53
+
54
+ # Paint each pick as a sphere
55
+ for point in pick.points:
56
+ # Convert the pick's location from angstroms to voxel units
57
+ cx, cy, cz = point.location.x / voxel_spacing, point.location.y / voxel_spacing, point.location.z / voxel_spacing
58
+
59
+ # Calculate subarray bounds
60
+ xLow, xHigh = get_relative_target_coordinates(cx, delta, seg_volume.shape[2])
61
+ yLow, yHigh = get_relative_target_coordinates(cy, delta, seg_volume.shape[1])
62
+ zLow, zHigh = get_relative_target_coordinates(cz, delta, seg_volume.shape[0])
63
+
64
+ # Subarray shape
65
+ subarray_shape = (zHigh - zLow, yHigh - yLow, xHigh - xLow)
66
+ if any(dim <= 0 for dim in subarray_shape):
67
+ continue
68
+
69
+ # Compute the local center of the sphere within the subarray
70
+ local_center = (cz - zLow, cy - yLow, cx - xLow)
71
+ sphere = create_sphere(subarray_shape, local_center, radius_voxel, label_value)
72
+
73
+ # Assign Sphere to Segmentation Target Volume
74
+ seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh] = np.maximum(seg_volume[zLow:zHigh, yLow:yHigh, xLow:xHigh], sphere)
75
+
76
+ return seg_volume
77
+
78
+
79
+ def downsample_to_exact_shape(array, target_shape):
80
+ """
81
+ Downsamples a 3D array to match the target shape using nearest-neighbor interpolation.
82
+ Ensures that the resulting array has the exact target shape.
83
+ """
84
+ zoom_factors = [t / s for t, s in zip(target_shape, array.shape)]
85
+ return zoom(array, zoom_factors, order=0)
86
+
87
+
88
+ def segmentation_from_picks(radius, painting_segmentation_name, run, voxel_spacing, tomo_type, pickable_object, pick_set, user_id="paintedPicks", session_id="0"):
89
+ """
90
+ Paints picks from a run into a multiscale segmentation array, representing them as spheres in 3D space.
91
+
92
+ Parameters:
93
+ -----------
94
+ radius : float
95
+ Radius of the spheres in physical units.
96
+ painting_segmentation_name : str
97
+ The name of the segmentation dataset to be created or modified.
98
+ run : copick.Run
99
+ The current Copick run object.
100
+ voxel_spacing : float
101
+ The spacing of the voxels in the tomogram data.
102
+ tomo_type : str
103
+ The type of tomogram to retrieve.
104
+ pickable_object : copick.models.CopickObject
105
+ The object that defines the label value to be used in segmentation.
106
+ pick_set : copick.models.CopickPicks
107
+ The set of picks containing the locations to paint spheres.
108
+ user_id : str, optional
109
+ The ID of the user creating the segmentation. Default is "paintedPicks".
110
+ session_id : str, optional
111
+ The session ID for this segmentation. Default is "0".
112
+
113
+ Returns:
114
+ --------
115
+ copick.Segmentation
116
+ The created or modified segmentation object.
117
+ """
118
+ # Fetch the tomogram and determine its multiscale structure
119
+ tomogram = run.get_voxel_spacing(voxel_spacing).get_tomogram(tomo_type)
120
+ if not tomogram:
121
+ raise ValueError("Tomogram not found for the given parameters.")
122
+
123
+ # Use copick to create a new segmentation if one does not exist
124
+ segs = run.get_segmentations(user_id=user_id, session_id=session_id, is_multilabel=True, name=painting_segmentation_name, voxel_size=voxel_spacing)
125
+ if len(segs) == 0:
126
+ seg = run.new_segmentation(voxel_spacing, painting_segmentation_name, session_id, True, user_id=user_id)
127
+ else:
128
+ seg = segs[0]
129
+
130
+ segmentation_group = zarr.open(seg.zarr(), mode="a")
131
+ highest_res_name = "0"
132
+
133
+ # Get the highest resolution dimensions and create a new array if necessary
134
+ tomogram_zarr = zarr.open(tomogram.zarr(), "r")
135
+
136
+ highest_res_shape = tomogram_zarr[highest_res_name].shape
137
+ if highest_res_name not in segmentation_group:
138
+ segmentation_group.create(highest_res_name, shape=highest_res_shape, dtype=np.uint16, overwrite=True)
139
+
140
+ # Initialize or load the highest resolution array
141
+ highest_res_seg = segmentation_group[highest_res_name][:]
142
+ highest_res_seg.fill(0)
143
+
144
+ # Paint picks into the highest resolution array
145
+ highest_res_seg = from_picks(pick_set, highest_res_seg, radius, pickable_object.label, voxel_spacing)
146
+
147
+ # Write back the highest resolution data
148
+ segmentation_group[highest_res_name][:] = highest_res_seg
149
+
150
+ # Downsample to create lower resolution scales
151
+ multiscale_metadata = tomogram_zarr.attrs.get('multiscales', [{}])[0].get('datasets', [])
152
+ for level_index, level_metadata in enumerate(multiscale_metadata):
153
+ if level_index == 0:
154
+ continue
155
+
156
+ level_name = level_metadata.get("path", str(level_index))
157
+ expected_shape = tuple(tomogram_zarr[level_name].shape)
158
+
159
+ # Compute scaling factors relative to the highest resolution shape
160
+ scaled_array = downsample_to_exact_shape(highest_res_seg, expected_shape)
161
+
162
+ # Create/overwrite the Zarr array for this level
163
+ segmentation_group.create_dataset(level_name, shape=expected_shape, data=scaled_array, dtype=np.uint16, overwrite=True)
164
+
165
+ segmentation_group[level_name][:] = scaled_array
166
+
167
+ return seg
@@ -0,0 +1,102 @@
1
+ # This code is adapted from the copick-utils project,
2
+ # originally available at: https://github.com/copick/copick-utils/blob/main/src/copick_utils/writers/write.py
3
+ # Licensed under the MIT License.
4
+
5
+ # Copyright (c) 2023 The copick-utils authors
6
+
7
+ from typing import Any, Dict, List
8
+ import numpy as np
9
+
10
+ def tomogram(
11
+ run,
12
+ input_volume,
13
+ voxel_size=10,
14
+ algorithm="wbp"
15
+ ):
16
+ """
17
+ Writes a volumetric tomogram into an OME-Zarr format within a Copick directory.
18
+
19
+ Parameters:
20
+ -----------
21
+ run : copick.Run
22
+ The current Copick run object.
23
+ input_volume : np.ndarray
24
+ The volumetric tomogram data to be written.
25
+ voxel_size : float, optional
26
+ The size of the voxels in physical units. Default is 10.
27
+ algorithm : str, optional
28
+ The tomographic reconstruction algorithm to use. Default is 'wbp'.
29
+
30
+ Returns:
31
+ --------
32
+ copick.Tomogram
33
+ The created or modified tomogram object.
34
+ """
35
+
36
+ # Retrieve or create voxel spacing
37
+ voxel_spacing = run.get_voxel_spacing(voxel_size)
38
+ if voxel_spacing is None:
39
+ voxel_spacing = run.new_voxel_spacing(voxel_size=voxel_size)
40
+
41
+ # Check if We Need to Create a New Tomogram for Given Algorithm
42
+ tomogram = voxel_spacing.get_tomogram(algorithm)
43
+ if tomogram is None:
44
+ tomogram = voxel_spacing.new_tomogram(tomo_type=algorithm)
45
+
46
+ # Write the tomogram data
47
+ tomogram.from_numpy(input_volume)
48
+
49
+
50
+ def segmentation(
51
+ run,
52
+ segmentation_volume,
53
+ user_id,
54
+ name="segmentation",
55
+ session_id="0",
56
+ voxel_size=10,
57
+ multilabel=True
58
+ ):
59
+ """
60
+ Writes a segmentation into an OME-Zarr format within a Copick directory.
61
+
62
+ Parameters:
63
+ -----------
64
+ run : copick.Run
65
+ The current Copick run object.
66
+ segmentation_volume : np.ndarray
67
+ The segmentation data to be written.
68
+ user_id : str
69
+ The ID of the user creating the segmentation.
70
+ name : str, optional
71
+ The name of the segmentation dataset to be created or modified. Default is 'segmentation'.
72
+ session_id : str, optional
73
+ The session ID for this segmentation. Default is '0'.
74
+ voxel_size : float, optional
75
+ The size of the voxels in physical units. Default is 10.
76
+ multilabel : bool, optional
77
+ Whether the segmentation is a multilabel segmentation. Default is True.
78
+
79
+ Returns:
80
+ --------
81
+ copick.Segmentation
82
+ The created or modified segmentation object.
83
+ """
84
+
85
+ # Retrieve or create a segmentation
86
+ segmentations = run.get_segmentations(name=name, user_id=user_id, session_id=session_id)
87
+
88
+ # If no segmentation exists or no segmentation at the given voxel size, create a new one
89
+ if len(segmentations) == 0 or any(seg.voxel_size != voxel_size for seg in segmentations):
90
+ segmentation = run.new_segmentation(
91
+ voxel_size=voxel_size,
92
+ name=name,
93
+ session_id=session_id,
94
+ is_multilabel=multilabel,
95
+ user_id=user_id
96
+ )
97
+ else:
98
+ # Overwrite the current segmentation at the specified voxel size if it exists
99
+ segmentation = next(seg for seg in segmentations if seg.voxel_size == voxel_size)
100
+
101
+ # Write the segmentation data
102
+ segmentation.from_numpy(segmentation_volume, dtype=np.uint8)
File without changes
@@ -0,0 +1,243 @@
1
+ from monai.losses import FocalLoss, TverskyLoss
2
+ from monai.metrics import ConfusionMatrixMetric
3
+ from octopi.pytorch import trainer
4
+ from mlflow.tracking import MlflowClient
5
+ from octopi.models import common
6
+ from octopi import io, losses
7
+ import torch, mlflow, optuna, gc
8
+
9
+ class BayesianModelSearch:
10
+
11
+ def __init__(self, data_generator, model_type="Unet", parent_run_id=None, parent_run_name=None):
12
+ """
13
+ Class to handle model creation, training, and optimization.
14
+
15
+ Args:
16
+ data_generator (object): Data generator object containing dataset properties.
17
+ model_type (str): Type of model to build ("UNet", "AttentionUnet").
18
+ """
19
+ self.data_generator = data_generator
20
+ self.Nclasses = data_generator.Nclasses
21
+ self.device = None
22
+ self.model_type = model_type
23
+ self.model = None
24
+ self.loss_function = None
25
+ self.metrics_function = None
26
+ self.sampling = None
27
+ self.parent_run_id = parent_run_id
28
+
29
+ # Define results directory path
30
+ self.results_dir = f'explore_results_{self.model_type}'
31
+
32
+ def my_build_model(self, trial):
33
+ """Builds and initializes a model based on Optuna-suggested parameters."""
34
+
35
+ # Build the model
36
+ self.model_builder = common.get_model(self.model_type)
37
+ self.model_builder.bayesian_search(trial, self.Nclasses)
38
+ self.model = self.model_builder.model.to(self.device)
39
+ self.config = self.model_builder.config
40
+
41
+ # Define loss function
42
+ self.loss_function = common.get_loss_function(trial)
43
+
44
+ # Define metrics
45
+ self.metrics_function = ConfusionMatrixMetric(
46
+ include_background=False,
47
+ metric_name=["recall", "precision", "f1 score"],
48
+ reduction="none"
49
+ )
50
+
51
+ # Sample crop size and num_samples
52
+ self.sampling = {
53
+ 'crop_size': trial.suggest_int("crop_size", 48, 160, step=16),
54
+ 'num_samples': 8
55
+ }
56
+ self.config['dim_in'] = self.sampling['crop_size']
57
+
58
+ def _define_optimizer(self):
59
+ # Define optimizer
60
+ lr0 = 1e-3
61
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr0, weight_decay=1e-5)
62
+
63
+ def _train_model(self, trial, model_trainer, epochs, val_interval, crop_size, num_samples, best_metric):
64
+ """Handles model training and error handling."""
65
+ try:
66
+ results = model_trainer.train(
67
+ self.data_generator,
68
+ model_save_path=None,
69
+ crop_size=crop_size,
70
+ max_epochs=epochs,
71
+ val_interval=val_interval,
72
+ my_num_samples=num_samples,
73
+ best_metric=best_metric,
74
+ use_mlflow=True,
75
+ verbose=False
76
+ )
77
+ return results['best_metric']
78
+
79
+ except torch.cuda.OutOfMemoryError:
80
+ print(f"[Trial Failed] OOM Error for model={model_trainer.model}, crop_size={crop_size}, num_samples={num_samples}")
81
+ trial.set_user_attr("out_of_memory", True)
82
+ raise optuna.TrialPruned()
83
+
84
+ except Exception as e:
85
+ print(f"[Trial Failed] Unexpected error: {e}")
86
+ trial.set_user_attr("error", str(e))
87
+ raise optuna.TrialPruned()
88
+
89
+ def objective(self, trial, epochs, device, val_interval=15, best_metric="avg_f1"):
90
+ """Runs the full training process for a given trial."""
91
+
92
+ # Set device
93
+ self.device = device
94
+
95
+ # Set a unique run name for each trial
96
+ trial_num = f"trial_{trial.number}"
97
+
98
+ # Start MLflow run
99
+ with mlflow.start_run(run_name=trial_num, nested=True):
100
+
101
+ # Build model
102
+ self.my_build_model(trial)
103
+
104
+ # Create trainer
105
+ self._define_optimizer()
106
+ model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
107
+
108
+ # Train model and evaluate score
109
+ score = self._train_model(
110
+ trial, model_trainer, epochs, val_interval,
111
+ self.sampling['crop_size'], self.sampling['num_samples'],
112
+ best_metric)
113
+
114
+ # Log parameters and metrics
115
+ params = {
116
+ 'model': self.model_builder.get_model_parameters(),
117
+ 'optimizer': io.get_optimizer_parameters(model_trainer)
118
+ }
119
+ model_trainer.my_log_params(io.flatten_params(params))
120
+
121
+ # Explicitly set the parent run ID
122
+ mlflow.log_param("parent_run_id", self.parent_run_id)
123
+ mlflow.log_param("parent_run_name", self.parent_run_name)
124
+
125
+ # Save best model
126
+ self._save_best_model(trial, model_trainer, score)
127
+
128
+ # Cleanup
129
+ self.cleanup(model_trainer, self.optimizer)
130
+
131
+ return score
132
+
133
+ def _setup_parallel_trial_run(self, trial, parent_run=None, gpu_count=1):
134
+ """Set up parallel MLflow runs and assign GPU for the trial."""
135
+
136
+ trial_num = f"trial_{trial.number}"
137
+
138
+ # Create a child run under the parent MLflow experiment
139
+ mlflow_client = MlflowClient()
140
+ self.trial_run = mlflow_client.create_run(
141
+ experiment_id=mlflow_client.get_run(parent_run.info.run_id).info.experiment_id,
142
+ tags={"mlflow.parentRunId": parent_run.info.run_id},
143
+ run_name=trial_num
144
+ )
145
+ target_run_id = self.trial_run.info.run_id
146
+ print(f"Logging trial {trial.number} data to MLflow run: {target_run_id}")
147
+
148
+ # Assign GPU device
149
+ if gpu_count > 1:
150
+ gpu_id = trial.number % gpu_count
151
+ device = torch.device(f"cuda:{gpu_id}")
152
+ torch.cuda.set_device(device)
153
+ else:
154
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
155
+
156
+ return device, mlflow_client, target_run_id
157
+
158
+ def multi_gpu_objective(self, parent_run, trial, epochs, val_interval=5, best_metric="avg_f1", gpu_count=1):
159
+ """
160
+ Trains model on multiple GPUs using a parent MLflow run.
161
+ """
162
+ self.device, self.client, self.target_run_id = self._setup_parallel_trial_run(trial, parent_run, gpu_count)
163
+
164
+ # Build model
165
+ self.my_build_model(trial)
166
+
167
+ # Create trainer
168
+ self._define_optimizer()
169
+ model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
170
+ model_trainer.set_parallel_mlflow(self.client, self.target_run_id)
171
+
172
+ # # Train model
173
+ # score = self._train_model(
174
+ # trial, model_trainer, epochs, val_interval,
175
+ # self.sampling['crop_size'], self.sampling['num_samples'],
176
+ # best_metric)
177
+
178
+ # Train Model, with error handling
179
+ try:
180
+ score = self._train_model(
181
+ trial, model_trainer, epochs, val_interval,
182
+ self.sampling['crop_size'], self.sampling['num_samples'],
183
+ best_metric)
184
+ except Exception as e:
185
+ print(f"[Trial Failed] Unexpected error: {e}")
186
+ trial.set_user_attr("error", str(e))
187
+ raise optuna.TrialPruned()
188
+
189
+ # Log training parameters
190
+ params = {
191
+ 'model': self.model_builder.get_model_parameters(),
192
+ 'optimizer': io.get_optimizer_parameters(model_trainer)
193
+ }
194
+ model_trainer.my_log_params(io.flatten_params(params))
195
+ model_trainer.my_log_params({"parent_run_name": parent_run.info.run_name})
196
+
197
+ # Save best model
198
+ self._save_best_model(trial, model_trainer, score)
199
+
200
+ # Cleanup
201
+ self.cleanup(model_trainer, self.optimizer)
202
+ return score
203
+
204
+ def _save_best_model(self, trial, model_trainer, score):
205
+ """Saves the best model if it improves upon previous scores."""
206
+ best_score_so_far = self.get_best_score(trial)
207
+ if score > best_score_so_far:
208
+ torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
209
+ io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
210
+ f'{self.results_dir}/best_model_config.yaml')
211
+
212
+ def get_best_score(self, trial):
213
+ """Retrieve the best score from the trial."""
214
+ try:
215
+ return trial.study.best_value
216
+ except ValueError:
217
+ return 0
218
+
219
+ def cleanup(self, model_trainer, optimizer):
220
+ """Handles cleanup of resources."""
221
+
222
+ # Delete the trainer and optimizer objects
223
+ del model_trainer, optimizer
224
+
225
+ # If the model object holds GPU memory, delete it explicitly and set it to None
226
+ if hasattr(self, "model"):
227
+ del self.model
228
+ self.model = None
229
+
230
+ # Optional: If your model_builder or other objects hold GPU references, delete them too
231
+ if hasattr(self, "model_builder"):
232
+ del self.model_builder
233
+ self.model_builder = None
234
+
235
+ # Clear the CUDA cache and force garbage collection
236
+ torch.cuda.empty_cache()
237
+ gc.collect()
238
+
239
+ # Try Terminating Multi-GPU Runs, if not run current run (works for single GPU runs)
240
+ try:
241
+ self.client.set_terminated(self.target_run_id, status="FINISHED")
242
+ except:
243
+ mlflow.end_run()