spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +245 -2494
- spacr/deep_spacr.py +316 -48
- spacr/gui.py +1 -0
- spacr/gui_core.py +74 -63
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +346 -6
- spacr/io.py +680 -141
- spacr/logger.py +28 -9
- spacr/measure.py +107 -95
- spacr/mediar.py +0 -3
- spacr/ml.py +1051 -0
- spacr/openai.py +37 -0
- spacr/plot.py +707 -20
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +134 -47
- spacr/sim.py +0 -2
- spacr/submodules.py +349 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +238 -0
- spacr/utils.py +419 -180
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/deep_spacr.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
import os, torch, time, gc, datetime
|
2
|
-
|
1
|
+
import os, torch, time, gc, datetime, cv2
|
3
2
|
torch.backends.cudnn.benchmark = True
|
4
3
|
|
5
4
|
import numpy as np
|
@@ -8,15 +7,146 @@ from torch.optim import Adagrad, AdamW
|
|
8
7
|
from torch.autograd import grad
|
9
8
|
from torch.optim.lr_scheduler import StepLR
|
10
9
|
import torch.nn.functional as F
|
11
|
-
from IPython.display import display, clear_output
|
12
10
|
import matplotlib.pyplot as plt
|
13
11
|
from PIL import Image
|
14
12
|
from sklearn.metrics import auc, precision_recall_curve
|
15
|
-
from
|
16
|
-
|
13
|
+
from IPython.display import display
|
14
|
+
from multiprocessing import cpu_count
|
15
|
+
|
16
|
+
from torchvision import transforms
|
17
|
+
from torch.utils.data import DataLoader
|
17
18
|
|
18
|
-
|
19
|
-
|
19
|
+
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
|
20
|
+
|
21
|
+
from .io import NoClassDataset
|
22
|
+
from .utils import print_progress
|
23
|
+
|
24
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25
|
+
|
26
|
+
if normalize:
|
27
|
+
transform = transforms.Compose([
|
28
|
+
transforms.ToTensor(),
|
29
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
30
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
31
|
+
else:
|
32
|
+
transform = transforms.Compose([
|
33
|
+
transforms.ToTensor(),
|
34
|
+
transforms.CenterCrop(size=(image_size, image_size))])
|
35
|
+
|
36
|
+
model = torch.load(model_path)
|
37
|
+
print(model)
|
38
|
+
|
39
|
+
print(f'Loading dataset in {src} with {len(src)} images')
|
40
|
+
dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
|
41
|
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs)
|
42
|
+
print(f'Loaded {len(src)} images')
|
43
|
+
|
44
|
+
result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
|
45
|
+
print(f'Results wil be saved in: {result_loc}')
|
46
|
+
|
47
|
+
model.eval()
|
48
|
+
model = model.to(device)
|
49
|
+
prediction_pos_probs = []
|
50
|
+
filenames_list = []
|
51
|
+
time_ls = []
|
52
|
+
with torch.no_grad():
|
53
|
+
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
54
|
+
start = time.time()
|
55
|
+
images = batch_images.to(torch.float).to(device)
|
56
|
+
outputs = model(images)
|
57
|
+
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
58
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
59
|
+
filenames_list.extend(filenames)
|
60
|
+
stop = time.time()
|
61
|
+
duration = stop - start
|
62
|
+
time_ls.append(duration)
|
63
|
+
files_processed = batch_idx*batch_size
|
64
|
+
files_to_process = len(data_loader)
|
65
|
+
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions")
|
66
|
+
|
67
|
+
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
68
|
+
df = pd.DataFrame(data, index=None)
|
69
|
+
df.to_csv(result_loc, index=True, header=True, mode='w')
|
70
|
+
torch.cuda.empty_cache()
|
71
|
+
torch.cuda.memory.empty_cache()
|
72
|
+
return df
|
73
|
+
|
74
|
+
def apply_model_to_tar(settings={}):
|
75
|
+
|
76
|
+
from .io import TarImageDataset
|
77
|
+
from .utils import process_vision_results, print_progress
|
78
|
+
|
79
|
+
if os.path.exists(settings['dataset']):
|
80
|
+
tar_path = settings['dataset']
|
81
|
+
else:
|
82
|
+
tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
|
83
|
+
model_path = settings['model_path']
|
84
|
+
|
85
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
86
|
+
if settings['normalize']:
|
87
|
+
transform = transforms.Compose([
|
88
|
+
transforms.ToTensor(),
|
89
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
90
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
91
|
+
else:
|
92
|
+
transform = transforms.Compose([
|
93
|
+
transforms.ToTensor(),
|
94
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
95
|
+
|
96
|
+
if settings['verbose']:
|
97
|
+
print(f"Loading model from {model_path}")
|
98
|
+
print(f"Loading dataset from {tar_path}")
|
99
|
+
|
100
|
+
model = torch.load(settings['model_path'])
|
101
|
+
|
102
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
103
|
+
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
104
|
+
|
105
|
+
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
106
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
107
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
108
|
+
dst = os.path.dirname(tar_path)
|
109
|
+
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
110
|
+
|
111
|
+
model.eval()
|
112
|
+
model = model.to(device)
|
113
|
+
|
114
|
+
if settings['verbose']:
|
115
|
+
print(model)
|
116
|
+
print(f'Generated dataset with {len(dataset)} images')
|
117
|
+
print(f'Generating loader from {len(data_loader)} batches')
|
118
|
+
print(f'Results wil be saved in: {result_loc}')
|
119
|
+
print(f'Model is in eval mode')
|
120
|
+
print(f'Model loaded to device')
|
121
|
+
|
122
|
+
prediction_pos_probs = []
|
123
|
+
filenames_list = []
|
124
|
+
time_ls = []
|
125
|
+
gc.collect()
|
126
|
+
with torch.no_grad():
|
127
|
+
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
128
|
+
start = time.time()
|
129
|
+
images = batch_images.to(torch.float).to(device)
|
130
|
+
outputs = model(images)
|
131
|
+
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
132
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
133
|
+
filenames_list.extend(filenames)
|
134
|
+
stop = time.time()
|
135
|
+
duration = stop - start
|
136
|
+
time_ls.append(duration)
|
137
|
+
files_processed = batch_idx*settings['batch_size']
|
138
|
+
files_to_process = len(data_loader)*settings['batch_size']
|
139
|
+
print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
|
140
|
+
|
141
|
+
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
142
|
+
df = pd.DataFrame(data, index=None)
|
143
|
+
df = process_vision_results(df, settings['score_threshold'])
|
144
|
+
|
145
|
+
df.to_csv(result_loc, index=True, header=True, mode='w')
|
146
|
+
print(f"Saved results to {result_loc}")
|
147
|
+
torch.cuda.empty_cache()
|
148
|
+
torch.cuda.memory.empty_cache()
|
149
|
+
return df
|
20
150
|
|
21
151
|
def evaluate_model_performance(model, loader, epoch, loss_type):
|
22
152
|
"""
|
@@ -118,7 +248,7 @@ def evaluate_model_performance(model, loader, epoch, loss_type):
|
|
118
248
|
|
119
249
|
loss /= len(loader)
|
120
250
|
data_dict = classification_metrics(all_labels, prediction_pos_probs)
|
121
|
-
data_dict['loss'] = loss
|
251
|
+
data_dict['loss'] = loss.item()
|
122
252
|
data_dict['epoch'] = epoch
|
123
253
|
data_dict['Accuracy'] = acc
|
124
254
|
|
@@ -175,7 +305,7 @@ def test_model_core(model, loader, loader_name, epoch, loss_type):
|
|
175
305
|
'class_1_probability':prediction_pos_probs})
|
176
306
|
|
177
307
|
loss /= len(loader)
|
178
|
-
data_df = classification_metrics(all_labels, prediction_pos_probs,
|
308
|
+
data_df = classification_metrics(all_labels, prediction_pos_probs, loss, epoch)
|
179
309
|
return data_df, prediction_pos_probs, all_labels, results_df
|
180
310
|
|
181
311
|
def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
|
@@ -201,9 +331,12 @@ def test_model_performance(loaders, model, loader_name_list, epoch, loss_type):
|
|
201
331
|
|
202
332
|
def train_test_model(settings):
|
203
333
|
|
204
|
-
from .io import
|
205
|
-
from .utils import pick_best_model
|
206
|
-
from .
|
334
|
+
from .io import _copy_missclassified
|
335
|
+
from .utils import pick_best_model, save_settings
|
336
|
+
from .io import generate_loaders
|
337
|
+
from .settings import get_train_test_model_settings
|
338
|
+
|
339
|
+
settings = get_train_test_model_settings(settings)
|
207
340
|
|
208
341
|
torch.cuda.empty_cache()
|
209
342
|
torch.cuda.memory.empty_cache()
|
@@ -221,7 +354,12 @@ def train_test_model(settings):
|
|
221
354
|
model = torch.load(settings['custom_model_path'])
|
222
355
|
|
223
356
|
if settings['train']:
|
224
|
-
|
357
|
+
if settings['train'] and settings['test']:
|
358
|
+
save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True)
|
359
|
+
elif settings['train'] is True:
|
360
|
+
save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True)
|
361
|
+
elif settings['test'] is True:
|
362
|
+
save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True)
|
225
363
|
|
226
364
|
if settings['train']:
|
227
365
|
train, val, train_fig = generate_loaders(src,
|
@@ -235,7 +373,6 @@ def train_test_model(settings):
|
|
235
373
|
normalize=settings['normalize'],
|
236
374
|
channels=settings['train_channels'],
|
237
375
|
augment=settings['augment'],
|
238
|
-
preload_batches=settings['preload_batches'],
|
239
376
|
verbose=settings['verbose'])
|
240
377
|
|
241
378
|
#train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
|
@@ -264,19 +401,19 @@ def train_test_model(settings):
|
|
264
401
|
channels=settings['train_channels'])
|
265
402
|
|
266
403
|
if settings['test']:
|
267
|
-
test, _,
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
404
|
+
test, _, train_fig = generate_loaders(src,
|
405
|
+
mode='test',
|
406
|
+
image_size=settings['image_size'],
|
407
|
+
batch_size=settings['batch_size'],
|
408
|
+
classes=settings['classes'],
|
409
|
+
n_jobs=settings['n_jobs'],
|
410
|
+
validation_split=0.0,
|
411
|
+
pin_memory=settings['pin_memory'],
|
412
|
+
normalize=settings['normalize'],
|
413
|
+
channels=settings['train_channels'],
|
414
|
+
augment=False,
|
415
|
+
verbose=settings['verbose'])
|
416
|
+
|
280
417
|
if model == None:
|
281
418
|
model_path = pick_best_model(src+'/model')
|
282
419
|
print(f'Best model: {model_path}')
|
@@ -308,7 +445,10 @@ def train_test_model(settings):
|
|
308
445
|
torch.cuda.memory.empty_cache()
|
309
446
|
gc.collect()
|
310
447
|
|
311
|
-
|
448
|
+
if settings['train']:
|
449
|
+
return model_path
|
450
|
+
if settings['test']:
|
451
|
+
return result_loc
|
312
452
|
|
313
453
|
def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False):
|
314
454
|
"""
|
@@ -355,11 +495,6 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
355
495
|
|
356
496
|
kwargs = {'n_jobs': n_jobs, 'pin_memory': True} if use_cuda else {}
|
357
497
|
|
358
|
-
#for idx, (images, labels, filenames) in enumerate(train_loaders):
|
359
|
-
# batch, chans, height, width = images.shape
|
360
|
-
# break
|
361
|
-
|
362
|
-
|
363
498
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose)
|
364
499
|
|
365
500
|
|
@@ -452,19 +587,21 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
452
587
|
if schedule == 'step_lr':
|
453
588
|
scheduler.step()
|
454
589
|
|
455
|
-
if
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
590
|
+
if accumulated_train_dicts and accumulated_val_dicts:
|
591
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
592
|
+
validation_df = pd.DataFrame(accumulated_val_dicts)
|
593
|
+
_save_progress(dst, train_df, validation_df)
|
594
|
+
accumulated_train_dicts, accumulated_val_dicts = [], []
|
595
|
+
|
596
|
+
elif accumulated_train_dicts:
|
597
|
+
train_df = pd.DataFrame(accumulated_train_dicts)
|
598
|
+
_save_progress(dst, train_df, None)
|
599
|
+
accumulated_train_dicts = []
|
600
|
+
elif accumulated_test_dicts:
|
601
|
+
test_df = pd.DataFrame(accumulated_test_dicts)
|
602
|
+
_save_progress(dst, test_df, None)
|
603
|
+
accumulated_test_dicts = []
|
604
|
+
|
468
605
|
batch_size = len(train_loaders)
|
469
606
|
duration = time.time() - start_time
|
470
607
|
time_ls.append(duration)
|
@@ -473,7 +610,138 @@ def train_model(dst, model_type, train_loaders, epochs=100, learning_rate=0.0001
|
|
473
610
|
|
474
611
|
return model, model_path
|
475
612
|
|
476
|
-
def visualize_saliency_map(
|
613
|
+
def visualize_saliency_map(settings):
|
614
|
+
from spacr.utils import SaliencyMapGenerator, print_progress
|
615
|
+
from spacr.io import TarImageDataset # Assuming you have a dataset class
|
616
|
+
from torchvision.utils import make_grid
|
617
|
+
|
618
|
+
use_cuda = torch.cuda.is_available()
|
619
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
620
|
+
|
621
|
+
# Set number of jobs for loading
|
622
|
+
if settings['n_jobs'] is None:
|
623
|
+
n_jobs = max(1, cpu_count() - 4)
|
624
|
+
else:
|
625
|
+
n_jobs = settings['n_jobs']
|
626
|
+
|
627
|
+
# Set transforms for images
|
628
|
+
if settings['normalize']:
|
629
|
+
transform = transforms.Compose([
|
630
|
+
transforms.ToTensor(),
|
631
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
632
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
633
|
+
else:
|
634
|
+
transform = transforms.Compose([
|
635
|
+
transforms.ToTensor(),
|
636
|
+
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
637
|
+
|
638
|
+
# Handle dataset path
|
639
|
+
if os.path.exists(settings['dataset']):
|
640
|
+
tar_path = settings['dataset']
|
641
|
+
else:
|
642
|
+
print(f"Dataset not found at {settings['dataset']}")
|
643
|
+
return
|
644
|
+
|
645
|
+
if settings.get('save', False):
|
646
|
+
if settings['dtype'] not in ['uint8', 'uint16']:
|
647
|
+
print("Invalid dtype in settings. Please use 'uint8' or 'uint16'.")
|
648
|
+
return
|
649
|
+
|
650
|
+
# Load the model
|
651
|
+
model = torch.load(settings['model_path'])
|
652
|
+
model.to(device)
|
653
|
+
model.eval() # Ensure the model is in evaluation mode
|
654
|
+
|
655
|
+
# Create directory for saving saliency maps if it does not exist
|
656
|
+
if settings.get('save', False):
|
657
|
+
dataset_dir = os.path.dirname(tar_path)
|
658
|
+
dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
|
659
|
+
save_dir = os.path.join(dataset_dir, dataset_name, 'saliency_maps')
|
660
|
+
os.makedirs(save_dir, exist_ok=True)
|
661
|
+
print(f"Saliency maps will be saved in: {save_dir}")
|
662
|
+
|
663
|
+
# Load dataset
|
664
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
665
|
+
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=n_jobs, pin_memory=True)
|
666
|
+
|
667
|
+
# Initialize SaliencyMapGenerator
|
668
|
+
cam_generator = SaliencyMapGenerator(model)
|
669
|
+
time_ls = []
|
670
|
+
|
671
|
+
for batch_idx, (inputs, filenames) in enumerate(data_loader):
|
672
|
+
start = time.time()
|
673
|
+
inputs = inputs.to(device)
|
674
|
+
|
675
|
+
saliency_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs)
|
676
|
+
|
677
|
+
if settings['saliency_mode'] not in ['mean', 'sum']:
|
678
|
+
print("To generate channel average or sum saliency maps set saliency_mode to 'mean' or 'sum', respectively.")
|
679
|
+
|
680
|
+
if settings['saliency_mode'] == 'mean':
|
681
|
+
saliency_maps = saliency_maps.mean(dim=1, keepdim=True)
|
682
|
+
|
683
|
+
elif settings['saliency_mode'] == 'sum':
|
684
|
+
saliency_maps = saliency_maps.sum(dim=1, keepdim=True)
|
685
|
+
|
686
|
+
# Example usage with the class
|
687
|
+
if settings.get('plot', False):
|
688
|
+
if settings['plot_mode'] not in ['mean', 'channel', '3-channel']:
|
689
|
+
print("Invalid plot_mode in settings. Please use 'mean', 'channel', or '3-channel'.")
|
690
|
+
return
|
691
|
+
else:
|
692
|
+
cam_generator.plot_saliency_grid(inputs, saliency_maps, predicted_classes, mode=settings['plot_mode'])
|
693
|
+
|
694
|
+
if settings.get('save', False):
|
695
|
+
for i in range(inputs.size(0)):
|
696
|
+
saliency_map = saliency_maps[i].detach().cpu().numpy()
|
697
|
+
|
698
|
+
# Check dtype in settings and normalize accordingly
|
699
|
+
if settings['dtype'] == 'uint16':
|
700
|
+
saliency_map = np.clip(saliency_map, 0, 1) * 65535
|
701
|
+
saliency_map = saliency_map.astype(np.uint16)
|
702
|
+
mode = 'I;16'
|
703
|
+
elif settings['dtype'] == 'uint8':
|
704
|
+
saliency_map = np.clip(saliency_map, 0, 1) * 255
|
705
|
+
saliency_map = saliency_map.astype(np.uint8)
|
706
|
+
mode = 'L' # Grayscale mode for uint8
|
707
|
+
|
708
|
+
# Get the class prediction (0 or 1)
|
709
|
+
class_pred = predicted_classes[i].item()
|
710
|
+
|
711
|
+
save_class_dir = os.path.join(save_dir, f'class_{class_pred}')
|
712
|
+
os.makedirs(save_class_dir, exist_ok=True)
|
713
|
+
save_path = os.path.join(save_class_dir, filenames[i])
|
714
|
+
|
715
|
+
# Handle different cases based on saliency_map dimensions
|
716
|
+
if saliency_map.ndim == 3: # Multi-channel case (C, H, W)
|
717
|
+
if saliency_map.shape[0] == 3: # RGB-like saliency map
|
718
|
+
saliency_image = Image.fromarray(np.moveaxis(saliency_map, 0, -1), mode="RGB") # Convert (C, H, W) to (H, W, C)
|
719
|
+
elif saliency_map.shape[0] == 1: # Single-channel case (1, H, W)
|
720
|
+
saliency_map = np.squeeze(saliency_map) # Remove the extra channel dimension
|
721
|
+
saliency_image = Image.fromarray(saliency_map, mode=mode) # Use grayscale mode for single-channel
|
722
|
+
else:
|
723
|
+
raise ValueError(f"Unexpected number of channels: {saliency_map.shape[0]}")
|
724
|
+
|
725
|
+
elif saliency_map.ndim == 2: # Single-channel case (H, W)
|
726
|
+
saliency_image = Image.fromarray(saliency_map, mode=mode) # Keep single channel (H, W)
|
727
|
+
|
728
|
+
else:
|
729
|
+
raise ValueError(f"Unexpected number of dimensions: {saliency_map.ndim}")
|
730
|
+
|
731
|
+
# Save the image
|
732
|
+
saliency_image.save(save_path)
|
733
|
+
|
734
|
+
|
735
|
+
stop = time.time()
|
736
|
+
duration = stop - start
|
737
|
+
time_ls.append(duration)
|
738
|
+
files_processed = batch_idx * settings['batch_size']
|
739
|
+
files_to_process = len(data_loader)
|
740
|
+
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Saliency Maps")
|
741
|
+
|
742
|
+
print("Saliency map generation complete.")
|
743
|
+
|
744
|
+
def visualize_saliency_map_v1(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
477
745
|
|
478
746
|
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
479
747
|
|
@@ -732,7 +1000,7 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
|
|
732
1000
|
|
733
1001
|
def deep_spacr(settings={}):
|
734
1002
|
from .settings import deep_spacr_defaults
|
735
|
-
from .
|
1003
|
+
from .io import generate_training_dataset, generate_dataset
|
736
1004
|
from .utils import save_settings
|
737
1005
|
|
738
1006
|
settings = deep_spacr_defaults(settings)
|
spacr/gui.py
CHANGED
@@ -48,6 +48,7 @@ class MainApp(tk.Tk):
|
|
48
48
|
}
|
49
49
|
|
50
50
|
self.additional_gui_apps = {
|
51
|
+
"Convert": (lambda frame: initiate_root(self, 'convert'), "Convert images to Grayscale TIFs."),
|
51
52
|
"Umap": (lambda frame: initiate_root(self, 'umap'), "Generate UMAP embeddings with datapoints represented as images."),
|
52
53
|
"Train Cellpose": (lambda frame: initiate_root(self, 'train_cellpose'), "Train custom Cellpose models."),
|
53
54
|
"ML Analyze": (lambda frame: initiate_root(self, 'ml_analyze'), "Machine learning analysis of data."),
|