spacr 0.2.4__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +1 -11
- spacr/core.py +277 -349
- spacr/deep_spacr.py +248 -269
- spacr/gui.py +58 -54
- spacr/gui_core.py +689 -535
- spacr/gui_elements.py +1002 -153
- spacr/gui_utils.py +452 -107
- spacr/io.py +158 -91
- spacr/measure.py +199 -151
- spacr/plot.py +159 -47
- spacr/resources/font/open_sans/OFL.txt +93 -0
- spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
- spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
- spacr/resources/font/open_sans/README.txt +100 -0
- spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
- spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
- spacr/resources/icons/logo.pdf +2786 -6
- spacr/resources/icons/logo_spacr.png +0 -0
- spacr/resources/icons/logo_spacr_1.png +0 -0
- spacr/sequencing.py +477 -587
- spacr/settings.py +217 -144
- spacr/utils.py +46 -46
- {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/METADATA +46 -35
- spacr-0.2.8.dist-info/RECORD +100 -0
- {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/WHEEL +1 -1
- spacr-0.2.4.dist-info/RECORD +0 -58
- {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/LICENSE +0 -0
- {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
1
|
import os, torch, time, gc, datetime
|
2
|
+
|
3
|
+
torch.backends.cudnn.benchmark = True
|
4
|
+
|
2
5
|
import numpy as np
|
3
6
|
import pandas as pd
|
4
7
|
from torch.optim import Adagrad, AdamW
|
@@ -8,13 +11,14 @@ import torch.nn.functional as F
|
|
8
11
|
from IPython.display import display, clear_output
|
9
12
|
import matplotlib.pyplot as plt
|
10
13
|
from PIL import Image
|
14
|
+
from sklearn.metrics import auc, precision_recall_curve
|
15
|
+
from multiprocessing import set_start_method
|
16
|
+
#set_start_method('spawn', force=True)
|
11
17
|
|
12
18
|
from .logger import log_function_call
|
13
19
|
from .utils import close_multiprocessing_processes, reset_mp
|
14
|
-
#reset_mp()
|
15
|
-
#close_multiprocessing_processes()
|
16
20
|
|
17
|
-
def
|
21
|
+
def evaluate_model_performance(model, loader, epoch, loss_type):
|
18
22
|
"""
|
19
23
|
Evaluates the performance of a model on a given data loader.
|
20
24
|
|
@@ -31,7 +35,56 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
|
31
35
|
all_labels (list): The true labels for each prediction.
|
32
36
|
"""
|
33
37
|
|
34
|
-
from .utils import calculate_loss
|
38
|
+
from .utils import calculate_loss
|
39
|
+
|
40
|
+
def classification_metrics(all_labels, prediction_pos_probs):
|
41
|
+
"""
|
42
|
+
Calculate classification metrics for binary classification.
|
43
|
+
|
44
|
+
Parameters:
|
45
|
+
- all_labels (list): List of true labels.
|
46
|
+
- prediction_pos_probs (list): List of predicted positive probabilities.
|
47
|
+
- loader_name (str): Name of the data loader.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
- data_df (DataFrame): DataFrame containing the calculated metrics.
|
51
|
+
"""
|
52
|
+
|
53
|
+
if len(all_labels) != len(prediction_pos_probs):
|
54
|
+
raise ValueError(f"all_labels ({len(all_labels)}) and pred_labels ({len(prediction_pos_probs)}) have different lengths")
|
55
|
+
|
56
|
+
unique_labels = np.unique(all_labels)
|
57
|
+
if len(unique_labels) >= 2:
|
58
|
+
pr_labels = np.array(all_labels).astype(int)
|
59
|
+
precision, recall, thresholds = precision_recall_curve(pr_labels, prediction_pos_probs, pos_label=1)
|
60
|
+
pr_auc = auc(recall, precision)
|
61
|
+
thresholds = np.append(thresholds, 0.0)
|
62
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
63
|
+
optimal_idx = np.nanargmax(f1_scores)
|
64
|
+
optimal_threshold = thresholds[optimal_idx]
|
65
|
+
pred_labels = [int(p > 0.5) for p in prediction_pos_probs]
|
66
|
+
if len(unique_labels) < 2:
|
67
|
+
optimal_threshold = 0.5
|
68
|
+
pred_labels = [int(p > optimal_threshold) for p in prediction_pos_probs]
|
69
|
+
pr_auc = np.nan
|
70
|
+
data = {'label': all_labels, 'pred': pred_labels}
|
71
|
+
df = pd.DataFrame(data)
|
72
|
+
pc_df = df[df['label'] == 1.0]
|
73
|
+
nc_df = df[df['label'] == 0.0]
|
74
|
+
correct = df[df['label'] == df['pred']]
|
75
|
+
acc_all = len(correct) / len(df)
|
76
|
+
if len(pc_df) > 0:
|
77
|
+
correct_pc = pc_df[pc_df['label'] == pc_df['pred']]
|
78
|
+
acc_pc = len(correct_pc) / len(pc_df)
|
79
|
+
else:
|
80
|
+
acc_pc = np.nan
|
81
|
+
if len(nc_df) > 0:
|
82
|
+
correct_nc = nc_df[nc_df['label'] == nc_df['pred']]
|
83
|
+
acc_nc = len(correct_nc) / len(nc_df)
|
84
|
+
else:
|
85
|
+
acc_nc = np.nan
|
86
|
+
data_dict = {'accuracy': acc_all, 'neg_accuracy': acc_nc, 'pos_accuracy': acc_pc, 'prauc':pr_auc, 'optimal_threshold':optimal_threshold}
|
87
|
+
return data_dict
|
35
88
|
|
36
89
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37
90
|
model.eval()
|
@@ -61,48 +114,15 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
|
|
61
114
|
acc = correct / total_samples
|
62
115
|
end_time = time.time()
|
63
116
|
test_time = end_time - start_time
|
64
|
-
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
117
|
+
#print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
118
|
+
|
65
119
|
loss /= len(loader)
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
Args:
|
74
|
-
loaders (list): List of data loaders.
|
75
|
-
model: The model to evaluate.
|
76
|
-
loader_name_list (list): List of names for the data loaders.
|
77
|
-
epoch (int): The current epoch.
|
78
|
-
train_mode (str): The training mode ('erm' or 'irm').
|
79
|
-
loss_type: The type of loss function.
|
80
|
-
|
81
|
-
Returns:
|
82
|
-
tuple: A tuple containing the evaluation result and the time taken for evaluation.
|
83
|
-
"""
|
84
|
-
start_time = time.time()
|
85
|
-
df_list = []
|
86
|
-
if train_mode == 'erm':
|
87
|
-
result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
88
|
-
if train_mode == 'irm':
|
89
|
-
for loader_index in range(0, len(loaders)):
|
90
|
-
loader = loaders[loader_index]
|
91
|
-
loader_name = loader_name_list[loader_index]
|
92
|
-
data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
|
93
|
-
torch.cuda.empty_cache()
|
94
|
-
df_list.append(data_df)
|
95
|
-
result = pd.concat(df_list)
|
96
|
-
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
97
|
-
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
98
|
-
tot_mean = result['accuracy'].mean(skipna=True)
|
99
|
-
loss_mean = result['loss'].mean(skipna=True)
|
100
|
-
prauc_mean = result['prauc'].mean(skipna=True)
|
101
|
-
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
102
|
-
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
103
|
-
end_time = time.time()
|
104
|
-
test_time = end_time - start_time
|
105
|
-
return result, test_time
|
120
|
+
data_dict = classification_metrics(all_labels, prediction_pos_probs)
|
121
|
+
data_dict['loss'] = loss
|
122
|
+
data_dict['epoch'] = epoch
|
123
|
+
data_dict['Accuracy'] = acc
|
124
|
+
|
125
|
+
return data_dict, [prediction_pos_probs, all_labels]
|
106
126
|
|
107
127
|
def test_model_core(model, loader, loader_name, epoch, loss_type):
|
108
128
|
|
@@ -145,7 +165,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
|
|
145
165
|
acc = correct / total_samples
|
146
166
|
end_time = time.time()
|
147
167
|
test_time = end_time - start_time
|
148
|
-
print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
168
|
+
#print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
|
149
169
|
|
150
170
|
# Constructing the DataFrame
|
151
171
|
results_df = pd.DataFrame({
|
@@ -158,7 +178,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
|
|
158
178
|
data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
|
159
179
|
return data_df, prediction_pos_probs, all_labels, results_df
|
160
180
|
|
161
|
-
def test_model_performance(loaders, model, loader_name_list, epoch,
|
181
|
+
def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
|
162
182
|
"""
|
163
183
|
Test the performance of a model on given data loaders.
|
164
184
|
|
@@ -167,7 +187,6 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
167
187
|
model: The model to be tested.
|
168
188
|
loader_name_list (list): List of names for the data loaders.
|
169
189
|
epoch (int): The current epoch.
|
170
|
-
train_mode (str): The training mode ('erm' or 'irm').
|
171
190
|
loss_type: The type of loss function.
|
172
191
|
|
173
192
|
Returns:
|
@@ -175,114 +194,89 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
175
194
|
"""
|
176
195
|
start_time = time.time()
|
177
196
|
df_list = []
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
for loader_index in range(0, len(loaders)):
|
182
|
-
loader = loaders[loader_index]
|
183
|
-
loader_name = loader_name_list[loader_index]
|
184
|
-
data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
|
185
|
-
torch.cuda.empty_cache()
|
186
|
-
df_list.append(data_df)
|
187
|
-
result = pd.concat(df_list)
|
188
|
-
nc_mean = result['neg_accuracy'].mean(skipna=True)
|
189
|
-
pc_mean = result['pos_accuracy'].mean(skipna=True)
|
190
|
-
tot_mean = result['accuracy'].mean(skipna=True)
|
191
|
-
loss_mean = result['loss'].mean(skipna=True)
|
192
|
-
prauc_mean = result['prauc'].mean(skipna=True)
|
193
|
-
data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
|
194
|
-
result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
|
195
|
-
end_time = time.time()
|
196
|
-
test_time = end_time - start_time
|
197
|
+
|
198
|
+
result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
|
199
|
+
|
197
200
|
return result, results_df
|
198
201
|
|
199
|
-
def train_test_model(
|
202
|
+
def train_test_model(settings):
|
200
203
|
|
201
204
|
from .io import _save_settings, _copy_missclassified
|
202
205
|
from .utils import pick_best_model
|
203
206
|
from .core import generate_loaders
|
204
|
-
from .settings import set_default_train_test_model
|
205
207
|
|
206
208
|
torch.cuda.empty_cache()
|
207
209
|
torch.cuda.memory.empty_cache()
|
208
210
|
gc.collect()
|
209
211
|
|
210
|
-
|
211
|
-
|
212
|
+
src = settings['src']
|
213
|
+
|
214
|
+
channels_str = ''.join(settings['train_channels'])
|
212
215
|
dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
|
213
216
|
os.makedirs(dst, exist_ok=True)
|
214
217
|
settings['src'] = src
|
215
218
|
settings['dst'] = dst
|
216
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
217
|
-
settings_csv = os.path.join(dst,'train_test_model_settings.csv')
|
218
|
-
settings_df.to_csv(settings_csv, index=False)
|
219
219
|
|
220
|
-
if custom_model:
|
221
|
-
model = torch.load(custom_model_path)
|
220
|
+
if settings['custom_model']:
|
221
|
+
model = torch.load(settings['custom_model_path'])
|
222
222
|
|
223
223
|
if settings['train']:
|
224
224
|
_save_settings(settings, src)
|
225
225
|
|
226
226
|
if settings['train']:
|
227
|
-
train, val,
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
227
|
+
train, val, train_fig = generate_loaders(src,
|
228
|
+
mode='train',
|
229
|
+
image_size=settings['image_size'],
|
230
|
+
batch_size=settings['batch_size'],
|
231
|
+
classes=settings['classes'],
|
232
|
+
n_jobs=settings['n_jobs'],
|
233
|
+
validation_split=settings['val_split'],
|
234
|
+
pin_memory=settings['pin_memory'],
|
235
|
+
normalize=settings['normalize'],
|
236
|
+
channels=settings['train_channels'],
|
237
|
+
augment=settings['augment'],
|
238
|
+
preload_batches=settings['preload_batches'],
|
239
|
+
verbose=settings['verbose'])
|
240
240
|
|
241
|
-
train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
|
242
|
-
train_fig.savefig(train_batch_1_figure, format='pdf', dpi=
|
241
|
+
#train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
|
242
|
+
#train_fig.savefig(train_batch_1_figure, format='pdf', dpi=300)
|
243
243
|
|
244
244
|
if settings['train']:
|
245
|
-
model = train_model(dst = settings['dst'],
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
gradient_accumulation_steps=settings['gradient_accumulation_steps'],
|
266
|
-
channels=settings['channels'])
|
267
|
-
|
268
|
-
torch.cuda.empty_cache()
|
269
|
-
torch.cuda.memory.empty_cache()
|
270
|
-
gc.collect()
|
245
|
+
model, model_path = train_model(dst = settings['dst'],
|
246
|
+
model_type=settings['model_type'],
|
247
|
+
train_loaders = train,
|
248
|
+
epochs = settings['epochs'],
|
249
|
+
learning_rate = settings['learning_rate'],
|
250
|
+
init_weights = settings['init_weights'],
|
251
|
+
weight_decay = settings['weight_decay'],
|
252
|
+
amsgrad = settings['amsgrad'],
|
253
|
+
optimizer_type = settings['optimizer_type'],
|
254
|
+
use_checkpoint = settings['use_checkpoint'],
|
255
|
+
dropout_rate = settings['dropout_rate'],
|
256
|
+
n_jobs = settings['n_jobs'],
|
257
|
+
val_loaders = val,
|
258
|
+
test_loaders = None,
|
259
|
+
intermedeate_save = settings['intermedeate_save'],
|
260
|
+
schedule = settings['schedule'],
|
261
|
+
loss_type=settings['loss_type'],
|
262
|
+
gradient_accumulation=settings['gradient_accumulation'],
|
263
|
+
gradient_accumulation_steps=settings['gradient_accumulation_steps'],
|
264
|
+
channels=settings['train_channels'])
|
271
265
|
|
272
266
|
if settings['test']:
|
273
267
|
test, _, plate_names_test, train_fig = generate_loaders(src,
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
268
|
+
mode='test',
|
269
|
+
image_size=settings['image_size'],
|
270
|
+
batch_size=settings['batch_size'],
|
271
|
+
classes=settings['classes'],
|
272
|
+
n_jobs=settings['n_jobs'],
|
273
|
+
validation_split=0.0,
|
274
|
+
pin_memory=settings['pin_memory'],
|
275
|
+
normalize=settings['normalize'],
|
276
|
+
channels=settings['train_channels'],
|
277
|
+
augment=False,
|
278
|
+
preload_batches=settings['preload_batches'],
|
279
|
+
verbose=settings['verbose'])
|
286
280
|
if model == None:
|
287
281
|
model_path = pick_best_model(src+'/model')
|
288
282
|
print(f'Best model: {model_path}')
|
@@ -304,7 +298,6 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
304
298
|
model=model,
|
305
299
|
loader_name_list='test',
|
306
300
|
epoch=1,
|
307
|
-
train_mode=settings['train_mode'],
|
308
301
|
loss_type=settings['loss_type'])
|
309
302
|
|
310
303
|
result.to_csv(result_loc, index=True, header=True, mode='w')
|
@@ -314,8 +307,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
314
307
|
torch.cuda.empty_cache()
|
315
308
|
torch.cuda.memory.empty_cache()
|
316
309
|
gc.collect()
|
310
|
+
|
311
|
+
return model_path
|
317
312
|
|
318
|
-
def train_model(dst, model_type, train_loaders,
|
313
|
+
def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False):
|
319
314
|
"""
|
320
315
|
Trains a model using the specified parameters.
|
321
316
|
|
@@ -323,8 +318,6 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
323
318
|
dst (str): The destination path to save the model and results.
|
324
319
|
model_type (str): The type of model to train.
|
325
320
|
train_loaders (list): A list of training data loaders.
|
326
|
-
train_loader_names (list): A list of names for the training data loaders.
|
327
|
-
train_mode (str, optional): The training mode. Defaults to 'erm'.
|
328
321
|
epochs (int, optional): The number of training epochs. Defaults to 100.
|
329
322
|
learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
|
330
323
|
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
|
@@ -348,29 +341,35 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
348
341
|
"""
|
349
342
|
|
350
343
|
from .io import _save_model, _save_progress
|
351
|
-
from .utils import
|
344
|
+
from .utils import calculate_loss, choose_model
|
352
345
|
|
353
346
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
354
347
|
|
355
348
|
if test_loaders != None:
|
356
349
|
print(f'Test batches:{len(test_loaders)}')
|
357
|
-
|
350
|
+
|
358
351
|
use_cuda = torch.cuda.is_available()
|
359
352
|
device = torch.device("cuda" if use_cuda else "cpu")
|
353
|
+
|
354
|
+
print(f'Using {device} for Torch')
|
355
|
+
|
360
356
|
kwargs = {'n_jobs': n_jobs, 'pin_memory': True} if use_cuda else {}
|
361
357
|
|
362
|
-
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
358
|
+
#for idx, (images, labels, filenames) in enumerate(train_loaders):
|
359
|
+
# batch, chans, height, width = images.shape
|
360
|
+
# break
|
367
361
|
|
362
|
+
|
363
|
+
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose)
|
364
|
+
|
365
|
+
|
368
366
|
if model is None:
|
369
367
|
print(f'Model {model_type} not found')
|
370
368
|
return
|
371
369
|
|
370
|
+
print(f'Loading Model to {device}...')
|
372
371
|
model.to(device)
|
373
|
-
|
372
|
+
|
374
373
|
if optimizer_type == 'adamw':
|
375
374
|
optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
|
376
375
|
|
@@ -386,140 +385,93 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
386
385
|
else:
|
387
386
|
scheduler = None
|
388
387
|
|
389
|
-
|
390
|
-
for epoch in range(1, epochs+1):
|
391
|
-
model.train()
|
392
|
-
start_time = time.time()
|
393
|
-
running_loss = 0.0
|
394
|
-
|
395
|
-
# Initialize gradients if using gradient accumulation
|
396
|
-
if gradient_accumulation:
|
397
|
-
optimizer.zero_grad()
|
388
|
+
time_ls = []
|
398
389
|
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
# Normalize loss if using gradient accumulation
|
404
|
-
if gradient_accumulation:
|
405
|
-
loss /= gradient_accumulation_steps
|
406
|
-
running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
|
407
|
-
loss.backward()
|
390
|
+
# Initialize lists to accumulate results
|
391
|
+
accumulated_train_dicts = []
|
392
|
+
accumulated_val_dicts = []
|
393
|
+
accumulated_test_dicts = []
|
408
394
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
395
|
+
print(f'Training ...')
|
396
|
+
for epoch in range(1, epochs+1):
|
397
|
+
model.train()
|
398
|
+
start_time = time.time()
|
399
|
+
running_loss = 0.0
|
413
400
|
|
414
|
-
|
415
|
-
|
401
|
+
# Initialize gradients if using gradient accumulation
|
402
|
+
if gradient_accumulation:
|
403
|
+
optimizer.zero_grad()
|
416
404
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
train_names = 'train'
|
422
|
-
results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
|
423
|
-
train_metrics_df['train_test_time'] = train_test_time
|
424
|
-
if val_loaders != None:
|
425
|
-
val_names = 'val'
|
426
|
-
result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
|
427
|
-
|
428
|
-
if schedule == 'reduce_lr_on_plateau':
|
429
|
-
val_loss = result['loss']
|
430
|
-
|
431
|
-
results_df = pd.concat([results_df, result])
|
432
|
-
train_metrics_df['val_time'] = val_time
|
433
|
-
if test_loaders != None:
|
434
|
-
test_names = 'test'
|
435
|
-
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
|
436
|
-
results_df = pd.concat([results_df, result])
|
437
|
-
test_time = (train_test_time+val_time+test_test_time)/3
|
438
|
-
train_metrics_df['test_time'] = test_time
|
439
|
-
|
440
|
-
if scheduler:
|
441
|
-
if schedule == 'reduce_lr_on_plateau':
|
442
|
-
scheduler.step(val_loss)
|
443
|
-
if schedule == 'step_lr':
|
444
|
-
scheduler.step()
|
445
|
-
|
446
|
-
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
447
|
-
clear_output(wait=True)
|
448
|
-
display(results_df)
|
449
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
|
450
|
-
|
451
|
-
if train_mode == 'irm':
|
452
|
-
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
453
|
-
phi = torch.nn.Parameter (torch.ones(4,1))
|
454
|
-
for epoch in range(1, epochs):
|
455
|
-
model.train()
|
456
|
-
penalty_factor = epoch * 1e-5
|
457
|
-
epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
|
458
|
-
loader_erm_loss_list = []
|
459
|
-
total_erm_loss_mean = 0
|
460
|
-
for loader_index in range(0, len(train_loaders)):
|
461
|
-
start_time = time.time()
|
462
|
-
loader = train_loaders[loader_index]
|
463
|
-
loader_erm_loss_mean = 0
|
464
|
-
batch_count = 0
|
465
|
-
batch_erm_loss_list = []
|
466
|
-
for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
|
467
|
-
optimizer.zero_grad()
|
468
|
-
data, target = data.to(device), target.to(device).float()
|
469
|
-
|
470
|
-
output = model(data)
|
471
|
-
erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
|
472
|
-
|
473
|
-
batch_erm_loss_list.append(erm_loss.mean())
|
474
|
-
print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
|
475
|
-
loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
|
476
|
-
loader_erm_loss_list.append(loader_erm_loss_mean)
|
477
|
-
total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
|
478
|
-
irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
|
479
|
-
|
480
|
-
(total_erm_loss_mean + penalty_factor * irm_loss).backward()
|
481
|
-
optimizer.step()
|
482
|
-
|
483
|
-
end_time = time.time()
|
484
|
-
train_time = end_time - start_time
|
405
|
+
for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
|
406
|
+
data, target = data.to(device), target.to(device).float()
|
407
|
+
output = model(data)
|
408
|
+
loss = calculate_loss(output, target, loss_type=loss_type)
|
485
409
|
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
410
|
+
# Normalize loss if using gradient accumulation
|
411
|
+
if gradient_accumulation:
|
412
|
+
loss /= gradient_accumulation_steps
|
413
|
+
running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
|
414
|
+
loss.backward()
|
415
|
+
|
416
|
+
# Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
|
417
|
+
if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
|
418
|
+
optimizer.step()
|
419
|
+
optimizer.zero_grad()
|
420
|
+
|
421
|
+
avg_loss = running_loss / batch_idx
|
422
|
+
batch_size = len(train_loaders)
|
423
|
+
duration = time.time() - start_time
|
424
|
+
time_ls.append(duration)
|
425
|
+
#print(f'Progress: {batch_idx}/{batch_size}, operation_type: DL-Batch, Epoch {epoch}/{epochs}, Loss {avg_loss}, Time {duration}')
|
490
426
|
|
491
|
-
|
492
|
-
|
493
|
-
|
427
|
+
end_time = time.time()
|
428
|
+
train_time = end_time - start_time
|
429
|
+
train_dict, _ = evaluate_model_performance(model, train_loaders, epoch, loss_type=loss_type)
|
430
|
+
train_dict['train_time'] = train_time
|
431
|
+
accumulated_train_dicts.append(train_dict)
|
432
|
+
|
433
|
+
if val_loaders != None:
|
434
|
+
val_dict, _ = evaluate_model_performance(model, val_loaders, epoch, loss_type=loss_type)
|
435
|
+
accumulated_val_dicts.append(val_dict)
|
494
436
|
|
495
|
-
if
|
496
|
-
|
497
|
-
|
437
|
+
if schedule == 'reduce_lr_on_plateau':
|
438
|
+
val_loss = val_dict['loss']
|
439
|
+
|
440
|
+
print(f"Progress: {train_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {train_dict['loss']:.3f}, Val Loss: {val_dict['loss']:.3f}, Train acc.: {train_dict['accuracy']:.3f}, Val acc.: {val_dict['accuracy']:.3f}, Train NC acc.: {train_dict['neg_accuracy']:.3f}, Val NC acc.: {val_dict['neg_accuracy']:.3f}, Train PC acc.: {train_dict['pos_accuracy']:.3f}, Val PC acc.: {val_dict['pos_accuracy']:.3f}, Train PRAUC: {train_dict['prauc']:.3f}, Val PRAUC: {val_dict['prauc']:.3f}")
|
441
|
+
|
442
|
+
else:
|
443
|
+
print(f"Progress: {train_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {train_dict['loss']:.3f}, Train acc.: {train_dict['accuracy']:.3f}, Train NC acc.: {train_dict['neg_accuracy']:.3f}, Train PC acc.: {train_dict['pos_accuracy']:.3f}, Train PRAUC: {train_dict['prauc']:.3f}")
|
444
|
+
if test_loaders != None:
|
445
|
+
test_dict, _ = evaluate_model_performance(model, test_loaders, epoch, loss_type=loss_type)
|
446
|
+
accumulated_test_dicts.append(test_dict)
|
447
|
+
print(f"Progress: {test_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {test_dict['loss']:.3f}, Train acc.: {test_dict['accuracy']:.3f}, Train NC acc.: {test_dict['neg_accuracy']:.3f}, Train PC acc.: {test_dict['pos_accuracy']:.3f}, Train PRAUC: {test_dict['prauc']:.3f}")
|
448
|
+
|
449
|
+
if scheduler:
|
450
|
+
if schedule == 'reduce_lr_on_plateau':
|
451
|
+
scheduler.step(val_loss)
|
452
|
+
if schedule == 'step_lr':
|
453
|
+
scheduler.step()
|
454
|
+
|
455
|
+
if epoch % 10 == 0 or epoch == epochs:
|
456
|
+
if accumulated_train_dicts:
|
457
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
458
|
+
_save_progress(dst, train_df, result_type='train')
|
498
459
|
|
499
|
-
|
500
|
-
|
460
|
+
if accumulated_val_dicts:
|
461
|
+
val_df = pd.DataFrame(accumulated_val_dicts)
|
462
|
+
_save_progress(dst, val_df,result_type='validation')
|
501
463
|
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
if test_loaders != None:
|
506
|
-
test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
|
507
|
-
result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
|
508
|
-
results_df = pd.concat([results_df, result])
|
509
|
-
train_metrics_df['test_test_time'] = test_test_time
|
464
|
+
if accumulated_test_dicts:
|
465
|
+
val_df = pd.DataFrame(accumulated_test_dicts)
|
466
|
+
_save_progress(dst, val_df, result_type='test')
|
510
467
|
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
468
|
+
batch_size = len(train_loaders)
|
469
|
+
duration = time.time() - start_time
|
470
|
+
time_ls.append(duration)
|
471
|
+
|
472
|
+
model_path = _save_model(model, model_type, train_dict, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
|
516
473
|
|
517
|
-
|
518
|
-
display(results_df)
|
519
|
-
_save_progress(dst, results_df, train_metrics_df, epoch, epochs)
|
520
|
-
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
521
|
-
print(f'Saved model: {dst}')
|
522
|
-
return model
|
474
|
+
return model, model_path
|
523
475
|
|
524
476
|
def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
525
477
|
|
@@ -778,8 +730,35 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
|
|
778
730
|
smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
|
779
731
|
smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
|
780
732
|
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
733
|
+
def deep_spacr(settings={}):
|
734
|
+
from .settings import deep_spacr_defaults
|
735
|
+
from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
|
736
|
+
from .utils import save_settings
|
737
|
+
|
738
|
+
settings = deep_spacr_defaults(settings)
|
739
|
+
src = settings['src']
|
740
|
+
|
741
|
+
save_settings(settings, name='DL_model')
|
742
|
+
|
743
|
+
if settings['train'] or settings['test']:
|
744
|
+
if settings['generate_training_dataset']:
|
745
|
+
print(f"Generating train and test datasets ...")
|
746
|
+
train_path, test_path = generate_training_dataset(settings)
|
747
|
+
print(f'Generated Train set: {train_path}')
|
748
|
+
print(f'Generated Test set: {test_path}')
|
749
|
+
settings['src'] = os.path.dirname(train_path)
|
750
|
+
|
751
|
+
if settings['train_DL_model']:
|
752
|
+
print(f"Training model ...")
|
753
|
+
model_path = train_test_model(settings)
|
754
|
+
settings['model_path'] = model_path
|
755
|
+
settings['src'] = src
|
756
|
+
|
757
|
+
if settings['apply_model_to_dataset']:
|
758
|
+
if not settings['tar_path'] and os.path.isabs(settings['tar_path']) and os.path.exists(settings['tar_path']):
|
759
|
+
print(f"{settings['tar_path']} not found generating dataset ...")
|
760
|
+
tar_path = generate_dataset(settings)
|
761
|
+
settings['tar_path'] = tar_path
|
762
|
+
|
763
|
+
if os.path.exists(settings['model_path']):
|
764
|
+
apply_model_to_tar(settings)
|