octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
from octopi import visualization_tools as viz
|
|
2
|
+
from monai.inferers import sliding_window_inference
|
|
3
|
+
from octopi import stopping_criteria
|
|
4
|
+
from monai.transforms import AsDiscrete
|
|
5
|
+
from monai.data import decollate_batch
|
|
6
|
+
import torch, os, mlflow, re
|
|
7
|
+
import torch_ema as ema
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
# Not Ideal, but Necessary if Class is Missing From Dataset
|
|
12
|
+
import warnings
|
|
13
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
14
|
+
|
|
15
|
+
class ModelTrainer:
|
|
16
|
+
|
|
17
|
+
def __init__(self,
|
|
18
|
+
model,
|
|
19
|
+
device,
|
|
20
|
+
loss_function,
|
|
21
|
+
metrics_function,
|
|
22
|
+
optimizer,
|
|
23
|
+
use_ema: bool = True):
|
|
24
|
+
|
|
25
|
+
self.model = model
|
|
26
|
+
self.device = device
|
|
27
|
+
self.loss_function = loss_function
|
|
28
|
+
self.metrics_function = metrics_function
|
|
29
|
+
self.optimizer = optimizer
|
|
30
|
+
|
|
31
|
+
self.parallel_mlflow = False
|
|
32
|
+
self.client = None
|
|
33
|
+
self.trial_run_id = None
|
|
34
|
+
|
|
35
|
+
# Default F-Beta Value
|
|
36
|
+
self.beta = 2
|
|
37
|
+
|
|
38
|
+
# Initialize EMAHandler for the model
|
|
39
|
+
self.ema_experiment = use_ema
|
|
40
|
+
if self.ema_experiment:
|
|
41
|
+
self.ema_handler = ema.ExponentialMovingAverage(self.model.parameters(), decay=0.99)
|
|
42
|
+
|
|
43
|
+
# Initialize Figure and Axes for Plotting
|
|
44
|
+
self.fig = None; self.axs = None
|
|
45
|
+
|
|
46
|
+
def set_parallel_mlflow(self,
|
|
47
|
+
client,
|
|
48
|
+
trial_run_id):
|
|
49
|
+
|
|
50
|
+
self.parallel_mlflow = True
|
|
51
|
+
self.client = client
|
|
52
|
+
self.trial_run_id = trial_run_id
|
|
53
|
+
|
|
54
|
+
def train_update(self):
|
|
55
|
+
|
|
56
|
+
step = 0
|
|
57
|
+
epoch_loss = 0
|
|
58
|
+
self.model.train()
|
|
59
|
+
for batch_data in self.train_loader:
|
|
60
|
+
step += 1
|
|
61
|
+
inputs = batch_data["image"].to(self.device) # Shape: [B, C, H, W, D]
|
|
62
|
+
labels = batch_data["label"].to(self.device) # Shape: [B, C, H, W, D]
|
|
63
|
+
self.optimizer.zero_grad()
|
|
64
|
+
outputs = self.model(inputs) # Output shape: [B, num_classes, H, W, D]
|
|
65
|
+
loss = self.loss_function(outputs, labels)
|
|
66
|
+
loss.backward()
|
|
67
|
+
self.optimizer.step()
|
|
68
|
+
|
|
69
|
+
# Update EMA weights
|
|
70
|
+
if self.ema_experiment:
|
|
71
|
+
self.ema_handler.update()
|
|
72
|
+
|
|
73
|
+
# Update running epoch loss
|
|
74
|
+
epoch_loss += loss.item()
|
|
75
|
+
|
|
76
|
+
# Compute and log average epoch loss
|
|
77
|
+
epoch_loss /= step
|
|
78
|
+
return epoch_loss
|
|
79
|
+
|
|
80
|
+
def validate_update(self):
|
|
81
|
+
"""
|
|
82
|
+
Perform validation and compute metrics, including validation loss.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# Set model to evaluation mode
|
|
86
|
+
self.model.eval()
|
|
87
|
+
val_loss = 0
|
|
88
|
+
with torch.no_grad():
|
|
89
|
+
for val_data in self.val_loader:
|
|
90
|
+
val_inputs = val_data["image"].to(self.device)
|
|
91
|
+
val_labels = val_data["label"].to(self.device)
|
|
92
|
+
|
|
93
|
+
# Apply sliding window inference
|
|
94
|
+
# roi_size=self.input_dim, # try setting a set size of 128, 144 or 160?
|
|
95
|
+
val_outputs = sliding_window_inference(
|
|
96
|
+
inputs=val_inputs,
|
|
97
|
+
roi_size=(144,144,144),
|
|
98
|
+
sw_batch_size=4,
|
|
99
|
+
predictor=self.model,
|
|
100
|
+
overlap=0.5,
|
|
101
|
+
device=self.device
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Compute the loss for this batch
|
|
105
|
+
loss = self.loss_function(val_outputs, val_labels) # Assuming self.loss_function is defined
|
|
106
|
+
val_loss += loss.item() # Accumulate the loss
|
|
107
|
+
|
|
108
|
+
# Apply post-processing
|
|
109
|
+
metric_val_outputs = [self.post_pred(i) for i in decollate_batch(val_outputs)]
|
|
110
|
+
metric_val_labels = [self.post_label(i) for i in decollate_batch(val_labels)]
|
|
111
|
+
|
|
112
|
+
# Compute metrics
|
|
113
|
+
self.metrics_function(y_pred=metric_val_outputs, y=metric_val_labels)
|
|
114
|
+
|
|
115
|
+
# # Contains recall, precision, and f1 for each class
|
|
116
|
+
metric_values = self.metrics_function.aggregate(reduction='mean_batch')
|
|
117
|
+
|
|
118
|
+
# Compute average validation loss and add to metrics dictionary
|
|
119
|
+
val_loss /= len(self.val_loader)
|
|
120
|
+
metric_values.append(val_loss)
|
|
121
|
+
|
|
122
|
+
return metric_values
|
|
123
|
+
|
|
124
|
+
def train(
|
|
125
|
+
self,
|
|
126
|
+
data_load_gen,
|
|
127
|
+
model_save_path: str = 'results',
|
|
128
|
+
my_num_samples: int = 15,
|
|
129
|
+
crop_size: int = 96,
|
|
130
|
+
max_epochs: int = 100,
|
|
131
|
+
val_interval: int = 15,
|
|
132
|
+
lr_scheduler_type: str = 'cosine',
|
|
133
|
+
best_metric: str = 'avg_f1',
|
|
134
|
+
use_mlflow: bool = False,
|
|
135
|
+
verbose: bool = False
|
|
136
|
+
):
|
|
137
|
+
|
|
138
|
+
# best lr scheduler options are cosine or reduce
|
|
139
|
+
self.warmup_epochs = 5
|
|
140
|
+
self.warmup_lr_factor = 0.1
|
|
141
|
+
self.min_lr = 1e-6
|
|
142
|
+
|
|
143
|
+
self.max_epochs = max_epochs
|
|
144
|
+
self.crop_size = crop_size
|
|
145
|
+
self.num_samples = my_num_samples
|
|
146
|
+
self.val_interval = val_interval
|
|
147
|
+
self.use_mlflow = use_mlflow
|
|
148
|
+
|
|
149
|
+
# Create Save Folder if It Doesn't Exist
|
|
150
|
+
if model_save_path is not None:
|
|
151
|
+
os.makedirs(model_save_path, exist_ok=True)
|
|
152
|
+
|
|
153
|
+
Nclass = data_load_gen.Nclasses
|
|
154
|
+
self.create_results_dictionary(Nclass)
|
|
155
|
+
|
|
156
|
+
# Resolve the best metric
|
|
157
|
+
best_metric = self.resolve_best_metric(best_metric)
|
|
158
|
+
|
|
159
|
+
# Stopping Criteria
|
|
160
|
+
self.stopping_criteria = stopping_criteria.EarlyStoppingChecker(monitor_metric=best_metric, val_interval=val_interval)
|
|
161
|
+
|
|
162
|
+
self.post_pred = AsDiscrete(argmax=True, to_onehot=Nclass)
|
|
163
|
+
self.post_label = AsDiscrete(to_onehot=Nclass)
|
|
164
|
+
|
|
165
|
+
# Produce Dataloaders for the First Training Iteration
|
|
166
|
+
self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(crop_size=crop_size, num_samples=my_num_samples)
|
|
167
|
+
self.input_dim = data_load_gen.input_dim
|
|
168
|
+
|
|
169
|
+
# Save the original learning rate
|
|
170
|
+
original_lr = self.optimizer.param_groups[0]['lr']
|
|
171
|
+
self.load_learning_rate_scheduler(lr_scheduler_type)
|
|
172
|
+
|
|
173
|
+
# Initialize tqdm around the epoch loop
|
|
174
|
+
for epoch in tqdm(range(max_epochs), desc="Training Progress", unit="epoch"):
|
|
175
|
+
|
|
176
|
+
# Reload dataloaders periodically
|
|
177
|
+
if data_load_gen.reload_frequency > 0 and (epoch + 1) % data_load_gen.reload_frequency == 0:
|
|
178
|
+
self.train_loader, self.val_loader = data_load_gen.create_train_dataloaders(num_samples=my_num_samples)
|
|
179
|
+
# Lower the learning rate for the warm-up period
|
|
180
|
+
for param_group in self.optimizer.param_groups:
|
|
181
|
+
param_group['lr'] = original_lr * self.warmup_lr_factor
|
|
182
|
+
|
|
183
|
+
# Compute and log average epoch loss
|
|
184
|
+
epoch_loss = self.train_update()
|
|
185
|
+
|
|
186
|
+
# Check for NaN in the loss
|
|
187
|
+
if self.stopping_criteria.should_stop_training(epoch_loss):
|
|
188
|
+
tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
192
|
+
self.my_log_metrics( metrics_dict={"loss": epoch_loss}, curr_step=epoch + 1 )
|
|
193
|
+
self.my_log_metrics( metrics_dict={"learning_rate": current_lr}, curr_step=epoch + 1 )
|
|
194
|
+
|
|
195
|
+
# Validation and metric logging
|
|
196
|
+
if (epoch + 1) % val_interval == 0 or (epoch + 1) == max_epochs:
|
|
197
|
+
if verbose:
|
|
198
|
+
tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_train_loss: {epoch_loss:.4f}")
|
|
199
|
+
|
|
200
|
+
# Validate the Model with or without EMA
|
|
201
|
+
if self.ema_experiment:
|
|
202
|
+
with self.ema_handler.average_parameters():
|
|
203
|
+
metric_values = self.validate_update()
|
|
204
|
+
else:
|
|
205
|
+
metric_values = self.validate_update()
|
|
206
|
+
|
|
207
|
+
# Log all metrics
|
|
208
|
+
self.my_log_metrics( metrics_dict=metric_values, curr_step=epoch + 1 )
|
|
209
|
+
|
|
210
|
+
# Update tqdm description
|
|
211
|
+
if verbose:
|
|
212
|
+
(avg_f1, avg_recall, avg_precision) = (self.results['avg_f1'][-1][1],
|
|
213
|
+
self.results['avg_recall'][-1][1],
|
|
214
|
+
self.results['avg_precision'][-1][1])
|
|
215
|
+
tqdm.write(f"Epoch {epoch + 1}/{max_epochs}, avg_f1_score: {avg_f1:.4f}, avg_recall: {avg_recall:.4f}, avg_precision: {avg_precision:.4f}")
|
|
216
|
+
|
|
217
|
+
# Reset metrics function
|
|
218
|
+
self.metrics_function.reset()
|
|
219
|
+
|
|
220
|
+
# Save the best model
|
|
221
|
+
if self.results[best_metric][-1][1] > self.results["best_metric"]:
|
|
222
|
+
self.results["best_metric"] = self.results[best_metric][-1][1]
|
|
223
|
+
self.results["best_metric_epoch"] = epoch + 1
|
|
224
|
+
|
|
225
|
+
# Read Model Weights and Save
|
|
226
|
+
if self.ema_experiment:
|
|
227
|
+
with self.ema_handler.average_parameters():
|
|
228
|
+
self.save_model(model_save_path)
|
|
229
|
+
else:
|
|
230
|
+
self.save_model(model_save_path)
|
|
231
|
+
|
|
232
|
+
# Save plot if Local Training Call
|
|
233
|
+
if not self.use_mlflow:
|
|
234
|
+
self.fig, self.axs = viz.plot_training_results(
|
|
235
|
+
self.results,
|
|
236
|
+
save_plot=os.path.join(model_save_path, "net_train_history.png"),
|
|
237
|
+
fig=self.fig,
|
|
238
|
+
axs=self.axs)
|
|
239
|
+
|
|
240
|
+
# After Validation Metrics are Logged, Check for Early Stopping
|
|
241
|
+
if self.stopping_criteria.should_stop_training(epoch_loss, results=self.results, check_metrics=True):
|
|
242
|
+
tqdm.write(f"Training stopped early due to {self.stopping_criteria.get_stopped_reason()}")
|
|
243
|
+
break
|
|
244
|
+
|
|
245
|
+
# Run the learning rate scheduler
|
|
246
|
+
early_stop = self.run_scheduler(data_load_gen, original_lr, epoch, val_interval, lr_scheduler_type)
|
|
247
|
+
if early_stop:
|
|
248
|
+
break
|
|
249
|
+
|
|
250
|
+
return self.results
|
|
251
|
+
|
|
252
|
+
def load_learning_rate_scheduler(self, type: str = 'cosine'):
|
|
253
|
+
"""
|
|
254
|
+
Initialize and return the learning rate scheduler based on the given type.
|
|
255
|
+
"""
|
|
256
|
+
# Configure learning rate scheduler based on the type
|
|
257
|
+
if type == "cosine":
|
|
258
|
+
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
259
|
+
self.optimizer, T_max=self.max_epochs, eta_min=self.min_lr )
|
|
260
|
+
elif type == "onecyle":
|
|
261
|
+
max_lr = 1e-3
|
|
262
|
+
steps_per_epoch = len(self.train_loader)
|
|
263
|
+
self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
264
|
+
self.optimizer, max_lr=max_lr, epochs=self.max_epochs, steps_per_epoch=steps_per_epoch )
|
|
265
|
+
elif type == "reduce":
|
|
266
|
+
mode = "min"
|
|
267
|
+
patience = 3
|
|
268
|
+
factor = 0.5
|
|
269
|
+
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
270
|
+
self.optimizer, mode=mode, patience=patience, factor=factor )
|
|
271
|
+
elif type == 'exponential':
|
|
272
|
+
gamma = 0.9
|
|
273
|
+
self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
|
274
|
+
self.optimizer, gamma=gamma)
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(f"Unsupported scheduler type: {type}")
|
|
277
|
+
|
|
278
|
+
def run_scheduler(
|
|
279
|
+
self,
|
|
280
|
+
data_load_gen,
|
|
281
|
+
original_lr: float,
|
|
282
|
+
epoch: int,
|
|
283
|
+
val_interval: int,
|
|
284
|
+
type: str
|
|
285
|
+
):
|
|
286
|
+
"""
|
|
287
|
+
Manage the learning rate scheduler, including warm-up and normal scheduling.
|
|
288
|
+
"""
|
|
289
|
+
# Apply warm-up logic
|
|
290
|
+
if (epoch + 1) <= self.warmup_epochs:
|
|
291
|
+
for param_group in self.optimizer.param_groups:
|
|
292
|
+
param_group['lr'] = original_lr * self.warmup_lr_factor
|
|
293
|
+
return False # Continue training
|
|
294
|
+
|
|
295
|
+
# Step the scheduler
|
|
296
|
+
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
|
|
297
|
+
self.lr_scheduler.step()
|
|
298
|
+
elif isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and (epoch + 1) % val_interval == 0:
|
|
299
|
+
metric_value = self.results['val_loss'][-1][1]
|
|
300
|
+
self.lr_scheduler.step(metric_value)
|
|
301
|
+
else:
|
|
302
|
+
self.lr_scheduler.step() # Step for other schedulers
|
|
303
|
+
|
|
304
|
+
# Check learning rate for early stopping
|
|
305
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
306
|
+
if current_lr < self.min_lr and type != 'onecycle':
|
|
307
|
+
print(f"Early stopping triggered at epoch {epoch + 1} as learning rate fell below {self.min_lr}.")
|
|
308
|
+
return True # Indicate early stopping
|
|
309
|
+
|
|
310
|
+
return False # Continue training
|
|
311
|
+
|
|
312
|
+
def save_model(self, model_save_path: str):
|
|
313
|
+
|
|
314
|
+
# Store Model Weights as Member Variable
|
|
315
|
+
self.model_weights = self.model.state_dict()
|
|
316
|
+
|
|
317
|
+
# Save Model Weights to *.pth file
|
|
318
|
+
if model_save_path is not None:
|
|
319
|
+
torch.save(self.model_weights, os.path.join(model_save_path, "best_model.pth"))
|
|
320
|
+
|
|
321
|
+
def create_results_dictionary(self, Nclass: int):
|
|
322
|
+
|
|
323
|
+
self.results = {
|
|
324
|
+
'loss': [],
|
|
325
|
+
'val_loss': [],
|
|
326
|
+
'avg_f1': [],
|
|
327
|
+
'avg_recall': [],
|
|
328
|
+
'avg_precision': [],
|
|
329
|
+
'avg_fbeta': [],
|
|
330
|
+
'best_metric': -1, # Initialize as None or a default value
|
|
331
|
+
'best_metric_epoch': -1
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
for i in range(Nclass-1):
|
|
335
|
+
self.results[f'fbeta_class{i+1}'] = []
|
|
336
|
+
self.results[f'f1_class{i+1}'] = []
|
|
337
|
+
self.results[f'recall_class{i+1}'] = []
|
|
338
|
+
self.results[f'precision_class{i+1}'] = []
|
|
339
|
+
|
|
340
|
+
self.metric_names = self.results.keys()
|
|
341
|
+
|
|
342
|
+
def my_log_metrics(
|
|
343
|
+
self,
|
|
344
|
+
metrics_dict: dict,
|
|
345
|
+
curr_step: int,
|
|
346
|
+
):
|
|
347
|
+
|
|
348
|
+
# If metrics_dict contains multiple elements (e.g., recall, precision, f1), process them
|
|
349
|
+
if len(metrics_dict) > 1:
|
|
350
|
+
|
|
351
|
+
# Extract individual metrics
|
|
352
|
+
# (assume metrics_dict contains recall, precision, f1, val_loss in sequence)
|
|
353
|
+
recall, precision, f1s, val_loss = metrics_dict[0], metrics_dict[1], metrics_dict[2], metrics_dict[3]
|
|
354
|
+
|
|
355
|
+
# Log per-class metrics
|
|
356
|
+
metrics_to_log = {}
|
|
357
|
+
for i, (rec, prec, f1) in enumerate(zip(recall, precision, f1s)):
|
|
358
|
+
metrics_to_log[f"recall_class{i+1}"] = rec.item()
|
|
359
|
+
metrics_to_log[f"precision_class{i+1}"] = prec.item()
|
|
360
|
+
metrics_to_log[f"f1_class{i+1}"] = f1.item()
|
|
361
|
+
metrics_to_log[f"fbeta_class{i+1}"] = self.fbeta(prec, rec).item()
|
|
362
|
+
|
|
363
|
+
# Prepare average metrics
|
|
364
|
+
metrics_to_log["avg_recall"] = recall.mean().cpu().item()
|
|
365
|
+
metrics_to_log["avg_precision"] = precision.mean().cpu().item()
|
|
366
|
+
metrics_to_log["avg_f1"] = f1s.mean().cpu().item()
|
|
367
|
+
metrics_to_log["avg_fbeta"] = self.fbeta(precision, recall).mean().cpu().item()
|
|
368
|
+
metrics_to_log['val_loss'] = val_loss
|
|
369
|
+
|
|
370
|
+
# Update metrics_dict for further logging
|
|
371
|
+
metrics_dict = metrics_to_log
|
|
372
|
+
|
|
373
|
+
# Log all metrics (per-class and average metrics)
|
|
374
|
+
for metric_name, value in metrics_dict.items():
|
|
375
|
+
if metric_name not in self.results:
|
|
376
|
+
self.results[metric_name] = []
|
|
377
|
+
self.results[metric_name].append((curr_step, value))
|
|
378
|
+
|
|
379
|
+
# Log to MLflow or client
|
|
380
|
+
if self.client is not None and self.trial_run_id is not None:
|
|
381
|
+
for metric_name, value in metrics_dict.items():
|
|
382
|
+
self.client.log_metric(
|
|
383
|
+
run_id=self.trial_run_id,
|
|
384
|
+
key=metric_name,
|
|
385
|
+
value=value,
|
|
386
|
+
step=curr_step,
|
|
387
|
+
)
|
|
388
|
+
elif self.use_mlflow:
|
|
389
|
+
for metric_name, value in metrics_dict.items():
|
|
390
|
+
mlflow.log_metric(metric_name, value, step=curr_step)
|
|
391
|
+
|
|
392
|
+
def fbeta(self, precision, recall):
|
|
393
|
+
|
|
394
|
+
# Handle division by zero
|
|
395
|
+
numerator = (1 + self.beta**2) * (precision * recall)
|
|
396
|
+
denominator = (self.beta**2 * precision) + recall
|
|
397
|
+
|
|
398
|
+
# Use torch.where to handle zero cases
|
|
399
|
+
result = torch.where(
|
|
400
|
+
denominator > 0,
|
|
401
|
+
numerator / denominator,
|
|
402
|
+
torch.zeros_like(precision)
|
|
403
|
+
)
|
|
404
|
+
return result
|
|
405
|
+
|
|
406
|
+
def my_log_params(
|
|
407
|
+
self,
|
|
408
|
+
params_dict: dict,
|
|
409
|
+
):
|
|
410
|
+
|
|
411
|
+
if self.client is not None and self.trial_run_id is not None:
|
|
412
|
+
for key, value in params_dict.items():
|
|
413
|
+
self.client.log_param(run_id=self.trial_run_id, key=key, value=value)
|
|
414
|
+
else:
|
|
415
|
+
mlflow.log_params(params_dict)
|
|
416
|
+
|
|
417
|
+
# Example input: best_metric = 'fBeta2_class3' or 'fBeta1' or 'f1_class2'
|
|
418
|
+
def resolve_best_metric(self, best_metric):
|
|
419
|
+
fbeta_pattern = r"^fBeta(\d+)(?:_class(\d+))?$" # Matches fBetaX or fBetaX_classY
|
|
420
|
+
match = re.match(fbeta_pattern, best_metric)
|
|
421
|
+
|
|
422
|
+
if match:
|
|
423
|
+
self.beta = int(match.group(1)) # Extract beta value
|
|
424
|
+
class_part = match.group(2)
|
|
425
|
+
if class_part:
|
|
426
|
+
best_metric = f'fbeta_class{class_part}' # fBeta2_class3 → fbeta_class3
|
|
427
|
+
else:
|
|
428
|
+
best_metric = 'avg_fbeta' # fBeta2 → avg_fbeta
|
|
429
|
+
|
|
430
|
+
elif best_metric in self.metric_names:
|
|
431
|
+
pass # It's already a valid metric in the results dict
|
|
432
|
+
|
|
433
|
+
else:
|
|
434
|
+
print(f"'{best_metric}' is not a valid metric. Defaulting to 'avg_f1'.\n")
|
|
435
|
+
best_metric = 'avg_f1'
|
|
436
|
+
|
|
437
|
+
return best_metric
|
|
438
|
+
|
|
File without changes
|