spacr 0.0.66__py3-none-any.whl → 0.0.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/annotate_app.py +2 -4
- spacr/core.py +32 -32
- spacr/foldseek.py +6 -6
- spacr/get_alfafold_structures.py +3 -3
- spacr/io.py +53 -50
- spacr/sim.py +24 -29
- spacr/utils.py +18 -78
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/METADATA +10 -8
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/RECORD +13 -16
- spacr/graph_learning_lap.py +0 -84
- spacr/train.py +0 -667
- spacr/umap.py +0 -0
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/LICENSE +0 -0
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/WHEEL +0 -0
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.66.dist-info → spacr-0.0.71.dist-info}/top_level.txt +0 -0
spacr/train.py
DELETED
@@ -1,667 +0,0 @@
|
|
1
|
-
import os, torch, time, gc, datetime
|
2
|
-
import pandas as pd
|
3
|
-
from torch.optim import Adagrad
|
4
|
-
from torch.optim import AdamW
|
5
|
-
from torch.autograd import grad
|
6
|
-
from torch.optim.lr_scheduler import StepLR
|
7
|
-
import torch.nn.functional as F
|
8
|
-
from IPython.display import display, clear_output
|
9
|
-
import difflib
|
10
|
-
|
11
|
-
from .logger import log_function_call
|
12
|
-
|
13
|
-
def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
14
|
-
"""
|
15
|
-
Evaluates the performance of a model on a given data loader.
|
16
|
-
|
17
|
-
Args:
|
18
|
-
model (torch.nn.Module): The model to evaluate.
|
19
|
-
loader (torch.utils.data.DataLoader): The data loader to evaluate the model on.
|
20
|
-
loader_name (str): The name of the data loader.
|
21
|
-
epoch (int): The current epoch number.
|
22
|
-
loss_type (str): The type of loss function to use.
|
23
|
-
|
24
|
-
Returns:
|
25
|
-
data_df (pandas.DataFrame): The classification metrics data as a DataFrame.
|
26
|
-
prediction_pos_probs (list): The positive class probabilities for each prediction.
|
27
|
-
all_labels (list): The true labels for each prediction.
|
28
|
-
"""
|
29
|
-
|
30
|
-
from .utils import calculate_loss, classification_metrics
|
31
|
-
|
32
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
33
|
-
model.eval()
|
34
|
-
loss = 0
|
35
|
-
correct = 0
|
36
|
-
total_samples = 0
|
37
|
-
prediction_pos_probs = []
|
38
|
-
all_labels = []
|
39
|
-
model = model.to(device)
|
40
|
-
with torch.no_grad():
|
41
|
-
for batch_idx, (data, target, _) in enumerate(loader, start=1):
|
42
|
-
start_time = time.time()
|
43
|
-
data, target = data.to(device), target.to(device).float()
|
44
|
-
#data, target = data.to(torch.float).to(device), target.to(device).float()
|
45
|
-
output = model(data)
|
46
|
-
loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
|
47
|
-
loss = calculate_loss(output, target, loss_type=loss_type)
|
48
|
-
loss += loss.item()
|
49
|
-
total_samples += data.size(0)
|
50
|
-
pred = torch.where(output >= 0.5,
|
51
|
-
torch.Tensor([1.0]).to(device).float(),
|
52
|
-
torch.Tensor([0.0]).to(device).float())
|
53
|
-
correct += pred.eq(target.view_as(pred)).sum().item()
|
54
|
-
batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
|
55
|
-
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
56
|
-
all_labels.extend(target.cpu().numpy().tolist())
|
57
|
-
mean_loss = loss / total_samples
|
58
|
-
acc = correct / total_samples
|
59
|
-
end_time = time.time()
|
60
|
-
test_time = end_time - start_time
|
61
|
-
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
62
|
-
loss /= len(loader)
|
63
|
-
data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
|
64
|
-
return data_df, prediction_pos_probs, all_labels
|
65
|
-
|
66
|
-
def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
|
67
|
-
"""
|
68
|
-
Evaluate the performance of a model on given data loaders.
|
69
|
-
|
70
|
-
Args:
|
71
|
-
loaders (list): List of data loaders.
|
72
|
-
model: The model to evaluate.
|
73
|
-
loader_name_list (list): List of names for the data loaders.
|
74
|
-
epoch (int): The current epoch.
|
75
|
-
train_mode (str): The training mode ('erm' or 'irm').
|
76
|
-
loss_type: The type of loss function.
|
77
|
-
|
78
|
-
Returns:
|
79
|
-
tuple: A tuple containing the evaluation result and the time taken for evaluation.
|
80
|
-
"""
|
81
|
-
start_time = time.time()
|
82
|
-
df_list = []
|
83
|
-
if train_mode == 'erm':
|
84
|
-
result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
85
|
-
if train_mode == 'irm':
|
86
|
-
for loader_index in range(0, len(loaders)):
|
87
|
-
loader = loaders[loader_index]
|
88
|
-
loader_name = loader_name_list[loader_index]
|
89
|
-
data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
|
90
|
-
torch.cuda.empty_cache()
|
91
|
-
df_list.append(data_df)
|
92
|
-
result = pd.concat(df_list)
|
93
|
-
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
94
|
-
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
95
|
-
tot_mean = result['accuracy'].mean(skipna=True)
|
96
|
-
loss_mean = result['loss'].mean(skipna=True)
|
97
|
-
prauc_mean = result['prauc'].mean(skipna=True)
|
98
|
-
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
99
|
-
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
100
|
-
end_time = time.time()
|
101
|
-
test_time = end_time - start_time
|
102
|
-
return result, test_time
|
103
|
-
|
104
|
-
def test_model_core(model, loader, loader_name, epoch, loss_type):
|
105
|
-
|
106
|
-
from .utils import calculate_loss, classification_metrics
|
107
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
108
|
-
model.eval()
|
109
|
-
loss = 0
|
110
|
-
correct = 0
|
111
|
-
total_samples = 0
|
112
|
-
prediction_pos_probs = []
|
113
|
-
all_labels = []
|
114
|
-
filenames = []
|
115
|
-
true_targets = []
|
116
|
-
predicted_outputs = []
|
117
|
-
|
118
|
-
model = model.to(device)
|
119
|
-
with torch.no_grad():
|
120
|
-
for batch_idx, (data, target, filename) in enumerate(loader, start=1): # Assuming loader provides filenames
|
121
|
-
start_time = time.time()
|
122
|
-
data, target = data.to(device), target.to(device).float()
|
123
|
-
output = model(data)
|
124
|
-
loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()
|
125
|
-
loss = calculate_loss(output, target, loss_type=loss_type)
|
126
|
-
loss += loss.item()
|
127
|
-
total_samples += data.size(0)
|
128
|
-
pred = torch.where(output >= 0.5,
|
129
|
-
torch.Tensor([1.0]).to(device).float(),
|
130
|
-
torch.Tensor([0.0]).to(device).float())
|
131
|
-
correct += pred.eq(target.view_as(pred)).sum().item()
|
132
|
-
batch_prediction_pos_prob = torch.sigmoid(output).cpu().numpy()
|
133
|
-
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
134
|
-
all_labels.extend(target.cpu().numpy().tolist())
|
135
|
-
|
136
|
-
# Storing intermediate results in lists
|
137
|
-
true_targets.extend(target.cpu().numpy().tolist())
|
138
|
-
predicted_outputs.extend(pred.cpu().numpy().tolist())
|
139
|
-
filenames.extend(filename)
|
140
|
-
|
141
|
-
mean_loss = loss / total_samples
|
142
|
-
acc = correct / total_samples
|
143
|
-
end_time = time.time()
|
144
|
-
test_time = end_time - start_time
|
145
|
-
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
146
|
-
|
147
|
-
# Constructing the DataFrame
|
148
|
-
results_df = pd.DataFrame({
|
149
|
-
'filename': filenames,
|
150
|
-
'true_label': true_targets,
|
151
|
-
'predicted_label': predicted_outputs,
|
152
|
-
'class_1_probability':prediction_pos_probs})
|
153
|
-
|
154
|
-
loss /= len(loader)
|
155
|
-
data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
|
156
|
-
return data_df, prediction_pos_probs, all_labels, results_df
|
157
|
-
|
158
|
-
def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
|
159
|
-
"""
|
160
|
-
Test the performance of a model on given data loaders.
|
161
|
-
|
162
|
-
Args:
|
163
|
-
loaders (list): List of data loaders.
|
164
|
-
model: The model to be tested.
|
165
|
-
loader_name_list (list): List of names for the data loaders.
|
166
|
-
epoch (int): The current epoch.
|
167
|
-
train_mode (str): The training mode ('erm' or 'irm').
|
168
|
-
loss_type: The type of loss function.
|
169
|
-
|
170
|
-
Returns:
|
171
|
-
tuple: A tuple containing the test results and the results dataframe.
|
172
|
-
"""
|
173
|
-
start_time = time.time()
|
174
|
-
df_list = []
|
175
|
-
if train_mode == 'erm':
|
176
|
-
result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
177
|
-
if train_mode == 'irm':
|
178
|
-
for loader_index in range(0, len(loaders)):
|
179
|
-
loader = loaders[loader_index]
|
180
|
-
loader_name = loader_name_list[loader_index]
|
181
|
-
data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
|
182
|
-
torch.cuda.empty_cache()
|
183
|
-
df_list.append(data_df)
|
184
|
-
result = pd.concat(df_list)
|
185
|
-
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
186
|
-
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
187
|
-
tot_mean = result['accuracy'].mean(skipna=True)
|
188
|
-
loss_mean = result['loss'].mean(skipna=True)
|
189
|
-
prauc_mean = result['prauc'].mean(skipna=True)
|
190
|
-
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
191
|
-
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
192
|
-
end_time = time.time()
|
193
|
-
test_time = end_time - start_time
|
194
|
-
return result, results_df
|
195
|
-
|
196
|
-
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
197
|
-
|
198
|
-
from .io import _save_settings, _copy_missclassified
|
199
|
-
from .utils import pick_best_model
|
200
|
-
from .core import generate_loaders
|
201
|
-
|
202
|
-
settings['src'] = src
|
203
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
204
|
-
settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
|
205
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
206
|
-
settings_df.to_csv(settings_csv, index=False)
|
207
|
-
|
208
|
-
if custom_model:
|
209
|
-
model = torch.load(custom_model_path)
|
210
|
-
|
211
|
-
if settings['train']:
|
212
|
-
_save_settings(settings, src)
|
213
|
-
torch.cuda.empty_cache()
|
214
|
-
torch.cuda.memory.empty_cache()
|
215
|
-
gc.collect()
|
216
|
-
dst = os.path.join(src,'model')
|
217
|
-
os.makedirs(dst, exist_ok=True)
|
218
|
-
settings['src'] = src
|
219
|
-
settings['dst'] = dst
|
220
|
-
if settings['train']:
|
221
|
-
train, val, plate_names = generate_loaders(src,
|
222
|
-
train_mode=settings['train_mode'],
|
223
|
-
mode='train',
|
224
|
-
image_size=settings['image_size'],
|
225
|
-
batch_size=settings['batch_size'],
|
226
|
-
classes=settings['classes'],
|
227
|
-
num_workers=settings['num_workers'],
|
228
|
-
validation_split=settings['val_split'],
|
229
|
-
pin_memory=settings['pin_memory'],
|
230
|
-
normalize=settings['normalize'],
|
231
|
-
channels=settings['channels'],
|
232
|
-
verbose=settings['verbose'])
|
233
|
-
|
234
|
-
|
235
|
-
if settings['test']:
|
236
|
-
test, _, plate_names_test = generate_loaders(src,
|
237
|
-
train_mode=settings['train_mode'],
|
238
|
-
mode='test',
|
239
|
-
image_size=settings['image_size'],
|
240
|
-
batch_size=settings['batch_size'],
|
241
|
-
classes=settings['classes'],
|
242
|
-
num_workers=settings['num_workers'],
|
243
|
-
validation_split=0.0,
|
244
|
-
pin_memory=settings['pin_memory'],
|
245
|
-
normalize=settings['normalize'],
|
246
|
-
channels=settings['channels'],
|
247
|
-
verbose=settings['verbose'])
|
248
|
-
if model == None:
|
249
|
-
model_path = pick_best_model(src+'/model')
|
250
|
-
print(f'Best model: {model_path}')
|
251
|
-
|
252
|
-
model = torch.load(model_path, map_location=lambda storage, loc: storage)
|
253
|
-
|
254
|
-
model_type = settings['model_type']
|
255
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
256
|
-
print(type(model))
|
257
|
-
print(model)
|
258
|
-
|
259
|
-
model_fldr = os.path.join(src,'model')
|
260
|
-
time_now = datetime.date.today().strftime('%y%m%d')
|
261
|
-
result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
|
262
|
-
acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
|
263
|
-
print(f'Results wil be saved in: {result_loc}')
|
264
|
-
|
265
|
-
result, accuracy = test_model_performance(loaders=test,
|
266
|
-
model=model,
|
267
|
-
loader_name_list='test',
|
268
|
-
epoch=1,
|
269
|
-
train_mode=settings['train_mode'],
|
270
|
-
loss_type=settings['loss_type'])
|
271
|
-
|
272
|
-
result.to_csv(result_loc, index=True, header=True, mode='w')
|
273
|
-
accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
|
274
|
-
_copy_missclassified(accuracy)
|
275
|
-
else:
|
276
|
-
test = None
|
277
|
-
|
278
|
-
if settings['train']:
|
279
|
-
train_model(dst = settings['dst'],
|
280
|
-
model_type=settings['model_type'],
|
281
|
-
train_loaders = train,
|
282
|
-
train_loader_names = plate_names,
|
283
|
-
train_mode = settings['train_mode'],
|
284
|
-
epochs = settings['epochs'],
|
285
|
-
learning_rate = settings['learning_rate'],
|
286
|
-
init_weights = settings['init_weights'],
|
287
|
-
weight_decay = settings['weight_decay'],
|
288
|
-
amsgrad = settings['amsgrad'],
|
289
|
-
optimizer_type = settings['optimizer_type'],
|
290
|
-
use_checkpoint = settings['use_checkpoint'],
|
291
|
-
dropout_rate = settings['dropout_rate'],
|
292
|
-
num_workers = settings['num_workers'],
|
293
|
-
val_loaders = val,
|
294
|
-
test_loaders = test,
|
295
|
-
intermedeate_save = settings['intermedeate_save'],
|
296
|
-
schedule = settings['schedule'],
|
297
|
-
loss_type=settings['loss_type'],
|
298
|
-
gradient_accumulation=settings['gradient_accumulation'],
|
299
|
-
gradient_accumulation_steps=settings['gradient_accumulation_steps'])
|
300
|
-
|
301
|
-
torch.cuda.empty_cache()
|
302
|
-
torch.cuda.memory.empty_cache()
|
303
|
-
gc.collect()
|
304
|
-
|
305
|
-
def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
|
306
|
-
"""
|
307
|
-
Trains a model using the specified parameters.
|
308
|
-
|
309
|
-
Args:
|
310
|
-
dst (str): The destination path to save the model and results.
|
311
|
-
model_type (str): The type of model to train.
|
312
|
-
train_loaders (list): A list of training data loaders.
|
313
|
-
train_loader_names (list): A list of names for the training data loaders.
|
314
|
-
train_mode (str, optional): The training mode. Defaults to 'erm'.
|
315
|
-
epochs (int, optional): The number of training epochs. Defaults to 100.
|
316
|
-
learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
|
317
|
-
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
|
318
|
-
amsgrad (bool, optional): Whether to use AMSGrad for the optimizer. Defaults to False.
|
319
|
-
optimizer_type (str, optional): The type of optimizer to use. Defaults to 'adamw'.
|
320
|
-
use_checkpoint (bool, optional): Whether to use checkpointing during training. Defaults to False.
|
321
|
-
dropout_rate (float, optional): The dropout rate for the model. Defaults to 0.
|
322
|
-
num_workers (int, optional): The number of workers for data loading. Defaults to 20.
|
323
|
-
val_loaders (list, optional): A list of validation data loaders. Defaults to None.
|
324
|
-
test_loaders (list, optional): A list of test data loaders. Defaults to None.
|
325
|
-
init_weights (str, optional): The initialization weights for the model. Defaults to 'imagenet'.
|
326
|
-
intermedeate_save (list, optional): The intermediate save thresholds. Defaults to None.
|
327
|
-
chan_dict (dict, optional): The channel dictionary. Defaults to None.
|
328
|
-
schedule (str, optional): The learning rate schedule. Defaults to None.
|
329
|
-
loss_type (str, optional): The loss function type. Defaults to 'binary_cross_entropy_with_logits'.
|
330
|
-
gradient_accumulation (bool, optional): Whether to use gradient accumulation. Defaults to False.
|
331
|
-
gradient_accumulation_steps (int, optional): The number of steps for gradient accumulation. Defaults to 4.
|
332
|
-
|
333
|
-
Returns:
|
334
|
-
None
|
335
|
-
"""
|
336
|
-
|
337
|
-
from .io import _save_model, _save_progress
|
338
|
-
from .utils import compute_irm_penalty, calculate_loss, choose_model
|
339
|
-
|
340
|
-
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
341
|
-
|
342
|
-
if test_loaders != None:
|
343
|
-
print(f'Test batches:{len(test_loaders)}')
|
344
|
-
|
345
|
-
use_cuda = torch.cuda.is_available()
|
346
|
-
device = torch.device("cuda" if use_cuda else "cpu")
|
347
|
-
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
|
348
|
-
|
349
|
-
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
350
|
-
batch, channels, height, width = images.shape
|
351
|
-
break
|
352
|
-
|
353
|
-
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
354
|
-
|
355
|
-
if model is None:
|
356
|
-
print(f'Model {model_type} not found')
|
357
|
-
return
|
358
|
-
|
359
|
-
model.to(device)
|
360
|
-
|
361
|
-
if optimizer_type == 'adamw':
|
362
|
-
optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
|
363
|
-
|
364
|
-
if optimizer_type == 'adagrad':
|
365
|
-
optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay)
|
366
|
-
|
367
|
-
if schedule == 'step_lr':
|
368
|
-
StepLR_step_size = int(epochs/5)
|
369
|
-
StepLR_gamma = 0.75
|
370
|
-
scheduler = StepLR(optimizer, step_size=StepLR_step_size, gamma=StepLR_gamma)
|
371
|
-
elif schedule == 'reduce_lr_on_plateau':
|
372
|
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
|
373
|
-
else:
|
374
|
-
scheduler = None
|
375
|
-
|
376
|
-
if train_mode == 'erm':
|
377
|
-
for epoch in range(1, epochs+1):
|
378
|
-
model.train()
|
379
|
-
start_time = time.time()
|
380
|
-
running_loss = 0.0
|
381
|
-
|
382
|
-
# Initialize gradients if using gradient accumulation
|
383
|
-
if gradient_accumulation:
|
384
|
-
optimizer.zero_grad()
|
385
|
-
|
386
|
-
for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
|
387
|
-
data, target = data.to(device), target.to(device).float()
|
388
|
-
output = model(data)
|
389
|
-
loss = calculate_loss(output, target, loss_type=loss_type)
|
390
|
-
# Normalize loss if using gradient accumulation
|
391
|
-
if gradient_accumulation:
|
392
|
-
loss /= gradient_accumulation_steps
|
393
|
-
running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
|
394
|
-
loss.backward()
|
395
|
-
|
396
|
-
# Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
|
397
|
-
if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
|
398
|
-
optimizer.step()
|
399
|
-
optimizer.zero_grad()
|
400
|
-
|
401
|
-
avg_loss = running_loss / batch_idx
|
402
|
-
print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
|
403
|
-
|
404
|
-
end_time = time.time()
|
405
|
-
train_time = end_time - start_time
|
406
|
-
train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
|
407
|
-
train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
|
408
|
-
train_names = 'train'
|
409
|
-
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
|
410
|
-
train_metrics_df['train_test_time'] = train_test_time
|
411
|
-
if val_loaders != None:
|
412
|
-
val_names = 'val'
|
413
|
-
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
|
414
|
-
|
415
|
-
if schedule == 'reduce_lr_on_plateau':
|
416
|
-
val_loss = result['loss']
|
417
|
-
|
418
|
-
results_df = pd.concat([results_df, result])
|
419
|
-
train_metrics_df['val_time'] = val_time
|
420
|
-
if test_loaders != None:
|
421
|
-
test_names = 'test'
|
422
|
-
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
|
423
|
-
results_df = pd.concat([results_df, result])
|
424
|
-
test_time = (train_test_time+val_time+test_test_time)/3
|
425
|
-
train_metrics_df['test_time'] = test_time
|
426
|
-
|
427
|
-
if scheduler:
|
428
|
-
if schedule == 'reduce_lr_on_plateau':
|
429
|
-
scheduler.step(val_loss)
|
430
|
-
if schedule == 'step_lr':
|
431
|
-
scheduler.step()
|
432
|
-
|
433
|
-
_save_progress(dst, results_df, train_metrics_df)
|
434
|
-
clear_output(wait=True)
|
435
|
-
display(results_df)
|
436
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
437
|
-
|
438
|
-
if train_mode == 'irm':
|
439
|
-
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
440
|
-
phi = torch.nn.Parameter (torch.ones(4,1))
|
441
|
-
for epoch in range(1, epochs):
|
442
|
-
model.train()
|
443
|
-
penalty_factor = epoch * 1e-5
|
444
|
-
epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
|
445
|
-
loader_erm_loss_list = []
|
446
|
-
total_erm_loss_mean = 0
|
447
|
-
for loader_index in range(0, len(train_loaders)):
|
448
|
-
start_time = time.time()
|
449
|
-
loader = train_loaders[loader_index]
|
450
|
-
loader_erm_loss_mean = 0
|
451
|
-
batch_count = 0
|
452
|
-
batch_erm_loss_list = []
|
453
|
-
for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
|
454
|
-
optimizer.zero_grad()
|
455
|
-
data, target = data.to(device), target.to(device).float()
|
456
|
-
|
457
|
-
output = model(data)
|
458
|
-
erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
|
459
|
-
|
460
|
-
batch_erm_loss_list.append(erm_loss.mean())
|
461
|
-
print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
|
462
|
-
loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
|
463
|
-
loader_erm_loss_list.append(loader_erm_loss_mean)
|
464
|
-
total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
|
465
|
-
irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
|
466
|
-
|
467
|
-
(total_erm_loss_mean + penalty_factor * irm_loss).backward()
|
468
|
-
optimizer.step()
|
469
|
-
|
470
|
-
end_time = time.time()
|
471
|
-
train_time = end_time - start_time
|
472
|
-
|
473
|
-
train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
|
474
|
-
#train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
|
475
|
-
train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
|
476
|
-
print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
|
477
|
-
|
478
|
-
train_names = [item + '_train' for item in train_loader_names]
|
479
|
-
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
|
480
|
-
train_metrics_df['train_test_time'] = train_test_time
|
481
|
-
|
482
|
-
if val_loaders != None:
|
483
|
-
val_names = [item + '_val' for item in train_loader_names]
|
484
|
-
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
|
485
|
-
|
486
|
-
if schedule == 'reduce_lr_on_plateau':
|
487
|
-
val_loss = result['loss']
|
488
|
-
|
489
|
-
results_df = pd.concat([results_df, result])
|
490
|
-
train_metrics_df['val_time'] = val_time
|
491
|
-
|
492
|
-
if test_loaders != None:
|
493
|
-
test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
|
494
|
-
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
|
495
|
-
results_df = pd.concat([results_df, result])
|
496
|
-
train_metrics_df['test_test_time'] = test_test_time
|
497
|
-
|
498
|
-
if scheduler:
|
499
|
-
if schedule == 'reduce_lr_on_plateau':
|
500
|
-
scheduler.step(val_loss)
|
501
|
-
if schedule == 'step_lr':
|
502
|
-
scheduler.step()
|
503
|
-
|
504
|
-
clear_output(wait=True)
|
505
|
-
display(results_df)
|
506
|
-
_save_progress(dst, results_df, train_metrics_df)
|
507
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
508
|
-
print(f'Saved model: {dst}')
|
509
|
-
return
|
510
|
-
|
511
|
-
def get_submodules(model, prefix=''):
|
512
|
-
submodules = []
|
513
|
-
for name, module in model.named_children():
|
514
|
-
full_name = prefix + ('.' if prefix else '') + name
|
515
|
-
submodules.append(full_name)
|
516
|
-
submodules.extend(get_submodules(module, full_name))
|
517
|
-
return submodules
|
518
|
-
|
519
|
-
def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
520
|
-
import torch
|
521
|
-
import os
|
522
|
-
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
523
|
-
import matplotlib.pyplot as plt
|
524
|
-
import numpy as np
|
525
|
-
from PIL import Image
|
526
|
-
|
527
|
-
use_cuda = torch.cuda.is_available()
|
528
|
-
device = torch.device("cuda" if use_cuda else "cpu")
|
529
|
-
|
530
|
-
# Load the entire model object
|
531
|
-
model = torch.load(model_path)
|
532
|
-
model.to(device)
|
533
|
-
|
534
|
-
# Create directory for saving saliency maps if it does not exist
|
535
|
-
if save_saliency and not os.path.exists(save_dir):
|
536
|
-
os.makedirs(save_dir)
|
537
|
-
|
538
|
-
# Collect all images and their tensors
|
539
|
-
images = []
|
540
|
-
input_tensors = []
|
541
|
-
filenames = []
|
542
|
-
for file in os.listdir(src):
|
543
|
-
image_path = os.path.join(src, file)
|
544
|
-
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
545
|
-
images.append(image)
|
546
|
-
input_tensors.append(input_tensor)
|
547
|
-
filenames.append(file)
|
548
|
-
|
549
|
-
input_tensors = torch.cat(input_tensors).to(device)
|
550
|
-
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
551
|
-
|
552
|
-
# Generate saliency maps
|
553
|
-
cam_generator = SaliencyMapGenerator(model)
|
554
|
-
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
555
|
-
|
556
|
-
# Plot images, saliency maps, and overlays
|
557
|
-
saliency_maps = saliency_maps.cpu().numpy()
|
558
|
-
N = len(images)
|
559
|
-
|
560
|
-
dst = os.path.join(src, 'saliency_maps')
|
561
|
-
os.makedirs(dst, exist_ok=True)
|
562
|
-
|
563
|
-
for i in range(N):
|
564
|
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
565
|
-
|
566
|
-
# Original image
|
567
|
-
axes[0].imshow(images[i])
|
568
|
-
axes[0].axis('off')
|
569
|
-
if class_names:
|
570
|
-
axes[0].set_title(class_names[class_labels[i].item()])
|
571
|
-
|
572
|
-
# Saliency map
|
573
|
-
axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
|
574
|
-
axes[1].axis('off')
|
575
|
-
|
576
|
-
# Overlay
|
577
|
-
overlay = np.array(images[i])
|
578
|
-
axes[2].imshow(overlay)
|
579
|
-
axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
|
580
|
-
axes[2].axis('off')
|
581
|
-
|
582
|
-
plt.tight_layout()
|
583
|
-
plt.show()
|
584
|
-
|
585
|
-
# Save the saliency map if required
|
586
|
-
if save_saliency:
|
587
|
-
saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
|
588
|
-
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
589
|
-
|
590
|
-
def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
591
|
-
import torch
|
592
|
-
import os
|
593
|
-
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
594
|
-
import matplotlib.pyplot as plt
|
595
|
-
import numpy as np
|
596
|
-
from PIL import Image
|
597
|
-
|
598
|
-
use_cuda = torch.cuda.is_available()
|
599
|
-
device = torch.device("cuda" if use_cuda else "cpu")
|
600
|
-
|
601
|
-
# Load the entire model object
|
602
|
-
model = torch.load(model_path)
|
603
|
-
model.to(device)
|
604
|
-
|
605
|
-
# Create directory for saving saliency maps if it does not exist
|
606
|
-
if save_saliency and not os.path.exists(save_dir):
|
607
|
-
os.makedirs(save_dir)
|
608
|
-
|
609
|
-
# Collect all images and their tensors
|
610
|
-
images = []
|
611
|
-
input_tensors = []
|
612
|
-
filenames = []
|
613
|
-
for file in os.listdir(src):
|
614
|
-
if not file.endswith('.png'):
|
615
|
-
continue
|
616
|
-
image_path = os.path.join(src, file)
|
617
|
-
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
618
|
-
images.append(image)
|
619
|
-
input_tensors.append(input_tensor)
|
620
|
-
filenames.append(file)
|
621
|
-
|
622
|
-
input_tensors = torch.cat(input_tensors).to(device)
|
623
|
-
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
624
|
-
|
625
|
-
# Generate saliency maps
|
626
|
-
cam_generator = SaliencyMapGenerator(model)
|
627
|
-
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
628
|
-
|
629
|
-
# Convert saliency maps to numpy arrays
|
630
|
-
saliency_maps = saliency_maps.cpu().numpy()
|
631
|
-
|
632
|
-
N = len(images)
|
633
|
-
|
634
|
-
dst = os.path.join(src, 'saliency_maps')
|
635
|
-
|
636
|
-
for i in range(N):
|
637
|
-
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
|
638
|
-
|
639
|
-
# Original image
|
640
|
-
axes[0].imshow(images[i])
|
641
|
-
axes[0].axis('off')
|
642
|
-
if class_names:
|
643
|
-
axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
|
644
|
-
|
645
|
-
# Saliency Map
|
646
|
-
axes[1].imshow(saliency_maps[i, 0], cmap='hot')
|
647
|
-
axes[1].axis('off')
|
648
|
-
axes[1].set_title("Saliency Map")
|
649
|
-
|
650
|
-
# Overlay
|
651
|
-
overlay = np.array(images[i])
|
652
|
-
overlay = overlay / overlay.max()
|
653
|
-
saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
|
654
|
-
overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
|
655
|
-
axes[2].imshow(overlay)
|
656
|
-
axes[2].axis('off')
|
657
|
-
axes[2].set_title("Overlay")
|
658
|
-
|
659
|
-
plt.tight_layout()
|
660
|
-
plt.show()
|
661
|
-
|
662
|
-
# Save the saliency map if required
|
663
|
-
if save_saliency:
|
664
|
-
os.makedirs(dst, exist_ok=True)
|
665
|
-
saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
|
666
|
-
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
667
|
-
|
spacr/umap.py
DELETED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|