spacr 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +335 -163
  5. spacr/gui.py +2 -0
  6. spacr/gui_core.py +85 -65
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +375 -7
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +108 -133
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +181 -50
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +776 -182
  27. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py CHANGED
@@ -1,5 +1,4 @@
1
- import os, torch, time, gc, datetime
2
-
1
+ import os, torch, time, gc, datetime, cv2
3
2
  torch.backends.cudnn.benchmark = True
4
3
 
5
4
  import numpy as np
@@ -8,15 +7,146 @@ from torch.optim import Adagrad, AdamW
8
7
  from torch.autograd import grad
9
8
  from torch.optim.lr_scheduler import StepLR
10
9
  import torch.nn.functional as F
11
- from IPython.display import display, clear_output
12
10
  import matplotlib.pyplot as plt
13
11
  from PIL import Image
14
12
  from sklearn.metrics import auc, precision_recall_curve
15
- from multiprocessing import set_start_method
16
- #set_start_method('spawn', force=True)
13
+ from IPython.display import display
14
+ from multiprocessing import cpu_count
15
+
16
+ from torchvision import transforms
17
+ from torch.utils.data import DataLoader
18
+
19
+ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
20
+
21
+ from .io import NoClassDataset
22
+ from .utils import print_progress
23
+
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+
26
+ if normalize:
27
+ transform = transforms.Compose([
28
+ transforms.ToTensor(),
29
+ transforms.CenterCrop(size=(image_size, image_size)),
30
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
31
+ else:
32
+ transform = transforms.Compose([
33
+ transforms.ToTensor(),
34
+ transforms.CenterCrop(size=(image_size, image_size))])
35
+
36
+ model = torch.load(model_path)
37
+ print(model)
38
+
39
+ print(f'Loading dataset in {src} with {len(src)} images')
40
+ dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
41
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs)
42
+ print(f'Loaded {len(src)} images')
43
+
44
+ result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
45
+ print(f'Results wil be saved in: {result_loc}')
46
+
47
+ model.eval()
48
+ model = model.to(device)
49
+ prediction_pos_probs = []
50
+ filenames_list = []
51
+ time_ls = []
52
+ with torch.no_grad():
53
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
54
+ start = time.time()
55
+ images = batch_images.to(torch.float).to(device)
56
+ outputs = model(images)
57
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
58
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
59
+ filenames_list.extend(filenames)
60
+ stop = time.time()
61
+ duration = stop - start
62
+ time_ls.append(duration)
63
+ files_processed = batch_idx*batch_size
64
+ files_to_process = len(data_loader)
65
+ print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions")
66
+
67
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
68
+ df = pd.DataFrame(data, index=None)
69
+ df.to_csv(result_loc, index=True, header=True, mode='w')
70
+ torch.cuda.empty_cache()
71
+ torch.cuda.memory.empty_cache()
72
+ return df
73
+
74
+ def apply_model_to_tar(settings={}):
75
+
76
+ from .io import TarImageDataset
77
+ from .utils import process_vision_results, print_progress
78
+
79
+ if os.path.exists(settings['dataset']):
80
+ tar_path = settings['dataset']
81
+ else:
82
+ tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
83
+ model_path = settings['model_path']
84
+
85
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
86
+ if settings['normalize']:
87
+ transform = transforms.Compose([
88
+ transforms.ToTensor(),
89
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
90
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
91
+ else:
92
+ transform = transforms.Compose([
93
+ transforms.ToTensor(),
94
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
95
+
96
+ if settings['verbose']:
97
+ print(f"Loading model from {model_path}")
98
+ print(f"Loading dataset from {tar_path}")
99
+
100
+ model = torch.load(settings['model_path'])
101
+
102
+ dataset = TarImageDataset(tar_path, transform=transform)
103
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
104
+
105
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
106
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
107
+ date_name = datetime.date.today().strftime('%y%m%d')
108
+ dst = os.path.dirname(tar_path)
109
+ result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
17
110
 
18
- from .logger import log_function_call
19
- from .utils import close_multiprocessing_processes, reset_mp
111
+ model.eval()
112
+ model = model.to(device)
113
+
114
+ if settings['verbose']:
115
+ print(model)
116
+ print(f'Generated dataset with {len(dataset)} images')
117
+ print(f'Generating loader from {len(data_loader)} batches')
118
+ print(f'Results wil be saved in: {result_loc}')
119
+ print(f'Model is in eval mode')
120
+ print(f'Model loaded to device')
121
+
122
+ prediction_pos_probs = []
123
+ filenames_list = []
124
+ time_ls = []
125
+ gc.collect()
126
+ with torch.no_grad():
127
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
128
+ start = time.time()
129
+ images = batch_images.to(torch.float).to(device)
130
+ outputs = model(images)
131
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
132
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
133
+ filenames_list.extend(filenames)
134
+ stop = time.time()
135
+ duration = stop - start
136
+ time_ls.append(duration)
137
+ files_processed = batch_idx*settings['batch_size']
138
+ files_to_process = len(data_loader)*settings['batch_size']
139
+ print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
140
+
141
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
142
+ df = pd.DataFrame(data, index=None)
143
+ df = process_vision_results(df, settings['score_threshold'])
144
+
145
+ df.to_csv(result_loc, index=True, header=True, mode='w')
146
+ print(f"Saved results to {result_loc}")
147
+ torch.cuda.empty_cache()
148
+ torch.cuda.memory.empty_cache()
149
+ return df
20
150
 
21
151
  def evaluate_model_performance(model, loader, epoch, loss_type):
22
152
  """
@@ -118,7 +248,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
118
248
 
119
249
  loss /= len(loader)
120
250
  data_dict = classification_metrics(all_labels, prediction_pos_probs)
121
- data_dict['loss'] = loss
251
+ data_dict['loss'] = loss.item()
122
252
  data_dict['epoch'] = epoch
123
253
  data_dict['Accuracy'] = acc
124
254
 
@@ -175,7 +305,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
175
305
  'class_1_probability':prediction_pos_probs})
176
306
 
177
307
  loss /= len(loader)
178
- data_df = classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch)
308
+ data_df = classification_metrics(all_labels, prediction_pos_probs, loss, epoch)
179
309
  return data_df, prediction_pos_probs, all_labels, results_df
180
310
 
181
311
  def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
@@ -201,9 +331,12 @@ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
201
331
 
202
332
  def train_test_model(settings):
203
333
 
204
- from .io import _save_settings, _copy_missclassified
205
- from .utils import pick_best_model
206
- from .core import generate_loaders
334
+ from .io import _copy_missclassified
335
+ from .utils import pick_best_model, save_settings
336
+ from .io import generate_loaders
337
+ from .settings import get_train_test_model_settings
338
+
339
+ settings = get_train_test_model_settings(settings)
207
340
 
208
341
  torch.cuda.empty_cache()
209
342
  torch.cuda.memory.empty_cache()
@@ -221,7 +354,12 @@ def train_test_model(settings):
221
354
  model = torch.load(settings['custom_model_path'])
222
355
 
223
356
  if settings['train']:
224
- _save_settings(settings, src)
357
+ if settings['train'] and settings['test']:
358
+ save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True)
359
+ elif settings['train'] is True:
360
+ save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True)
361
+ elif settings['test'] is True:
362
+ save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True)
225
363
 
226
364
  if settings['train']:
227
365
  train, val, train_fig = generate_loaders(src,
@@ -235,7 +373,6 @@ def train_test_model(settings):
235
373
  normalize=settings['normalize'],
236
374
  channels=settings['train_channels'],
237
375
  augment=settings['augment'],
238
- preload_batches=settings['preload_batches'],
239
376
  verbose=settings['verbose'])
240
377
 
241
378
  #train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
@@ -264,19 +401,19 @@ def train_test_model(settings):
264
401
  channels=settings['train_channels'])
265
402
 
266
403
  if settings['test']:
267
- test, _, plate_names_test, train_fig = generate_loaders(src,
268
- mode='test',
269
- image_size=settings['image_size'],
270
- batch_size=settings['batch_size'],
271
- classes=settings['classes'],
272
- n_jobs=settings['n_jobs'],
273
- validation_split=0.0,
274
- pin_memory=settings['pin_memory'],
275
- normalize=settings['normalize'],
276
- channels=settings['train_channels'],
277
- augment=False,
278
- preload_batches=settings['preload_batches'],
279
- verbose=settings['verbose'])
404
+ test, _, train_fig = generate_loaders(src,
405
+ mode='test',
406
+ image_size=settings['image_size'],
407
+ batch_size=settings['batch_size'],
408
+ classes=settings['classes'],
409
+ n_jobs=settings['n_jobs'],
410
+ validation_split=0.0,
411
+ pin_memory=settings['pin_memory'],
412
+ normalize=settings['normalize'],
413
+ channels=settings['train_channels'],
414
+ augment=False,
415
+ verbose=settings['verbose'])
416
+
280
417
  if model == None:
281
418
  model_path = pick_best_model(src+'/model')
282
419
  print(f'Best model: {model_path}')
@@ -308,7 +445,10 @@ def train_test_model(settings):
308
445
  torch.cuda.memory.empty_cache()
309
446
  gc.collect()
310
447
 
311
- return model_path
448
+ if settings['train']:
449
+ return model_path
450
+ if settings['test']:
451
+ return result_loc
312
452
 
313
453
  def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False):
314
454
  """
@@ -355,11 +495,6 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
355
495
 
356
496
  kwargs = {'n_jobs': n_jobs, 'pin_memory': True} if use_cuda else {}
357
497
 
358
- #for idx, (images, labels, filenames) in enumerate(train_loaders):
359
- # batch, chans, height, width = images.shape
360
- # break
361
-
362
-
363
498
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose)
364
499
 
365
500
 
@@ -452,19 +587,21 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
452
587
  if schedule == 'step_lr':
453
588
  scheduler.step()
454
589
 
455
- if epoch % 10 == 0 or epoch == epochs:
456
- if accumulated_train_dicts:
457
- train_df = pd.DataFrame(accumulated_train_dicts)
458
- _save_progress(dst, train_df, result_type='train')
459
-
460
- if accumulated_val_dicts:
461
- val_df = pd.DataFrame(accumulated_val_dicts)
462
- _save_progress(dst, val_df,result_type='validation')
463
-
464
- if accumulated_test_dicts:
465
- val_df = pd.DataFrame(accumulated_test_dicts)
466
- _save_progress(dst, val_df, result_type='test')
467
-
590
+ if accumulated_train_dicts and accumulated_val_dicts:
591
+ train_df = pd.DataFrame(accumulated_train_dicts)
592
+ validation_df = pd.DataFrame(accumulated_val_dicts)
593
+ _save_progress(dst, train_df, validation_df)
594
+ accumulated_train_dicts, accumulated_val_dicts = [], []
595
+
596
+ elif accumulated_train_dicts:
597
+ train_df = pd.DataFrame(accumulated_train_dicts)
598
+ _save_progress(dst, train_df, None)
599
+ accumulated_train_dicts = []
600
+ elif accumulated_test_dicts:
601
+ test_df = pd.DataFrame(accumulated_test_dicts)
602
+ _save_progress(dst, test_df, None)
603
+ accumulated_test_dicts = []
604
+
468
605
  batch_size = len(train_loaders)
469
606
  duration = time.time() - start_time
470
607
  time_ls.append(duration)
@@ -473,133 +610,168 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
473
610
 
474
611
  return model, model_path
475
612
 
476
- def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
477
-
478
- from spacr.utils import SaliencyMapGenerator, preprocess_image
479
-
480
- use_cuda = torch.cuda.is_available()
481
- device = torch.device("cuda" if use_cuda else "cpu")
482
-
483
- # Load the entire model object
484
- model = torch.load(model_path)
485
- model.to(device)
486
-
487
- # Create directory for saving saliency maps if it does not exist
488
- if save_saliency and not os.path.exists(save_dir):
489
- os.makedirs(save_dir)
490
-
491
- # Collect all images and their tensors
492
- images = []
493
- input_tensors = []
494
- filenames = []
495
- for file in os.listdir(src):
496
- if not file.endswith('.png'):
497
- continue
498
- image_path = os.path.join(src, file)
499
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
500
- images.append(image)
501
- input_tensors.append(input_tensor)
502
- filenames.append(file)
503
-
504
- input_tensors = torch.cat(input_tensors).to(device)
505
- class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
506
-
507
- # Generate saliency maps
508
- cam_generator = SaliencyMapGenerator(model)
509
- saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
510
-
511
- # Convert saliency maps to numpy arrays
512
- saliency_maps = saliency_maps.cpu().numpy()
513
-
514
- N = len(images)
515
-
516
- dst = os.path.join(src, 'saliency_maps')
517
-
518
- for i in range(N):
519
- fig, axes = plt.subplots(1, 3, figsize=(20, 5))
520
-
521
- # Original image
522
- axes[0].imshow(images[i])
523
- axes[0].axis('off')
524
- if class_names:
525
- axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
526
-
527
- # Saliency Map
528
- axes[1].imshow(saliency_maps[i, 0], cmap='hot')
529
- axes[1].axis('off')
530
- axes[1].set_title("Saliency Map")
531
-
532
- # Overlay
533
- overlay = np.array(images[i])
534
- overlay = overlay / overlay.max()
535
- saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
536
- overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
537
- axes[2].imshow(overlay)
538
- axes[2].axis('off')
539
- axes[2].set_title("Overlay")
540
-
541
- plt.tight_layout()
542
- plt.show()
543
-
544
- # Save the saliency map if required
545
- if save_saliency:
546
- os.makedirs(dst, exist_ok=True)
547
- saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
548
- saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
549
-
550
- def visualize_grad_cam(src, model_path, target_layers=None, image_size=224, channels=[1, 2, 3], normalize=True, class_names=None, save_cam=False, save_dir='grad_cam'):
551
-
552
- from spacr.utils import GradCAM, preprocess_image, show_cam_on_image, recommend_target_layers
553
-
613
+ def generate_activation_map(settings):
614
+
615
+ from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database
616
+ from .utils import print_progress, save_settings, calculate_activation_correlations
617
+ from .io import TarImageDataset
618
+ from .settings import get_default_generate_activation_map_settings
619
+
620
+ torch.cuda.empty_cache()
621
+ gc.collect()
622
+
623
+ plt.clf()
554
624
  use_cuda = torch.cuda.is_available()
555
625
  device = torch.device("cuda" if use_cuda else "cpu")
556
626
 
557
- model = torch.load(model_path)
558
- model.to(device)
627
+ source_folder = os.path.dirname(os.path.dirname(settings['dataset']))
628
+ settings['src'] = source_folder
629
+ settings = get_default_generate_activation_map_settings(settings)
630
+ save_settings(settings, name=f"{settings['cam_type']}_settings", show=False)
559
631
 
560
- # If no target layers provided, recommend a target layer
561
- if target_layers is None:
562
- target_layers, all_layers = recommend_target_layers(model)
563
- print(f"No target layer provided. Using recommended layer: {target_layers[0]}")
564
- print("All possible target layers:")
565
- for layer in all_layers:
566
- print(layer)
632
+ if settings['model_type'] == 'maxvit' and settings['target_layer'] == None:
633
+ settings['target_layer'] = 'base_model.blocks.3.layers.1.layers.MBconv.layers.conv_b'
634
+ if settings['cam_type'] in ['saliency_image', 'saliency_channel']:
635
+ settings['target_layer'] = None
567
636
 
568
- grad_cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
569
-
570
- if save_cam and not os.path.exists(save_dir):
571
- os.makedirs(save_dir)
572
-
573
- images = []
574
- filenames = []
575
- for file in os.listdir(src):
576
- if not file.endswith('.png'):
577
- continue
578
- image_path = os.path.join(src, file)
579
- image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
580
- images.append(image)
581
- filenames.append(file)
637
+ # Set number of jobs for loading
638
+ n_jobs = settings['n_jobs']
639
+ if n_jobs is None:
640
+ n_jobs = max(1, cpu_count() - 4)
641
+
642
+ # Set transforms for images
643
+ transform = transforms.Compose([
644
+ transforms.ToTensor(),
645
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
646
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) if settings['normalize_input'] else None,
647
+ SelectChannels(settings['channels'])
648
+ ])
649
+
650
+ # Handle dataset path
651
+ if not os.path.exists(settings['dataset']):
652
+ print(f"Dataset not found at {settings['dataset']}")
653
+ return
582
654
 
583
- input_tensor = input_tensor.to(device)
584
- cam = grad_cam(input_tensor)
585
- cam_image = show_cam_on_image(np.array(image) / 255.0, cam)
655
+ # Load the model
656
+ model = torch.load(settings['model_path'])
657
+ model.to(device)
658
+ model.eval()
586
659
 
587
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
588
- ax[0].imshow(image)
589
- ax[0].axis('off')
590
- ax[0].set_title("Original Image")
591
- ax[1].imshow(cam_image)
592
- ax[1].axis('off')
593
- ax[1].set_title("Grad-CAM")
594
- plt.show()
660
+ # Create directory for saving activation maps if it does not exist
661
+ dataset_dir = os.path.dirname(settings['dataset'])
662
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
663
+ save_dir = os.path.join(dataset_dir, dataset_name, settings['cam_type'])
664
+ batch_grid_fldr = os.path.join(save_dir, 'batch_grids')
665
+
666
+ if settings['save']:
667
+ os.makedirs(save_dir, exist_ok=True)
668
+ print(f"Activation maps will be saved in: {save_dir}")
669
+
670
+ if settings['plot']:
671
+ os.makedirs(batch_grid_fldr, exist_ok=True)
672
+ print(f"Batch grid maps will be saved in: {batch_grid_fldr}")
673
+
674
+ # Load dataset
675
+ dataset = TarImageDataset(settings['dataset'], transform=transform)
676
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=settings['shuffle'], num_workers=n_jobs, pin_memory=True)
677
+
678
+ # Initialize generator based on cam_type
679
+ if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
680
+ cam_generator = GradCAMGenerator(model, target_layer=settings['target_layer'], cam_type=settings['cam_type'])
681
+ elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
682
+ cam_generator = SaliencyMapGenerator(model)
683
+
684
+ time_ls = []
685
+ for batch_idx, (inputs, filenames) in enumerate(data_loader):
686
+ start = time.time()
687
+ img_paths = []
688
+ inputs = inputs.to(device)
689
+
690
+ # Compute activation maps and predictions
691
+ if settings['cam_type'] in ['gradcam', 'gradcam_pp']:
692
+ activation_maps, predicted_classes = cam_generator.compute_gradcam_and_predictions(inputs)
693
+ elif settings['cam_type'] in ['saliency_image', 'saliency_channel']:
694
+ activation_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
695
+
696
+ # Move activation maps to CPU
697
+ activation_maps = activation_maps.cpu()
698
+
699
+ # Sum saliency maps for 'saliency_image' type
700
+ if settings['cam_type'] == 'saliency_image':
701
+ summed_activation_maps = []
702
+ for i in range(activation_maps.size(0)):
703
+ activation_map = activation_maps[i]
704
+ #print(f"1: {activation_map.shape}")
705
+ activation_map_sum = activation_map.sum(dim=0, keepdim=False)
706
+ #print(f"2: {activation_map.shape}")
707
+ activation_map_sum = np.squeeze(activation_map_sum, axis=0)
708
+ #print(f"3: {activation_map_sum.shape}")
709
+ summed_activation_maps.append(activation_map_sum)
710
+ activation_maps = torch.stack(summed_activation_maps)
711
+
712
+ # For plotting
713
+ if settings['plot']:
714
+ fig = cam_generator.plot_activation_grid(inputs, activation_maps, predicted_classes, overlay=settings['overlay'], normalize=settings['normalize'])
715
+ pdf_save_path = os.path.join(batch_grid_fldr,f"batch_{batch_idx}_grid.pdf")
716
+ fig.savefig(pdf_save_path, format='pdf')
717
+ print(f"Saved batch grid to {pdf_save_path}")
718
+ #plt.show()
719
+ display(fig)
720
+
721
+ for i in range(inputs.size(0)):
722
+ activation_map = activation_maps[i].detach().numpy()
723
+
724
+ if settings['cam_type'] in ['saliency_image', 'gradcam', 'gradcam_pp']:
725
+ #activation_map = activation_map.sum(axis=0)
726
+ activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min())
727
+ activation_map = (activation_map * 255).astype(np.uint8)
728
+ activation_image = Image.fromarray(activation_map, mode='L')
729
+
730
+ elif settings['cam_type'] == 'saliency_channel':
731
+ # Handle each channel separately and save as RGB
732
+ rgb_activation_map = np.zeros((activation_map.shape[1], activation_map.shape[2], 3), dtype=np.uint8)
733
+ for c in range(min(activation_map.shape[0], 3)): # Limit to 3 channels for RGB
734
+ channel_map = activation_map[c]
735
+ channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min())
736
+ rgb_activation_map[:, :, c] = (channel_map * 255).astype(np.uint8)
737
+ activation_image = Image.fromarray(rgb_activation_map, mode='RGB')
738
+
739
+ # Save activation maps
740
+ class_pred = predicted_classes[i].item()
741
+ parts = filenames[i].split('_')
742
+ plate = parts[0]
743
+ well = parts[1]
744
+ save_class_dir = os.path.join(save_dir, f'class_{class_pred}', str(plate), str(well))
745
+ os.makedirs(save_class_dir, exist_ok=True)
746
+ save_path = os.path.join(save_class_dir, f'{filenames[i]}')
747
+ if settings['save']:
748
+ activation_image.save(save_path)
749
+ img_paths.append(save_path)
750
+
751
+ if settings['save']:
752
+ activation_maps_to_database(img_paths, source_folder, settings)
753
+
754
+ if settings['correlation']:
755
+ df = calculate_activation_correlations(inputs, activation_maps, filenames, manders_thresholds=settings['manders_thresholds'])
756
+ if settings['plot']:
757
+ display(df)
758
+ if settings['save']:
759
+ activation_correlations_to_database(df, img_paths, source_folder, settings)
760
+
761
+ stop = time.time()
762
+ duration = stop - start
763
+ time_ls.append(duration)
764
+ files_processed = batch_idx * settings['batch_size']
765
+ files_to_process = len(data_loader) * settings['batch_size']
766
+ print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Activation Maps")
595
767
 
596
- if save_cam:
597
- cam_pil = Image.fromarray(cam_image)
598
- cam_pil.save(os.path.join(save_dir, f'grad_cam_{file}'))
768
+ torch.cuda.empty_cache()
769
+ gc.collect()
770
+ print("Activation map generation complete.")
599
771
 
600
772
  def visualize_classes(model, dtype, class_names, **kwargs):
601
773
 
602
- from spacr.utils import class_visualization
774
+ from .utils import class_visualization
603
775
 
604
776
  for target_y in range(2): # Assuming binary classification
605
777
  print(f"Visualizing class: {class_names[target_y]}")
@@ -732,7 +904,7 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
732
904
 
733
905
  def deep_spacr(settings={}):
734
906
  from .settings import deep_spacr_defaults
735
- from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
907
+ from .io import generate_training_dataset, generate_dataset
736
908
  from .utils import save_settings
737
909
 
738
910
  settings = deep_spacr_defaults(settings)
spacr/gui.py CHANGED
@@ -48,6 +48,7 @@ class MainApp(tk.Tk):
48
48
  }
49
49
 
50
50
  self.additional_gui_apps = {
51
+ "Convert": (lambda frame: initiate_root(self, 'convert'), "Convert images to Grayscale TIFs."),
51
52
  "Umap": (lambda frame: initiate_root(self, 'umap'), "Generate UMAP embeddings with datapoints represented as images."),
52
53
  "Train Cellpose": (lambda frame: initiate_root(self, 'train_cellpose'), "Train custom Cellpose models."),
53
54
  "ML Analyze": (lambda frame: initiate_root(self, 'ml_analyze'), "Machine learning analysis of data."),
@@ -56,6 +57,7 @@ class MainApp(tk.Tk):
56
57
  "Map Barcodes": (lambda frame: initiate_root(self, 'map_barcodes'), "Map barcodes to data."),
57
58
  "Regression": (lambda frame: initiate_root(self, 'regression'), "Perform regression analysis."),
58
59
  "Recruitment": (lambda frame: initiate_root(self, 'recruitment'), "Analyze recruitment data."),
60
+ "Activation": (lambda frame: initiate_root(self, 'activation'), "Generate activation maps of computer vision models and measure channel-activation correlation."),
59
61
  "Plaque": (lambda frame: initiate_root(self, 'analyze_plaques'), "Analyze plaque data.")
60
62
  }
61
63