spacr 0.2.4__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. spacr/__init__.py +1 -11
  2. spacr/core.py +277 -349
  3. spacr/deep_spacr.py +248 -269
  4. spacr/gui.py +58 -54
  5. spacr/gui_core.py +689 -535
  6. spacr/gui_elements.py +1002 -153
  7. spacr/gui_utils.py +452 -107
  8. spacr/io.py +158 -91
  9. spacr/measure.py +199 -151
  10. spacr/plot.py +159 -47
  11. spacr/resources/font/open_sans/OFL.txt +93 -0
  12. spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
  13. spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
  14. spacr/resources/font/open_sans/README.txt +100 -0
  15. spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
  16. spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
  17. spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
  18. spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
  19. spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
  20. spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
  21. spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
  22. spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
  23. spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
  24. spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
  25. spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
  26. spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
  27. spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
  28. spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
  29. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
  30. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
  31. spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
  32. spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
  33. spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
  34. spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
  35. spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
  36. spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
  37. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
  38. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
  39. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
  40. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
  41. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
  42. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
  43. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
  44. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
  45. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
  46. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
  47. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
  48. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
  49. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
  50. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
  51. spacr/resources/icons/logo.pdf +2786 -6
  52. spacr/resources/icons/logo_spacr.png +0 -0
  53. spacr/resources/icons/logo_spacr_1.png +0 -0
  54. spacr/sequencing.py +477 -587
  55. spacr/settings.py +217 -144
  56. spacr/utils.py +46 -46
  57. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/METADATA +46 -35
  58. spacr-0.2.8.dist-info/RECORD +100 -0
  59. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/WHEEL +1 -1
  60. spacr-0.2.4.dist-info/RECORD +0 -58
  61. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/LICENSE +0 -0
  62. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/entry_points.txt +0 -0
  63. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py CHANGED
@@ -1,4 +1,7 @@
1
1
  import os, torch, time, gc, datetime
2
+
3
+ torch.backends.cudnn.benchmark = True
4
+
2
5
  import numpy as np
3
6
  import pandas as pd
4
7
  from torch.optim import Adagrad, AdamW
@@ -8,13 +11,14 @@ import torch.nn.functional as F
8
11
  from IPython.display import display, clear_output
9
12
  import matplotlib.pyplot as plt
10
13
  from PIL import Image
14
+ from sklearn.metrics import auc, precision_recall_curve
15
+ from multiprocessing import set_start_method
16
+ #set_start_method('spawn', force=True)
11
17
 
12
18
  from .logger import log_function_call
13
19
  from .utils import close_multiprocessing_processes, reset_mp
14
- #reset_mp()
15
- #close_multiprocessing_processes()
16
20
 
17
- def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
21
+ def evaluate_model_performance(model, loader, epoch, loss_type):
18
22
  """
19
23
  Evaluates the performance of a model on a given data loader.
20
24
 
@@ -31,7 +35,56 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
31
35
  all_labels (list): The true labels for each prediction.
32
36
  """
33
37
 
34
- from .utils import calculate_loss, classification_metrics
38
+ from .utils import calculate_loss
39
+
40
+ def classification_metrics(all_labels, prediction_pos_probs):
41
+ """
42
+ Calculate classification metrics for binary classification.
43
+
44
+ Parameters:
45
+ - all_labels (list): List of true labels.
46
+ - prediction_pos_probs (list): List of predicted positive probabilities.
47
+ - loader_name (str): Name of the data loader.
48
+
49
+ Returns:
50
+ - data_df (DataFrame): DataFrame containing the calculated metrics.
51
+ """
52
+
53
+ if len(all_labels) != len(prediction_pos_probs):
54
+ raise ValueError(f"all_labels ({len(all_labels)}) and pred_labels ({len(prediction_pos_probs)}) have different lengths")
55
+
56
+ unique_labels = np.unique(all_labels)
57
+ if len(unique_labels) >= 2:
58
+ pr_labels = np.array(all_labels).astype(int)
59
+ precision, recall, thresholds = precision_recall_curve(pr_labels, prediction_pos_probs, pos_label=1)
60
+ pr_auc = auc(recall, precision)
61
+ thresholds = np.append(thresholds, 0.0)
62
+ f1_scores = 2 * (precision * recall) / (precision + recall)
63
+ optimal_idx = np.nanargmax(f1_scores)
64
+ optimal_threshold = thresholds[optimal_idx]
65
+ pred_labels = [int(p > 0.5) for p in prediction_pos_probs]
66
+ if len(unique_labels) < 2:
67
+ optimal_threshold = 0.5
68
+ pred_labels = [int(p > optimal_threshold) for p in prediction_pos_probs]
69
+ pr_auc = np.nan
70
+ data = {'label': all_labels, 'pred': pred_labels}
71
+ df = pd.DataFrame(data)
72
+ pc_df = df[df['label'] == 1.0]
73
+ nc_df = df[df['label'] == 0.0]
74
+ correct = df[df['label'] == df['pred']]
75
+ acc_all = len(correct) / len(df)
76
+ if len(pc_df) > 0:
77
+ correct_pc = pc_df[pc_df['label'] == pc_df['pred']]
78
+ acc_pc = len(correct_pc) / len(pc_df)
79
+ else:
80
+ acc_pc = np.nan
81
+ if len(nc_df) > 0:
82
+ correct_nc = nc_df[nc_df['label'] == nc_df['pred']]
83
+ acc_nc = len(correct_nc) / len(nc_df)
84
+ else:
85
+ acc_nc = np.nan
86
+ data_dict = {'accuracy': acc_all, 'neg_accuracy': acc_nc, 'pos_accuracy': acc_pc, 'prauc':pr_auc, 'optimal_threshold':optimal_threshold}
87
+ return data_dict
35
88
 
36
89
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
90
  model.eval()
@@ -61,48 +114,15 @@ def evaluate_model_core(model, loader, loader_name, epoch, loss_type):
61
114
  acc = correct / total_samples
62
115
  end_time = time.time()
63
116
  test_time = end_time - start_time
64
- print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
117
+ #print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx+1}/{len(loader)} loss: {mean_loss:.5f} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
118
+
65
119
  loss /= len(loader)
66
- data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
67
- return data_df, prediction_pos_probs, all_labels
68
-
69
- def evaluate_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
70
- """
71
- Evaluate the performance of a model on given data loaders.
72
-
73
- Args:
74
- loaders (list): List of data loaders.
75
- model: The model to evaluate.
76
- loader_name_list (list): List of names for the data loaders.
77
- epoch (int): The current epoch.
78
- train_mode (str): The training mode ('erm' or 'irm').
79
- loss_type: The type of loss function.
80
-
81
- Returns:
82
- tuple: A tuple containing the evaluation result and the time taken for evaluation.
83
- """
84
- start_time = time.time()
85
- df_list = []
86
- if train_mode == 'erm':
87
- result, _, _ = evaluate_model_core(model, loaders, loader_name_list, epoch, loss_type)
88
- if train_mode == 'irm':
89
- for loader_index in range(0, len(loaders)):
90
- loader = loaders[loader_index]
91
- loader_name = loader_name_list[loader_index]
92
- data_df, _, _ = evaluate_model_core(model, loader, loader_name, epoch, loss_type)
93
- torch.cuda.empty_cache()
94
- df_list.append(data_df)
95
- result = pd.concat(df_list)
96
- nc_mean = result['neg_accuracy'].mean(skipna=True)
97
- pc_mean = result['pos_accuracy'].mean(skipna=True)
98
- tot_mean = result['accuracy'].mean(skipna=True)
99
- loss_mean = result['loss'].mean(skipna=True)
100
- prauc_mean = result['prauc'].mean(skipna=True)
101
- data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
102
- result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
103
- end_time = time.time()
104
- test_time = end_time - start_time
105
- return result, test_time
120
+ data_dict = classification_metrics(all_labels, prediction_pos_probs)
121
+ data_dict['loss'] = loss
122
+ data_dict['epoch'] = epoch
123
+ data_dict['Accuracy'] = acc
124
+
125
+ return data_dict, [prediction_pos_probs, all_labels]
106
126
 
107
127
  def test_model_core(model, loader, loader_name, epoch, loss_type):
108
128
 
@@ -145,7 +165,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
145
165
  acc = correct / total_samples
146
166
  end_time = time.time()
147
167
  test_time = end_time - start_time
148
- print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
168
+ #print(f'\rTest: epoch: {epoch} Accuracy: {acc:.5f} batch: {batch_idx}/{len(loader)} loss: {mean_loss:.5f} time {test_time:.5f}', end='\r', flush=True)
149
169
 
150
170
  # Constructing the DataFrame
151
171
  results_df = pd.DataFrame({
@@ -158,7 +178,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
158
178
  data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
159
179
  return data_df, prediction_pos_probs, all_labels, results_df
160
180
 
161
- def test_model_performance(loaders, model, loader_name_list, epoch, train_mode, loss_type):
181
+ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
162
182
  """
163
183
  Test the performance of a model on given data loaders.
164
184
 
@@ -167,7 +187,6 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
167
187
  model: The model to be tested.
168
188
  loader_name_list (list): List of names for the data loaders.
169
189
  epoch (int): The current epoch.
170
- train_mode (str): The training mode ('erm' or 'irm').
171
190
  loss_type: The type of loss function.
172
191
 
173
192
  Returns:
@@ -175,114 +194,89 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
175
194
  """
176
195
  start_time = time.time()
177
196
  df_list = []
178
- if train_mode == 'erm':
179
- result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
180
- if train_mode == 'irm':
181
- for loader_index in range(0, len(loaders)):
182
- loader = loaders[loader_index]
183
- loader_name = loader_name_list[loader_index]
184
- data_df, prediction_pos_probs, all_labels, results_df = test_model_core(model, loader, loader_name, epoch, loss_type)
185
- torch.cuda.empty_cache()
186
- df_list.append(data_df)
187
- result = pd.concat(df_list)
188
- nc_mean = result['neg_accuracy'].mean(skipna=True)
189
- pc_mean = result['pos_accuracy'].mean(skipna=True)
190
- tot_mean = result['accuracy'].mean(skipna=True)
191
- loss_mean = result['loss'].mean(skipna=True)
192
- prauc_mean = result['prauc'].mean(skipna=True)
193
- data_mean = {'accuracy': tot_mean, 'neg_accuracy': nc_mean, 'pos_accuracy': pc_mean, 'loss': loss_mean, 'prauc': prauc_mean}
194
- result = pd.concat([pd.DataFrame(result), pd.DataFrame(data_mean, index=[str(epoch)+'_mean'])])
195
- end_time = time.time()
196
- test_time = end_time - start_time
197
+
198
+ result, prediction_pos_probs, all_labels, results_df = test_model_core(model, loaders, loader_name_list, epoch, loss_type)
199
+
197
200
  return result, results_df
198
201
 
199
- def train_test_model(src, settings, custom_model=False, custom_model_path=None):
202
+ def train_test_model(settings):
200
203
 
201
204
  from .io import _save_settings, _copy_missclassified
202
205
  from .utils import pick_best_model
203
206
  from .core import generate_loaders
204
- from .settings import set_default_train_test_model
205
207
 
206
208
  torch.cuda.empty_cache()
207
209
  torch.cuda.memory.empty_cache()
208
210
  gc.collect()
209
211
 
210
- settings = set_default_train_test_model(settings)
211
- channels_str = ''.join(settings['channels'])
212
+ src = settings['src']
213
+
214
+ channels_str = ''.join(settings['train_channels'])
212
215
  dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
213
216
  os.makedirs(dst, exist_ok=True)
214
217
  settings['src'] = src
215
218
  settings['dst'] = dst
216
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
217
- settings_csv = os.path.join(dst,'train_test_model_settings.csv')
218
- settings_df.to_csv(settings_csv, index=False)
219
219
 
220
- if custom_model:
221
- model = torch.load(custom_model_path)
220
+ if settings['custom_model']:
221
+ model = torch.load(settings['custom_model_path'])
222
222
 
223
223
  if settings['train']:
224
224
  _save_settings(settings, src)
225
225
 
226
226
  if settings['train']:
227
- train, val, plate_names, train_fig = generate_loaders(src,
228
- train_mode=settings['train_mode'],
229
- mode='train',
230
- image_size=settings['image_size'],
231
- batch_size=settings['batch_size'],
232
- classes=settings['classes'],
233
- n_jobs=settings['n_jobs'],
234
- validation_split=settings['val_split'],
235
- pin_memory=settings['pin_memory'],
236
- normalize=settings['normalize'],
237
- channels=settings['channels'],
238
- augment=settings['augment'],
239
- verbose=settings['verbose'])
227
+ train, val, train_fig = generate_loaders(src,
228
+ mode='train',
229
+ image_size=settings['image_size'],
230
+ batch_size=settings['batch_size'],
231
+ classes=settings['classes'],
232
+ n_jobs=settings['n_jobs'],
233
+ validation_split=settings['val_split'],
234
+ pin_memory=settings['pin_memory'],
235
+ normalize=settings['normalize'],
236
+ channels=settings['train_channels'],
237
+ augment=settings['augment'],
238
+ preload_batches=settings['preload_batches'],
239
+ verbose=settings['verbose'])
240
240
 
241
- train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
242
- train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
241
+ #train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
242
+ #train_fig.savefig(train_batch_1_figure, format='pdf', dpi=300)
243
243
 
244
244
  if settings['train']:
245
- model = train_model(dst = settings['dst'],
246
- model_type=settings['model_type'],
247
- train_loaders = train,
248
- train_loader_names = plate_names,
249
- train_mode = settings['train_mode'],
250
- epochs = settings['epochs'],
251
- learning_rate = settings['learning_rate'],
252
- init_weights = settings['init_weights'],
253
- weight_decay = settings['weight_decay'],
254
- amsgrad = settings['amsgrad'],
255
- optimizer_type = settings['optimizer_type'],
256
- use_checkpoint = settings['use_checkpoint'],
257
- dropout_rate = settings['dropout_rate'],
258
- n_jobs = settings['n_jobs'],
259
- val_loaders = val,
260
- test_loaders = None,
261
- intermedeate_save = settings['intermedeate_save'],
262
- schedule = settings['schedule'],
263
- loss_type=settings['loss_type'],
264
- gradient_accumulation=settings['gradient_accumulation'],
265
- gradient_accumulation_steps=settings['gradient_accumulation_steps'],
266
- channels=settings['channels'])
267
-
268
- torch.cuda.empty_cache()
269
- torch.cuda.memory.empty_cache()
270
- gc.collect()
245
+ model, model_path = train_model(dst = settings['dst'],
246
+ model_type=settings['model_type'],
247
+ train_loaders = train,
248
+ epochs = settings['epochs'],
249
+ learning_rate = settings['learning_rate'],
250
+ init_weights = settings['init_weights'],
251
+ weight_decay = settings['weight_decay'],
252
+ amsgrad = settings['amsgrad'],
253
+ optimizer_type = settings['optimizer_type'],
254
+ use_checkpoint = settings['use_checkpoint'],
255
+ dropout_rate = settings['dropout_rate'],
256
+ n_jobs = settings['n_jobs'],
257
+ val_loaders = val,
258
+ test_loaders = None,
259
+ intermedeate_save = settings['intermedeate_save'],
260
+ schedule = settings['schedule'],
261
+ loss_type=settings['loss_type'],
262
+ gradient_accumulation=settings['gradient_accumulation'],
263
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'],
264
+ channels=settings['train_channels'])
271
265
 
272
266
  if settings['test']:
273
267
  test, _, plate_names_test, train_fig = generate_loaders(src,
274
- train_mode=settings['train_mode'],
275
- mode='test',
276
- image_size=settings['image_size'],
277
- batch_size=settings['batch_size'],
278
- classes=settings['classes'],
279
- n_jobs=settings['n_jobs'],
280
- validation_split=0.0,
281
- pin_memory=settings['pin_memory'],
282
- normalize=settings['normalize'],
283
- channels=settings['channels'],
284
- augment=False,
285
- verbose=settings['verbose'])
268
+ mode='test',
269
+ image_size=settings['image_size'],
270
+ batch_size=settings['batch_size'],
271
+ classes=settings['classes'],
272
+ n_jobs=settings['n_jobs'],
273
+ validation_split=0.0,
274
+ pin_memory=settings['pin_memory'],
275
+ normalize=settings['normalize'],
276
+ channels=settings['train_channels'],
277
+ augment=False,
278
+ preload_batches=settings['preload_batches'],
279
+ verbose=settings['verbose'])
286
280
  if model == None:
287
281
  model_path = pick_best_model(src+'/model')
288
282
  print(f'Best model: {model_path}')
@@ -304,7 +298,6 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
304
298
  model=model,
305
299
  loader_name_list='test',
306
300
  epoch=1,
307
- train_mode=settings['train_mode'],
308
301
  loss_type=settings['loss_type'])
309
302
 
310
303
  result.to_csv(result_loc, index=True, header=True, mode='w')
@@ -314,8 +307,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
314
307
  torch.cuda.empty_cache()
315
308
  torch.cuda.memory.empty_cache()
316
309
  gc.collect()
310
+
311
+ return model_path
317
312
 
318
- def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
313
+ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False):
319
314
  """
320
315
  Trains a model using the specified parameters.
321
316
 
@@ -323,8 +318,6 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
323
318
  dst (str): The destination path to save the model and results.
324
319
  model_type (str): The type of model to train.
325
320
  train_loaders (list): A list of training data loaders.
326
- train_loader_names (list): A list of names for the training data loaders.
327
- train_mode (str, optional): The training mode. Defaults to 'erm'.
328
321
  epochs (int, optional): The number of training epochs. Defaults to 100.
329
322
  learning_rate (float, optional): The learning rate for the optimizer. Defaults to 0.0001.
330
323
  weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.05.
@@ -348,29 +341,35 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
348
341
  """
349
342
 
350
343
  from .io import _save_model, _save_progress
351
- from .utils import compute_irm_penalty, calculate_loss, choose_model
344
+ from .utils import calculate_loss, choose_model
352
345
 
353
346
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
354
347
 
355
348
  if test_loaders != None:
356
349
  print(f'Test batches:{len(test_loaders)}')
357
-
350
+
358
351
  use_cuda = torch.cuda.is_available()
359
352
  device = torch.device("cuda" if use_cuda else "cpu")
353
+
354
+ print(f'Using {device} for Torch')
355
+
360
356
  kwargs = {'n_jobs': n_jobs, 'pin_memory': True} if use_cuda else {}
361
357
 
362
- for idx, (images, labels, filenames) in enumerate(train_loaders):
363
- batch, chans, height, width = images.shape
364
- break
365
-
366
- model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
358
+ #for idx, (images, labels, filenames) in enumerate(train_loaders):
359
+ # batch, chans, height, width = images.shape
360
+ # break
367
361
 
362
+
363
+ model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose)
364
+
365
+
368
366
  if model is None:
369
367
  print(f'Model {model_type} not found')
370
368
  return
371
369
 
370
+ print(f'Loading Model to {device}...')
372
371
  model.to(device)
373
-
372
+
374
373
  if optimizer_type == 'adamw':
375
374
  optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad)
376
375
 
@@ -386,140 +385,93 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
386
385
  else:
387
386
  scheduler = None
388
387
 
389
- if train_mode == 'erm':
390
- for epoch in range(1, epochs+1):
391
- model.train()
392
- start_time = time.time()
393
- running_loss = 0.0
394
-
395
- # Initialize gradients if using gradient accumulation
396
- if gradient_accumulation:
397
- optimizer.zero_grad()
388
+ time_ls = []
398
389
 
399
- for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
400
- data, target = data.to(device), target.to(device).float()
401
- output = model(data)
402
- loss = calculate_loss(output, target, loss_type=loss_type)
403
- # Normalize loss if using gradient accumulation
404
- if gradient_accumulation:
405
- loss /= gradient_accumulation_steps
406
- running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
407
- loss.backward()
390
+ # Initialize lists to accumulate results
391
+ accumulated_train_dicts = []
392
+ accumulated_val_dicts = []
393
+ accumulated_test_dicts = []
408
394
 
409
- # Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
410
- if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
411
- optimizer.step()
412
- optimizer.zero_grad()
395
+ print(f'Training ...')
396
+ for epoch in range(1, epochs+1):
397
+ model.train()
398
+ start_time = time.time()
399
+ running_loss = 0.0
413
400
 
414
- avg_loss = running_loss / batch_idx
415
- print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
401
+ # Initialize gradients if using gradient accumulation
402
+ if gradient_accumulation:
403
+ optimizer.zero_grad()
416
404
 
417
- end_time = time.time()
418
- train_time = end_time - start_time
419
- train_metrics = {'epoch':epoch,'loss':loss.cpu().item(), 'train_time':train_time}
420
- train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
421
- train_names = 'train'
422
- results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
423
- train_metrics_df['train_test_time'] = train_test_time
424
- if val_loaders != None:
425
- val_names = 'val'
426
- result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
427
-
428
- if schedule == 'reduce_lr_on_plateau':
429
- val_loss = result['loss']
430
-
431
- results_df = pd.concat([results_df, result])
432
- train_metrics_df['val_time'] = val_time
433
- if test_loaders != None:
434
- test_names = 'test'
435
- result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
436
- results_df = pd.concat([results_df, result])
437
- test_time = (train_test_time+val_time+test_test_time)/3
438
- train_metrics_df['test_time'] = test_time
439
-
440
- if scheduler:
441
- if schedule == 'reduce_lr_on_plateau':
442
- scheduler.step(val_loss)
443
- if schedule == 'step_lr':
444
- scheduler.step()
445
-
446
- _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
447
- clear_output(wait=True)
448
- display(results_df)
449
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
450
-
451
- if train_mode == 'irm':
452
- dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
453
- phi = torch.nn.Parameter (torch.ones(4,1))
454
- for epoch in range(1, epochs):
455
- model.train()
456
- penalty_factor = epoch * 1e-5
457
- epoch_names = [str(epoch) + '_' + item for item in train_loader_names]
458
- loader_erm_loss_list = []
459
- total_erm_loss_mean = 0
460
- for loader_index in range(0, len(train_loaders)):
461
- start_time = time.time()
462
- loader = train_loaders[loader_index]
463
- loader_erm_loss_mean = 0
464
- batch_count = 0
465
- batch_erm_loss_list = []
466
- for batch_idx, (data, target, filenames) in enumerate(loader, start=1):
467
- optimizer.zero_grad()
468
- data, target = data.to(device), target.to(device).float()
469
-
470
- output = model(data)
471
- erm_loss = F.binary_cross_entropy_with_logits(output * dummy_w, target, reduction='none')
472
-
473
- batch_erm_loss_list.append(erm_loss.mean())
474
- print(f'\repoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)}', end='\r', flush=True)
475
- loader_erm_loss_mean = torch.stack(batch_erm_loss_list).mean()
476
- loader_erm_loss_list.append(loader_erm_loss_mean)
477
- total_erm_loss_mean = torch.stack(loader_erm_loss_list).mean()
478
- irm_loss = compute_irm_penalty(loader_erm_loss_list, dummy_w, device)
479
-
480
- (total_erm_loss_mean + penalty_factor * irm_loss).backward()
481
- optimizer.step()
482
-
483
- end_time = time.time()
484
- train_time = end_time - start_time
405
+ for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1):
406
+ data, target = data.to(device), target.to(device).float()
407
+ output = model(data)
408
+ loss = calculate_loss(output, target, loss_type=loss_type)
485
409
 
486
- train_metrics = {'epoch': epoch, 'irm_loss': irm_loss, 'erm_loss': total_erm_loss_mean, 'penalty_factor': penalty_factor, 'train_time': train_time}
487
- #train_metrics = {'epoch':epoch,'irm_loss':irm_loss.cpu().item(),'erm_loss':total_erm_loss_mean.cpu().item(),'penalty_factor':penalty_factor, 'train_time':train_time}
488
- train_metrics_df = pd.DataFrame(train_metrics, index=[epoch])
489
- print(f'\rTrain: epoch: {epoch} loader: {loader_index} batch: {batch_idx+1}/{len(loader)} irm_loss: {irm_loss:.5f} mean_erm_loss: {total_erm_loss_mean:.5f} train time {train_time:.5f}', end='\r', flush=True)
410
+ # Normalize loss if using gradient accumulation
411
+ if gradient_accumulation:
412
+ loss /= gradient_accumulation_steps
413
+ running_loss += loss.item() * gradient_accumulation_steps # correct the running_loss
414
+ loss.backward()
415
+
416
+ # Step optimizer if not using gradient accumulation or every gradient_accumulation_steps
417
+ if not gradient_accumulation or (batch_idx % gradient_accumulation_steps == 0):
418
+ optimizer.step()
419
+ optimizer.zero_grad()
420
+
421
+ avg_loss = running_loss / batch_idx
422
+ batch_size = len(train_loaders)
423
+ duration = time.time() - start_time
424
+ time_ls.append(duration)
425
+ #print(f'Progress: {batch_idx}/{batch_size}, operation_type: DL-Batch, Epoch {epoch}/{epochs}, Loss {avg_loss}, Time {duration}')
490
426
 
491
- train_names = [item + '_train' for item in train_loader_names]
492
- results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='irm', loss_type=loss_type)
493
- train_metrics_df['train_test_time'] = train_test_time
427
+ end_time = time.time()
428
+ train_time = end_time - start_time
429
+ train_dict, _ = evaluate_model_performance(model, train_loaders, epoch, loss_type=loss_type)
430
+ train_dict['train_time'] = train_time
431
+ accumulated_train_dicts.append(train_dict)
432
+
433
+ if val_loaders != None:
434
+ val_dict, _ = evaluate_model_performance(model, val_loaders, epoch, loss_type=loss_type)
435
+ accumulated_val_dicts.append(val_dict)
494
436
 
495
- if val_loaders != None:
496
- val_names = [item + '_val' for item in train_loader_names]
497
- result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='irm', loss_type=loss_type)
437
+ if schedule == 'reduce_lr_on_plateau':
438
+ val_loss = val_dict['loss']
439
+
440
+ print(f"Progress: {train_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {train_dict['loss']:.3f}, Val Loss: {val_dict['loss']:.3f}, Train acc.: {train_dict['accuracy']:.3f}, Val acc.: {val_dict['accuracy']:.3f}, Train NC acc.: {train_dict['neg_accuracy']:.3f}, Val NC acc.: {val_dict['neg_accuracy']:.3f}, Train PC acc.: {train_dict['pos_accuracy']:.3f}, Val PC acc.: {val_dict['pos_accuracy']:.3f}, Train PRAUC: {train_dict['prauc']:.3f}, Val PRAUC: {val_dict['prauc']:.3f}")
441
+
442
+ else:
443
+ print(f"Progress: {train_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {train_dict['loss']:.3f}, Train acc.: {train_dict['accuracy']:.3f}, Train NC acc.: {train_dict['neg_accuracy']:.3f}, Train PC acc.: {train_dict['pos_accuracy']:.3f}, Train PRAUC: {train_dict['prauc']:.3f}")
444
+ if test_loaders != None:
445
+ test_dict, _ = evaluate_model_performance(model, test_loaders, epoch, loss_type=loss_type)
446
+ accumulated_test_dicts.append(test_dict)
447
+ print(f"Progress: {test_dict['epoch']}/{epochs}, operation_type: Training, Train Loss: {test_dict['loss']:.3f}, Train acc.: {test_dict['accuracy']:.3f}, Train NC acc.: {test_dict['neg_accuracy']:.3f}, Train PC acc.: {test_dict['pos_accuracy']:.3f}, Train PRAUC: {test_dict['prauc']:.3f}")
448
+
449
+ if scheduler:
450
+ if schedule == 'reduce_lr_on_plateau':
451
+ scheduler.step(val_loss)
452
+ if schedule == 'step_lr':
453
+ scheduler.step()
454
+
455
+ if epoch % 10 == 0 or epoch == epochs:
456
+ if accumulated_train_dicts:
457
+ train_df = pd.DataFrame(accumulated_train_dicts)
458
+ _save_progress(dst, train_df, result_type='train')
498
459
 
499
- if schedule == 'reduce_lr_on_plateau':
500
- val_loss = result['loss']
460
+ if accumulated_val_dicts:
461
+ val_df = pd.DataFrame(accumulated_val_dicts)
462
+ _save_progress(dst, val_df,result_type='validation')
501
463
 
502
- results_df = pd.concat([results_df, result])
503
- train_metrics_df['val_time'] = val_time
504
-
505
- if test_loaders != None:
506
- test_names = [item + '_test' for item in train_loader_names] #test_loader_names?
507
- result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='irm', loss_type=loss_type)
508
- results_df = pd.concat([results_df, result])
509
- train_metrics_df['test_test_time'] = test_test_time
464
+ if accumulated_test_dicts:
465
+ val_df = pd.DataFrame(accumulated_test_dicts)
466
+ _save_progress(dst, val_df, result_type='test')
510
467
 
511
- if scheduler:
512
- if schedule == 'reduce_lr_on_plateau':
513
- scheduler.step(val_loss)
514
- if schedule == 'step_lr':
515
- scheduler.step()
468
+ batch_size = len(train_loaders)
469
+ duration = time.time() - start_time
470
+ time_ls.append(duration)
471
+
472
+ model_path = _save_model(model, model_type, train_dict, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
516
473
 
517
- clear_output(wait=True)
518
- display(results_df)
519
- _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
520
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
521
- print(f'Saved model: {dst}')
522
- return model
474
+ return model, model_path
523
475
 
524
476
  def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
525
477
 
@@ -778,8 +730,35 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
778
730
  smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
779
731
  smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
780
732
 
781
- # Usage
782
- #src = '/path/to/images'
783
- #model_path = '/path/to/model.pth'
784
- #target_label_idx = 0 # Change this to the target class index
785
- #visualize_smooth_grad(src, model_path, target_label_idx)
733
+ def deep_spacr(settings={}):
734
+ from .settings import deep_spacr_defaults
735
+ from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
736
+ from .utils import save_settings
737
+
738
+ settings = deep_spacr_defaults(settings)
739
+ src = settings['src']
740
+
741
+ save_settings(settings, name='DL_model')
742
+
743
+ if settings['train'] or settings['test']:
744
+ if settings['generate_training_dataset']:
745
+ print(f"Generating train and test datasets ...")
746
+ train_path, test_path = generate_training_dataset(settings)
747
+ print(f'Generated Train set: {train_path}')
748
+ print(f'Generated Test set: {test_path}')
749
+ settings['src'] = os.path.dirname(train_path)
750
+
751
+ if settings['train_DL_model']:
752
+ print(f"Training model ...")
753
+ model_path = train_test_model(settings)
754
+ settings['model_path'] = model_path
755
+ settings['src'] = src
756
+
757
+ if settings['apply_model_to_dataset']:
758
+ if not settings['tar_path'] and os.path.isabs(settings['tar_path']) and os.path.exists(settings['tar_path']):
759
+ print(f"{settings['tar_path']} not found generating dataset ...")
760
+ tar_path = generate_dataset(settings)
761
+ settings['tar_path'] = tar_path
762
+
763
+ if os.path.exists(settings['model_path']):
764
+ apply_model_to_tar(settings)