spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/train.py
CHANGED
@@ -6,6 +6,7 @@ from torch.autograd import grad
|
|
6
6
|
from torch.optim.lr_scheduler import StepLR
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from IPython.display import display, clear_output
|
9
|
+
import difflib
|
9
10
|
|
10
11
|
from .logger import log_function_call
|
11
12
|
|
@@ -194,15 +195,21 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
194
195
|
|
195
196
|
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
197
|
|
197
|
-
from .io import
|
198
|
-
from .utils import pick_best_model
|
198
|
+
from .io import _save_settings, _copy_missclassified
|
199
|
+
from .utils import pick_best_model
|
199
200
|
from .core import generate_loaders
|
200
201
|
|
202
|
+
settings['src'] = src
|
203
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
204
|
+
settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
|
205
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
206
|
+
settings_df.to_csv(settings_csv, index=False)
|
207
|
+
|
201
208
|
if custom_model:
|
202
|
-
model = torch.load(custom_model_path)
|
209
|
+
model = torch.load(custom_model_path)
|
203
210
|
|
204
211
|
if settings['train']:
|
205
|
-
|
212
|
+
_save_settings(settings, src)
|
206
213
|
torch.cuda.empty_cache()
|
207
214
|
torch.cuda.memory.empty_cache()
|
208
215
|
gc.collect()
|
@@ -221,20 +228,23 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
221
228
|
validation_split=settings['val_split'],
|
222
229
|
pin_memory=settings['pin_memory'],
|
223
230
|
normalize=settings['normalize'],
|
224
|
-
|
231
|
+
channels=settings['channels'],
|
232
|
+
verbose=settings['verbose'])
|
233
|
+
|
225
234
|
|
226
235
|
if settings['test']:
|
227
236
|
test, _, plate_names_test = generate_loaders(src,
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
237
|
+
train_mode=settings['train_mode'],
|
238
|
+
mode='test',
|
239
|
+
image_size=settings['image_size'],
|
240
|
+
batch_size=settings['batch_size'],
|
241
|
+
classes=settings['classes'],
|
242
|
+
num_workers=settings['num_workers'],
|
243
|
+
validation_split=0.0,
|
244
|
+
pin_memory=settings['pin_memory'],
|
245
|
+
normalize=settings['normalize'],
|
246
|
+
channels=settings['channels'],
|
247
|
+
verbose=settings['verbose'])
|
238
248
|
if model == None:
|
239
249
|
model_path = pick_best_model(src+'/model')
|
240
250
|
print(f'Best model: {model_path}')
|
@@ -324,8 +334,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
324
334
|
None
|
325
335
|
"""
|
326
336
|
|
327
|
-
from .io import
|
328
|
-
from .utils import
|
337
|
+
from .io import _save_model, _save_progress
|
338
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model
|
329
339
|
|
330
340
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
331
341
|
|
@@ -341,6 +351,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
341
351
|
break
|
342
352
|
|
343
353
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
354
|
+
|
355
|
+
if model is None:
|
356
|
+
print(f'Model {model_type} not found')
|
357
|
+
return
|
358
|
+
|
344
359
|
model.to(device)
|
345
360
|
|
346
361
|
if optimizer_type == 'adamw':
|
@@ -415,10 +430,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
415
430
|
if schedule == 'step_lr':
|
416
431
|
scheduler.step()
|
417
432
|
|
418
|
-
|
433
|
+
_save_progress(dst, results_df, train_metrics_df)
|
419
434
|
clear_output(wait=True)
|
420
435
|
display(results_df)
|
421
|
-
|
436
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
422
437
|
|
423
438
|
if train_mode == 'irm':
|
424
439
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -488,7 +503,165 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
488
503
|
|
489
504
|
clear_output(wait=True)
|
490
505
|
display(results_df)
|
491
|
-
|
492
|
-
|
506
|
+
_save_progress(dst, results_df, train_metrics_df)
|
507
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
493
508
|
print(f'Saved model: {dst}')
|
494
|
-
return
|
509
|
+
return
|
510
|
+
|
511
|
+
def get_submodules(model, prefix=''):
|
512
|
+
submodules = []
|
513
|
+
for name, module in model.named_children():
|
514
|
+
full_name = prefix + ('.' if prefix else '') + name
|
515
|
+
submodules.append(full_name)
|
516
|
+
submodules.extend(get_submodules(module, full_name))
|
517
|
+
return submodules
|
518
|
+
|
519
|
+
def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
520
|
+
import torch
|
521
|
+
import os
|
522
|
+
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
523
|
+
import matplotlib.pyplot as plt
|
524
|
+
import numpy as np
|
525
|
+
from PIL import Image
|
526
|
+
|
527
|
+
use_cuda = torch.cuda.is_available()
|
528
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
529
|
+
|
530
|
+
# Load the entire model object
|
531
|
+
model = torch.load(model_path)
|
532
|
+
model.to(device)
|
533
|
+
|
534
|
+
# Create directory for saving saliency maps if it does not exist
|
535
|
+
if save_saliency and not os.path.exists(save_dir):
|
536
|
+
os.makedirs(save_dir)
|
537
|
+
|
538
|
+
# Collect all images and their tensors
|
539
|
+
images = []
|
540
|
+
input_tensors = []
|
541
|
+
filenames = []
|
542
|
+
for file in os.listdir(src):
|
543
|
+
image_path = os.path.join(src, file)
|
544
|
+
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
545
|
+
images.append(image)
|
546
|
+
input_tensors.append(input_tensor)
|
547
|
+
filenames.append(file)
|
548
|
+
|
549
|
+
input_tensors = torch.cat(input_tensors).to(device)
|
550
|
+
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
551
|
+
|
552
|
+
# Generate saliency maps
|
553
|
+
cam_generator = SaliencyMapGenerator(model)
|
554
|
+
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
555
|
+
|
556
|
+
# Plot images, saliency maps, and overlays
|
557
|
+
saliency_maps = saliency_maps.cpu().numpy()
|
558
|
+
N = len(images)
|
559
|
+
|
560
|
+
dst = os.path.join(src, 'saliency_maps')
|
561
|
+
os.makedirs(dst, exist_ok=True)
|
562
|
+
|
563
|
+
for i in range(N):
|
564
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
565
|
+
|
566
|
+
# Original image
|
567
|
+
axes[0].imshow(images[i])
|
568
|
+
axes[0].axis('off')
|
569
|
+
if class_names:
|
570
|
+
axes[0].set_title(class_names[class_labels[i].item()])
|
571
|
+
|
572
|
+
# Saliency map
|
573
|
+
axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
|
574
|
+
axes[1].axis('off')
|
575
|
+
|
576
|
+
# Overlay
|
577
|
+
overlay = np.array(images[i])
|
578
|
+
axes[2].imshow(overlay)
|
579
|
+
axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
|
580
|
+
axes[2].axis('off')
|
581
|
+
|
582
|
+
plt.tight_layout()
|
583
|
+
plt.show()
|
584
|
+
|
585
|
+
# Save the saliency map if required
|
586
|
+
if save_saliency:
|
587
|
+
saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
|
588
|
+
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
589
|
+
|
590
|
+
def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
|
591
|
+
import torch
|
592
|
+
import os
|
593
|
+
from spacr.utils import SaliencyMapGenerator, preprocess_image
|
594
|
+
import matplotlib.pyplot as plt
|
595
|
+
import numpy as np
|
596
|
+
from PIL import Image
|
597
|
+
|
598
|
+
use_cuda = torch.cuda.is_available()
|
599
|
+
device = torch.device("cuda" if use_cuda else "cpu")
|
600
|
+
|
601
|
+
# Load the entire model object
|
602
|
+
model = torch.load(model_path)
|
603
|
+
model.to(device)
|
604
|
+
|
605
|
+
# Create directory for saving saliency maps if it does not exist
|
606
|
+
if save_saliency and not os.path.exists(save_dir):
|
607
|
+
os.makedirs(save_dir)
|
608
|
+
|
609
|
+
# Collect all images and their tensors
|
610
|
+
images = []
|
611
|
+
input_tensors = []
|
612
|
+
filenames = []
|
613
|
+
for file in os.listdir(src):
|
614
|
+
if not file.endswith('.png'):
|
615
|
+
continue
|
616
|
+
image_path = os.path.join(src, file)
|
617
|
+
image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
|
618
|
+
images.append(image)
|
619
|
+
input_tensors.append(input_tensor)
|
620
|
+
filenames.append(file)
|
621
|
+
|
622
|
+
input_tensors = torch.cat(input_tensors).to(device)
|
623
|
+
class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
|
624
|
+
|
625
|
+
# Generate saliency maps
|
626
|
+
cam_generator = SaliencyMapGenerator(model)
|
627
|
+
saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
|
628
|
+
|
629
|
+
# Convert saliency maps to numpy arrays
|
630
|
+
saliency_maps = saliency_maps.cpu().numpy()
|
631
|
+
|
632
|
+
N = len(images)
|
633
|
+
|
634
|
+
dst = os.path.join(src, 'saliency_maps')
|
635
|
+
|
636
|
+
for i in range(N):
|
637
|
+
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
|
638
|
+
|
639
|
+
# Original image
|
640
|
+
axes[0].imshow(images[i])
|
641
|
+
axes[0].axis('off')
|
642
|
+
if class_names:
|
643
|
+
axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
|
644
|
+
|
645
|
+
# Saliency Map
|
646
|
+
axes[1].imshow(saliency_maps[i, 0], cmap='hot')
|
647
|
+
axes[1].axis('off')
|
648
|
+
axes[1].set_title("Saliency Map")
|
649
|
+
|
650
|
+
# Overlay
|
651
|
+
overlay = np.array(images[i])
|
652
|
+
overlay = overlay / overlay.max()
|
653
|
+
saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
|
654
|
+
overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
|
655
|
+
axes[2].imshow(overlay)
|
656
|
+
axes[2].axis('off')
|
657
|
+
axes[2].set_title("Overlay")
|
658
|
+
|
659
|
+
plt.tight_layout()
|
660
|
+
plt.show()
|
661
|
+
|
662
|
+
# Save the saliency map if required
|
663
|
+
if save_saliency:
|
664
|
+
os.makedirs(dst, exist_ok=True)
|
665
|
+
saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
|
666
|
+
saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
|
667
|
+
|