spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
3
  import pandas as pd
4
4
  from collections import defaultdict
5
5
  import matplotlib.pyplot as plt
6
- from matplotlib.animation import FuncAnimation
7
6
  from IPython.display import display
8
7
  from IPython.display import Image as ipyimage
9
8
  import trackpy as tp
@@ -11,6 +10,7 @@ from btrack import datasets as btrack_datasets
11
10
  from skimage.measure import regionprops
12
11
  from scipy.signal import find_peaks
13
12
  from scipy.optimize import curve_fit
13
+ from scipy.integrate import trapz
14
14
  import matplotlib.pyplot as plt
15
15
 
16
16
  from .logger import log_function_call
@@ -590,6 +590,80 @@ def infected_vs_noninfected(result_df, measurement):
590
590
  plt.tight_layout()
591
591
  plt.show()
592
592
 
593
+ def save_figure(fig, src, figure_number):
594
+ source = os.path.dirname(src)
595
+ results_fldr = os.path.join(source,'results')
596
+ os.makedirs(results_fldr, exist_ok=True)
597
+ fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
598
+ fig.savefig(fig_loc)
599
+ print(f'Saved figure:{fig_loc}')
600
+
601
+ def save_results_dataframe(df, src, results_name):
602
+ source = os.path.dirname(src)
603
+ results_fldr = os.path.join(source,'results')
604
+ os.makedirs(results_fldr, exist_ok=True)
605
+ csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
606
+ df.to_csv(csv_loc, index=True)
607
+ print(f'Saved results:{csv_loc}')
608
+
609
+ def summarize_per_well(peak_details_df):
610
+ # Step 1: Split the 'ID' column
611
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
612
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
613
+
614
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
615
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
616
+
617
+ # Filter entries where 'amplitude' is not null
618
+ filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
619
+
620
+ # Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
621
+ numeric_cols = filtered_df.select_dtypes(include=['number']).columns
622
+
623
+ # Step 3: Calculate summary statistics
624
+ summary_df = filtered_df.groupby('well_ID').agg(
625
+ peaks_per_well=('ID', 'size'),
626
+ unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
627
+ **{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
628
+ ).reset_index()
629
+
630
+ # Step 3: Calculate summary statistics
631
+ summary_df_2 = peak_details_df.groupby('well_ID').agg(
632
+ cells_per_well=('object_number', 'nunique'),
633
+ ).reset_index()
634
+
635
+ summary_df['cells_per_well'] = summary_df_2['cells_per_well']
636
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
637
+
638
+ return summary_df
639
+
640
+ def summarize_per_well_inf_non_inf(peak_details_df):
641
+ # Step 1: Split the 'ID' column
642
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
643
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
644
+
645
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
646
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
647
+
648
+ # Assume 'pathogen_count' indicates infection if > 0
649
+ # Add an 'infected_status' column to classify cells
650
+ peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
651
+
652
+ # Preparation for Step 3: Identify numeric columns for averaging
653
+ numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
654
+
655
+ # Step 3: Calculate summary statistics
656
+ summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
657
+ cells_per_well=('object_number', 'nunique'),
658
+ peaks_per_well=('ID', 'size'),
659
+ **{col: (col, 'mean') for col in numeric_cols}
660
+ ).reset_index()
661
+
662
+ # Calculate peaks per cell
663
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
664
+
665
+ return summary_df
666
+
593
667
  def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
594
668
  # Load data
595
669
  conn = sqlite3.connect(db_loc)
@@ -626,7 +700,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
626
700
  cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
627
701
 
628
702
  df = cell_df.copy()
629
-
703
+
630
704
  # Fit exponential decay model to all scaled fluorescence data
631
705
  try:
632
706
  params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
@@ -653,11 +727,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
653
727
  if verbose:
654
728
  print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
655
729
 
656
- if not len(group) <= threshold:
730
+ if len(group) <= threshold:
657
731
  transience_removed += 1
732
+ if verbose:
733
+ print(f'removed group {unique_id} due to transience')
658
734
  continue
659
735
 
660
736
  size_diff = group[size_filter].std() / group[size_filter].mean()
737
+
661
738
  if size_diff <= fluctuation_threshold:
662
739
  group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
663
740
  corrected_dfs.append(group)
@@ -665,12 +742,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
665
742
  # Detect peaks
666
743
  peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
667
744
 
745
+ # Set values < 0 to 0
746
+ group_filtered = group.copy()
747
+ group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
748
+ above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
749
+ auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
750
+ is_infected = (group['parasite_count'] > 0).any()
751
+
752
+ if is_infected:
753
+ is_infected = 1
754
+ else:
755
+ is_infected = 0
756
+
757
+ if len(peaks) == 0:
758
+ peak_details_list.append({
759
+ 'ID': unique_id,
760
+ 'plate': group['plate'].iloc[0],
761
+ 'row': group['row'].iloc[0],
762
+ 'column': group['column'].iloc[0],
763
+ 'field': group['field'].iloc[0],
764
+ 'object_number': group['object_number'].iloc[0],
765
+ 'time': np.nan, # The time of the peak
766
+ 'amplitude': np.nan,
767
+ 'delta': np.nan,
768
+ 'AUC': auc,
769
+ 'AUC_positive': above_zero_auc,
770
+ 'AUC_peak': np.nan,
771
+ 'infected': is_infected
772
+ })
773
+
668
774
  # Inside the for loop where peaks are detected
669
775
  for i, peak in enumerate(peaks):
670
- amplitude = properties['peak_heights'][i] # Correctly access the amplitude
671
- peak_time = group['time'].iloc[peak] # Time corresponding to the peak
672
- # Get the number of pathogens in the cell at the time of the peak
776
+
777
+ amplitude = properties['peak_heights'][i]
778
+ peak_time = group['time'].iloc[peak]
673
779
  pathogen_count_at_peak = group['parasite_count'].iloc[peak]
780
+
781
+ start_idx = max(peak - 1, 0)
782
+ end_idx = min(peak + 1, len(group) - 1)
783
+
784
+ # Using indices to slice for AUC calculation
785
+ peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
786
+ peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
787
+ peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
788
+
674
789
  peak_details_list.append({
675
790
  'ID': unique_id,
676
791
  'plate': group['plate'].iloc[0],
@@ -681,6 +796,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
681
796
  'time': peak_time, # The time of the peak
682
797
  'amplitude': amplitude,
683
798
  'delta': group['delta_' + measurement].iloc[peak],
799
+ 'AUC': auc,
800
+ 'AUC_positive': above_zero_auc,
801
+ 'AUC_peak': peak_auc,
684
802
  'infected': pathogen_count_at_peak
685
803
  })
686
804
  else:
@@ -697,7 +815,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
697
815
  return
698
816
 
699
817
  peak_details_df = pd.DataFrame(peak_details_list)
700
-
818
+ summary_df = summarize_per_well(peak_details_df)
819
+ summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
820
+
821
+ save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
822
+ save_results_dataframe(df=result_df, src=db_loc, results_name='results')
823
+ save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
824
+ save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
825
+
701
826
  # Plotting
702
827
  fig, ax = plt.subplots(figsize=(10, 8))
703
828
  sampled_groups = result_df['plate_row_column_field_object'].unique()
@@ -714,12 +839,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
714
839
  ax.set_xlabel('Time')
715
840
  ax.set_ylabel('Normalized Delta ' + measurement)
716
841
  plt.tight_layout()
842
+
717
843
  plt.show()
844
+
845
+ save_figure(fig, src=db_loc, figure_number=1)
718
846
 
719
847
  if pathogen:
720
848
  infected_vs_noninfected(result_df, measurement)
849
+ save_figure(fig, src=db_loc, figure_number=2)
721
850
 
722
- # Identifying cells with and without infection
851
+ # Identify cells with and without pathogens
723
852
  infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
724
853
  noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
725
854
 
@@ -733,5 +862,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
733
862
 
734
863
  print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
735
864
  print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
736
-
737
- return result_df, peak_details_df
865
+ print(f'done')
866
+ return result_df, peak_details_df, fig
spacr/train.py CHANGED
@@ -6,6 +6,7 @@ from torch.autograd import grad
6
6
  from torch.optim.lr_scheduler import StepLR
7
7
  import torch.nn.functional as F
8
8
  from IPython.display import display, clear_output
9
+ import difflib
9
10
 
10
11
  from .logger import log_function_call
11
12
 
@@ -194,8 +195,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
194
195
 
195
196
  def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
197
 
197
- from .io import save_settings, _copy_missclassified
198
- from .utils import pick_best_model, test_model_performance
198
+ from .io import _save_settings, _copy_missclassified
199
+ from .utils import pick_best_model
199
200
  from .core import generate_loaders
200
201
 
201
202
  settings['src'] = src
@@ -208,7 +209,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
208
209
  model = torch.load(custom_model_path)
209
210
 
210
211
  if settings['train']:
211
- save_settings(settings, src)
212
+ _save_settings(settings, src)
212
213
  torch.cuda.empty_cache()
213
214
  torch.cuda.memory.empty_cache()
214
215
  gc.collect()
@@ -227,20 +228,23 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
227
228
  validation_split=settings['val_split'],
228
229
  pin_memory=settings['pin_memory'],
229
230
  normalize=settings['normalize'],
230
- verbose=settings['verbose'])
231
+ channels=settings['channels'],
232
+ verbose=settings['verbose'])
233
+
231
234
 
232
235
  if settings['test']:
233
236
  test, _, plate_names_test = generate_loaders(src,
234
- train_mode=settings['train_mode'],
235
- mode='test',
236
- image_size=settings['image_size'],
237
- batch_size=settings['batch_size'],
238
- classes=settings['classes'],
239
- num_workers=settings['num_workers'],
240
- validation_split=0.0,
241
- pin_memory=settings['pin_memory'],
242
- normalize=settings['normalize'],
243
- verbose=settings['verbose'])
237
+ train_mode=settings['train_mode'],
238
+ mode='test',
239
+ image_size=settings['image_size'],
240
+ batch_size=settings['batch_size'],
241
+ classes=settings['classes'],
242
+ num_workers=settings['num_workers'],
243
+ validation_split=0.0,
244
+ pin_memory=settings['pin_memory'],
245
+ normalize=settings['normalize'],
246
+ channels=settings['channels'],
247
+ verbose=settings['verbose'])
244
248
  if model == None:
245
249
  model_path = pick_best_model(src+'/model')
246
250
  print(f'Best model: {model_path}')
@@ -330,8 +334,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
330
334
  None
331
335
  """
332
336
 
333
- from .io import save_model, save_progress
334
- from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
337
+ from .io import _save_model, _save_progress
338
+ from .utils import compute_irm_penalty, calculate_loss, choose_model
335
339
 
336
340
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
337
341
 
@@ -347,6 +351,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
347
351
  break
348
352
 
349
353
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
354
+
355
+ if model is None:
356
+ print(f'Model {model_type} not found')
357
+ return
358
+
350
359
  model.to(device)
351
360
 
352
361
  if optimizer_type == 'adamw':
@@ -421,10 +430,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
421
430
  if schedule == 'step_lr':
422
431
  scheduler.step()
423
432
 
424
- save_progress(dst, results_df, train_metrics_df)
433
+ _save_progress(dst, results_df, train_metrics_df)
425
434
  clear_output(wait=True)
426
435
  display(results_df)
427
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
436
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
428
437
 
429
438
  if train_mode == 'irm':
430
439
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -494,7 +503,165 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
494
503
 
495
504
  clear_output(wait=True)
496
505
  display(results_df)
497
- save_progress(dst, results_df, train_metrics_df)
498
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
506
+ _save_progress(dst, results_df, train_metrics_df)
507
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
499
508
  print(f'Saved model: {dst}')
500
- return
509
+ return
510
+
511
+ def get_submodules(model, prefix=''):
512
+ submodules = []
513
+ for name, module in model.named_children():
514
+ full_name = prefix + ('.' if prefix else '') + name
515
+ submodules.append(full_name)
516
+ submodules.extend(get_submodules(module, full_name))
517
+ return submodules
518
+
519
+ def visualize_model_attention_v2(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
520
+ import torch
521
+ import os
522
+ from spacr.utils import SaliencyMapGenerator, preprocess_image
523
+ import matplotlib.pyplot as plt
524
+ import numpy as np
525
+ from PIL import Image
526
+
527
+ use_cuda = torch.cuda.is_available()
528
+ device = torch.device("cuda" if use_cuda else "cpu")
529
+
530
+ # Load the entire model object
531
+ model = torch.load(model_path)
532
+ model.to(device)
533
+
534
+ # Create directory for saving saliency maps if it does not exist
535
+ if save_saliency and not os.path.exists(save_dir):
536
+ os.makedirs(save_dir)
537
+
538
+ # Collect all images and their tensors
539
+ images = []
540
+ input_tensors = []
541
+ filenames = []
542
+ for file in os.listdir(src):
543
+ image_path = os.path.join(src, file)
544
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
545
+ images.append(image)
546
+ input_tensors.append(input_tensor)
547
+ filenames.append(file)
548
+
549
+ input_tensors = torch.cat(input_tensors).to(device)
550
+ class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
551
+
552
+ # Generate saliency maps
553
+ cam_generator = SaliencyMapGenerator(model)
554
+ saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
555
+
556
+ # Plot images, saliency maps, and overlays
557
+ saliency_maps = saliency_maps.cpu().numpy()
558
+ N = len(images)
559
+
560
+ dst = os.path.join(src, 'saliency_maps')
561
+ os.makedirs(dst, exist_ok=True)
562
+
563
+ for i in range(N):
564
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
565
+
566
+ # Original image
567
+ axes[0].imshow(images[i])
568
+ axes[0].axis('off')
569
+ if class_names:
570
+ axes[0].set_title(class_names[class_labels[i].item()])
571
+
572
+ # Saliency map
573
+ axes[1].imshow(saliency_maps[i], cmap=plt.cm.hot)
574
+ axes[1].axis('off')
575
+
576
+ # Overlay
577
+ overlay = np.array(images[i])
578
+ axes[2].imshow(overlay)
579
+ axes[2].imshow(saliency_maps[i], cmap='jet', alpha=0.5)
580
+ axes[2].axis('off')
581
+
582
+ plt.tight_layout()
583
+ plt.show()
584
+
585
+ # Save the saliency map if required
586
+ if save_saliency:
587
+ saliency_image = Image.fromarray((saliency_maps[i] * 255).astype(np.uint8))
588
+ saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
589
+
590
+ def visualize_model_attention(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
591
+ import torch
592
+ import os
593
+ from spacr.utils import SaliencyMapGenerator, preprocess_image
594
+ import matplotlib.pyplot as plt
595
+ import numpy as np
596
+ from PIL import Image
597
+
598
+ use_cuda = torch.cuda.is_available()
599
+ device = torch.device("cuda" if use_cuda else "cpu")
600
+
601
+ # Load the entire model object
602
+ model = torch.load(model_path)
603
+ model.to(device)
604
+
605
+ # Create directory for saving saliency maps if it does not exist
606
+ if save_saliency and not os.path.exists(save_dir):
607
+ os.makedirs(save_dir)
608
+
609
+ # Collect all images and their tensors
610
+ images = []
611
+ input_tensors = []
612
+ filenames = []
613
+ for file in os.listdir(src):
614
+ if not file.endswith('.png'):
615
+ continue
616
+ image_path = os.path.join(src, file)
617
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
618
+ images.append(image)
619
+ input_tensors.append(input_tensor)
620
+ filenames.append(file)
621
+
622
+ input_tensors = torch.cat(input_tensors).to(device)
623
+ class_labels = torch.zeros(input_tensors.size(0), dtype=torch.long).to(device) # Replace with actual class labels if available
624
+
625
+ # Generate saliency maps
626
+ cam_generator = SaliencyMapGenerator(model)
627
+ saliency_maps = cam_generator.compute_saliency_maps(input_tensors, class_labels)
628
+
629
+ # Convert saliency maps to numpy arrays
630
+ saliency_maps = saliency_maps.cpu().numpy()
631
+
632
+ N = len(images)
633
+
634
+ dst = os.path.join(src, 'saliency_maps')
635
+
636
+ for i in range(N):
637
+ fig, axes = plt.subplots(1, 3, figsize=(20, 5))
638
+
639
+ # Original image
640
+ axes[0].imshow(images[i])
641
+ axes[0].axis('off')
642
+ if class_names:
643
+ axes[0].set_title(f"Class: {class_names[class_labels[i].item()]}")
644
+
645
+ # Saliency Map
646
+ axes[1].imshow(saliency_maps[i, 0], cmap='hot')
647
+ axes[1].axis('off')
648
+ axes[1].set_title("Saliency Map")
649
+
650
+ # Overlay
651
+ overlay = np.array(images[i])
652
+ overlay = overlay / overlay.max()
653
+ saliency_map_rgb = np.stack([saliency_maps[i, 0]] * 3, axis=-1) # Convert saliency map to RGB
654
+ overlay = (overlay * 0.5 + saliency_map_rgb * 0.5).clip(0, 1)
655
+ axes[2].imshow(overlay)
656
+ axes[2].axis('off')
657
+ axes[2].set_title("Overlay")
658
+
659
+ plt.tight_layout()
660
+ plt.show()
661
+
662
+ # Save the saliency map if required
663
+ if save_saliency:
664
+ os.makedirs(dst, exist_ok=True)
665
+ saliency_image = Image.fromarray((saliency_maps[i, 0] * 255).astype(np.uint8))
666
+ saliency_image.save(os.path.join(dst, f'saliency_{filenames[i]}'))
667
+