spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -11,6 +11,7 @@ from btrack import datasets as btrack_datasets
11
11
  from skimage.measure import regionprops
12
12
  from scipy.signal import find_peaks
13
13
  from scipy.optimize import curve_fit
14
+ from scipy.integrate import trapz
14
15
  import matplotlib.pyplot as plt
15
16
 
16
17
  from .logger import log_function_call
@@ -590,6 +591,80 @@ def infected_vs_noninfected(result_df, measurement):
590
591
  plt.tight_layout()
591
592
  plt.show()
592
593
 
594
+ def save_figure(fig, src, figure_number):
595
+ source = os.path.dirname(src)
596
+ results_fldr = os.path.join(source,'results')
597
+ os.makedirs(results_fldr, exist_ok=True)
598
+ fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
599
+ fig.savefig(fig_loc)
600
+ print(f'Saved figure:{fig_loc}')
601
+
602
+ def save_results_dataframe(df, src, results_name):
603
+ source = os.path.dirname(src)
604
+ results_fldr = os.path.join(source,'results')
605
+ os.makedirs(results_fldr, exist_ok=True)
606
+ csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
607
+ df.to_csv(csv_loc, index=True)
608
+ print(f'Saved results:{csv_loc}')
609
+
610
+ def summarize_per_well(peak_details_df):
611
+ # Step 1: Split the 'ID' column
612
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
613
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
614
+
615
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
616
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
617
+
618
+ # Filter entries where 'amplitude' is not null
619
+ filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
620
+
621
+ # Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
622
+ numeric_cols = filtered_df.select_dtypes(include=['number']).columns
623
+
624
+ # Step 3: Calculate summary statistics
625
+ summary_df = filtered_df.groupby('well_ID').agg(
626
+ peaks_per_well=('ID', 'size'),
627
+ unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
628
+ **{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
629
+ ).reset_index()
630
+
631
+ # Step 3: Calculate summary statistics
632
+ summary_df_2 = peak_details_df.groupby('well_ID').agg(
633
+ cells_per_well=('object_number', 'nunique'),
634
+ ).reset_index()
635
+
636
+ summary_df['cells_per_well'] = summary_df_2['cells_per_well']
637
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
638
+
639
+ return summary_df
640
+
641
+ def summarize_per_well_inf_non_inf(peak_details_df):
642
+ # Step 1: Split the 'ID' column
643
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
644
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
645
+
646
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
647
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
648
+
649
+ # Assume 'pathogen_count' indicates infection if > 0
650
+ # Add an 'infected_status' column to classify cells
651
+ peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
652
+
653
+ # Preparation for Step 3: Identify numeric columns for averaging
654
+ numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
655
+
656
+ # Step 3: Calculate summary statistics
657
+ summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
658
+ cells_per_well=('object_number', 'nunique'),
659
+ peaks_per_well=('ID', 'size'),
660
+ **{col: (col, 'mean') for col in numeric_cols}
661
+ ).reset_index()
662
+
663
+ # Calculate peaks per cell
664
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
665
+
666
+ return summary_df
667
+
593
668
  def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
594
669
  # Load data
595
670
  conn = sqlite3.connect(db_loc)
@@ -626,7 +701,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
626
701
  cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
627
702
 
628
703
  df = cell_df.copy()
629
-
704
+
630
705
  # Fit exponential decay model to all scaled fluorescence data
631
706
  try:
632
707
  params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
@@ -653,11 +728,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
653
728
  if verbose:
654
729
  print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
655
730
 
656
- if not len(group) <= threshold:
731
+ if len(group) <= threshold:
657
732
  transience_removed += 1
733
+ if verbose:
734
+ print(f'removed group {unique_id} due to transience')
658
735
  continue
659
736
 
660
737
  size_diff = group[size_filter].std() / group[size_filter].mean()
738
+
661
739
  if size_diff <= fluctuation_threshold:
662
740
  group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
663
741
  corrected_dfs.append(group)
@@ -665,12 +743,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
665
743
  # Detect peaks
666
744
  peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
667
745
 
746
+ # Set values < 0 to 0
747
+ group_filtered = group.copy()
748
+ group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
749
+ above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
750
+ auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
751
+ is_infected = (group['parasite_count'] > 0).any()
752
+
753
+ if is_infected:
754
+ is_infected = 1
755
+ else:
756
+ is_infected = 0
757
+
758
+ if len(peaks) == 0:
759
+ peak_details_list.append({
760
+ 'ID': unique_id,
761
+ 'plate': group['plate'].iloc[0],
762
+ 'row': group['row'].iloc[0],
763
+ 'column': group['column'].iloc[0],
764
+ 'field': group['field'].iloc[0],
765
+ 'object_number': group['object_number'].iloc[0],
766
+ 'time': np.nan, # The time of the peak
767
+ 'amplitude': np.nan,
768
+ 'delta': np.nan,
769
+ 'AUC': auc,
770
+ 'AUC_positive': above_zero_auc,
771
+ 'AUC_peak': np.nan,
772
+ 'infected': is_infected
773
+ })
774
+
668
775
  # Inside the for loop where peaks are detected
669
776
  for i, peak in enumerate(peaks):
670
- amplitude = properties['peak_heights'][i] # Correctly access the amplitude
671
- peak_time = group['time'].iloc[peak] # Time corresponding to the peak
672
- # Get the number of pathogens in the cell at the time of the peak
777
+
778
+ amplitude = properties['peak_heights'][i]
779
+ peak_time = group['time'].iloc[peak]
673
780
  pathogen_count_at_peak = group['parasite_count'].iloc[peak]
781
+
782
+ start_idx = max(peak - 1, 0)
783
+ end_idx = min(peak + 1, len(group) - 1)
784
+
785
+ # Using indices to slice for AUC calculation
786
+ peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
787
+ peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
788
+ peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
789
+
674
790
  peak_details_list.append({
675
791
  'ID': unique_id,
676
792
  'plate': group['plate'].iloc[0],
@@ -681,6 +797,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
681
797
  'time': peak_time, # The time of the peak
682
798
  'amplitude': amplitude,
683
799
  'delta': group['delta_' + measurement].iloc[peak],
800
+ 'AUC': auc,
801
+ 'AUC_positive': above_zero_auc,
802
+ 'AUC_peak': peak_auc,
684
803
  'infected': pathogen_count_at_peak
685
804
  })
686
805
  else:
@@ -697,7 +816,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
697
816
  return
698
817
 
699
818
  peak_details_df = pd.DataFrame(peak_details_list)
700
-
819
+ summary_df = summarize_per_well(peak_details_df)
820
+ summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
821
+
822
+ save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
823
+ save_results_dataframe(df=result_df, src=db_loc, results_name='results')
824
+ save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
825
+ save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
826
+
701
827
  # Plotting
702
828
  fig, ax = plt.subplots(figsize=(10, 8))
703
829
  sampled_groups = result_df['plate_row_column_field_object'].unique()
@@ -714,12 +840,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
714
840
  ax.set_xlabel('Time')
715
841
  ax.set_ylabel('Normalized Delta ' + measurement)
716
842
  plt.tight_layout()
843
+
717
844
  plt.show()
845
+
846
+ save_figure(fig, src=db_loc, figure_number=1)
718
847
 
719
848
  if pathogen:
720
849
  infected_vs_noninfected(result_df, measurement)
850
+ save_figure(fig, src=db_loc, figure_number=2)
721
851
 
722
- # Identifying cells with and without infection
852
+ # Identify cells with and without pathogens
723
853
  infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
724
854
  noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
725
855
 
@@ -733,5 +863,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
733
863
 
734
864
  print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
735
865
  print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
736
-
737
- return result_df, peak_details_df
866
+ print(f'done')
867
+ return result_df, peak_details_df, fig
spacr/train.py CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
194
194
 
195
195
  def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
196
 
197
- from .io import save_settings, _copy_missclassified
198
- from .utils import pick_best_model, test_model_performance
197
+ from .io import _save_settings, _copy_missclassified
198
+ from .utils import pick_best_model
199
199
  from .core import generate_loaders
200
200
 
201
201
  settings['src'] = src
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
208
208
  model = torch.load(custom_model_path)
209
209
 
210
210
  if settings['train']:
211
- save_settings(settings, src)
211
+ _save_settings(settings, src)
212
212
  torch.cuda.empty_cache()
213
213
  torch.cuda.memory.empty_cache()
214
214
  gc.collect()
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
227
227
  validation_split=settings['val_split'],
228
228
  pin_memory=settings['pin_memory'],
229
229
  normalize=settings['normalize'],
230
- verbose=settings['verbose'])
230
+ channels=settings['channels'],
231
+ verbose=settings['verbose'])
232
+
231
233
 
232
234
  if settings['test']:
233
235
  test, _, plate_names_test = generate_loaders(src,
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
240
242
  validation_split=0.0,
241
243
  pin_memory=settings['pin_memory'],
242
244
  normalize=settings['normalize'],
245
+ channels=settings['channels'],
243
246
  verbose=settings['verbose'])
244
247
  if model == None:
245
248
  model_path = pick_best_model(src+'/model')
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
330
333
  None
331
334
  """
332
335
 
333
- from .io import save_model, save_progress
334
- from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
336
+ from .io import _save_model, _save_progress
337
+ from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
335
338
 
336
339
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
337
340
 
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
347
350
  break
348
351
 
349
352
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
353
+
354
+ if model is None:
355
+ print(f'Model {model_type} not found')
356
+ return
357
+
350
358
  model.to(device)
351
359
 
352
360
  if optimizer_type == 'adamw':
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
421
429
  if schedule == 'step_lr':
422
430
  scheduler.step()
423
431
 
424
- save_progress(dst, results_df, train_metrics_df)
432
+ _save_progress(dst, results_df, train_metrics_df)
425
433
  clear_output(wait=True)
426
434
  display(results_df)
427
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
435
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
428
436
 
429
437
  if train_mode == 'irm':
430
438
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
494
502
 
495
503
  clear_output(wait=True)
496
504
  display(results_df)
497
- save_progress(dst, results_df, train_metrics_df)
498
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
505
+ _save_progress(dst, results_df, train_metrics_df)
506
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
499
507
  print(f'Saved model: {dst}')
500
508
  return
spacr/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
4
  from cellpose import models as cp_models
@@ -46,12 +46,6 @@ from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet
46
46
 
47
47
  from .logger import log_function_call
48
48
 
49
- #from .io import _read_and_join_tables, _save_figure
50
- #from .timelapse import _btrack_track_cells, _trackpy_track_cells
51
- #from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
52
- #from .core import identify_masks
53
-
54
-
55
49
  def _gen_rgb_image(image, cahnnels):
56
50
  rgb_image = np.take(image, cahnnels, axis=-1)
57
51
  rgb_image = rgb_image.astype(float)
@@ -71,7 +65,8 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
71
65
  outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
72
66
 
73
67
  # Find and draw contours
74
- for j in np.unique(mask)[1:]:
68
+ for j in np.unique(mask):
69
+ #for j in np.unique(mask)[1:]:
75
70
  contours = find_contours(mask == j, 0.5)
76
71
  # Convert contours for OpenCV format and draw directly to optimize
77
72
  cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
@@ -781,8 +776,8 @@ def _get_object_settings(object_type, settings):
781
776
  object_settings = {}
782
777
 
783
778
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
784
- object_settings['minimum_size'] = (object_settings['diameter']**2)/5
785
- object_settings['maximum_size'] = (object_settings['diameter']**2)*3
779
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/4
780
+ object_settings['maximum_size'] = (object_settings['diameter']**2)*10
786
781
  object_settings['merge'] = False
787
782
  object_settings['resample'] = True
788
783
  object_settings['remove_border_objects'] = False
@@ -793,21 +788,22 @@ def _get_object_settings(object_type, settings):
793
788
  object_settings['model_name'] = 'cyto'
794
789
  else:
795
790
  object_settings['model_name'] = 'cyto2'
796
- object_settings['filter_size'] = True
797
- object_settings['filter_intensity'] = True
791
+ object_settings['filter_size'] = False
792
+ object_settings['filter_intensity'] = False
798
793
  object_settings['restore_type'] = settings.get('cell_restore_type', None)
799
794
 
800
795
  elif object_type == 'nucleus':
801
796
  object_settings['model_name'] = 'nuclei'
802
- object_settings['filter_size'] = True
803
- object_settings['filter_intensity'] = True
797
+ object_settings['filter_size'] = False
798
+ object_settings['filter_intensity'] = False
804
799
  object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
805
800
 
806
801
  elif object_type == 'pathogen':
807
802
  object_settings['model_name'] = 'cyto'
808
803
  object_settings['filter_size'] = True
809
- object_settings['filter_intensity'] = True
804
+ object_settings['filter_intensity'] = False
810
805
  object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
806
+ object_settings['merge'] = settings['merge_pathogens']
811
807
 
812
808
  else:
813
809
  print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
@@ -884,17 +880,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
884
880
 
885
881
  if not pathogen_channel is None:
886
882
  if not nucleus_channel is None:
887
- cellpose_channels['pathogen'] = [0,1]
883
+ if not pathogen_channel is None:
884
+ cellpose_channels['pathogen'] = [0,2]
885
+ else:
886
+ cellpose_channels['pathogen'] = [0,1]
888
887
  else:
889
888
  cellpose_channels['pathogen'] = [0,0]
890
889
 
891
890
  if not cell_channel is None:
892
891
  if not nucleus_channel is None:
893
- if not pathogen_channel is None:
894
- cellpose_channels['cell'] = [0,2]
895
- else:
896
- cellpose_channels['cell'] = [0,1]
897
- elif not pathogen_channel is None:
898
892
  cellpose_channels['cell'] = [0,1]
899
893
  else:
900
894
  cellpose_channels['cell'] = [0,0]
@@ -1069,7 +1063,7 @@ class Cache:
1069
1063
  cache (OrderedDict): The cache data structure.
1070
1064
  """
1071
1065
 
1072
- def _init__(self, max_size):
1066
+ def __init__(self, max_size):
1073
1067
  self.cache = OrderedDict()
1074
1068
  self.max_size = max_size
1075
1069
 
@@ -1100,7 +1094,7 @@ class ScaledDotProductAttention(nn.Module):
1100
1094
 
1101
1095
  """
1102
1096
 
1103
- def _init__(self, d_k):
1097
+ def __init__(self, d_k):
1104
1098
  super(ScaledDotProductAttention, self).__init__()
1105
1099
  self.d_k = d_k
1106
1100
 
@@ -1131,7 +1125,7 @@ class SelfAttention(nn.Module):
1131
1125
  d_k (int): Dimensionality of the key and query vectors.
1132
1126
  """
1133
1127
 
1134
- def _init__(self, in_channels, d_k):
1128
+ def __init__(self, in_channels, d_k):
1135
1129
  super(SelfAttention, self).__init__()
1136
1130
  self.W_q = nn.Linear(in_channels, d_k)
1137
1131
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1155,7 +1149,7 @@ class SelfAttention(nn.Module):
1155
1149
  return output
1156
1150
 
1157
1151
  class ScaledDotProductAttention(nn.Module):
1158
- def _init__(self, d_k):
1152
+ def __init__(self, d_k):
1159
1153
  """
1160
1154
  Initializes the ScaledDotProductAttention module.
1161
1155
 
@@ -1192,7 +1186,7 @@ class SelfAttention(nn.Module):
1192
1186
  in_channels (int): Number of input channels.
1193
1187
  d_k (int): Dimensionality of the key and query vectors.
1194
1188
  """
1195
- def _init__(self, in_channels, d_k):
1189
+ def __init__(self, in_channels, d_k):
1196
1190
  super(SelfAttention, self).__init__()
1197
1191
  self.W_q = nn.Linear(in_channels, d_k)
1198
1192
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1223,7 +1217,7 @@ class EarlyFusion(nn.Module):
1223
1217
  Args:
1224
1218
  in_channels (int): Number of input channels.
1225
1219
  """
1226
- def _init__(self, in_channels):
1220
+ def __init__(self, in_channels):
1227
1221
  super(EarlyFusion, self).__init__()
1228
1222
  self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1229
1223
 
@@ -1242,7 +1236,7 @@ class EarlyFusion(nn.Module):
1242
1236
 
1243
1237
  # Spatial Attention Mechanism
1244
1238
  class SpatialAttention(nn.Module):
1245
- def _init__(self, kernel_size=7):
1239
+ def __init__(self, kernel_size=7):
1246
1240
  """
1247
1241
  Initializes the SpatialAttention module.
1248
1242
 
@@ -1287,7 +1281,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1287
1281
  forward: Forward method for the module.
1288
1282
  """
1289
1283
 
1290
- def _init__(self, in_channels, out_channels):
1284
+ def __init__(self, in_channels, out_channels):
1291
1285
  super(MultiScaleBlockWithAttention, self).__init__()
1292
1286
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1293
1287
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
@@ -1320,7 +1314,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1320
1314
 
1321
1315
  # Final Classifier
1322
1316
  class CustomCellClassifier(nn.Module):
1323
- def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1317
+ def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1324
1318
  super(CustomCellClassifier, self).__init__()
1325
1319
  self.early_fusion = EarlyFusion(in_channels=3)
1326
1320
 
@@ -1349,7 +1343,7 @@ class CustomCellClassifier(nn.Module):
1349
1343
 
1350
1344
  #CNN and Transformer class, pick any Torch model.
1351
1345
  class TorchModel(nn.Module):
1352
- def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1346
+ def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1353
1347
  super(TorchModel, self).__init__()
1354
1348
  self.model_name = model_name
1355
1349
  self.use_checkpoint = use_checkpoint
@@ -1423,7 +1417,7 @@ class TorchModel(nn.Module):
1423
1417
  return logits
1424
1418
 
1425
1419
  class FocalLossWithLogits(nn.Module):
1426
- def _init__(self, alpha=1, gamma=2):
1420
+ def __init__(self, alpha=1, gamma=2):
1427
1421
  super(FocalLossWithLogits, self).__init__()
1428
1422
  self.alpha = alpha
1429
1423
  self.gamma = gamma
@@ -1435,7 +1429,7 @@ class FocalLossWithLogits(nn.Module):
1435
1429
  return focal_loss.mean()
1436
1430
 
1437
1431
  class ResNet(nn.Module):
1438
- def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1432
+ def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1439
1433
  super(ResNet, self).__init__()
1440
1434
 
1441
1435
  resnet_map = {
@@ -1788,25 +1782,24 @@ def annotate_predictions(csv_loc):
1788
1782
  df['cond'] = df.apply(assign_condition, axis=1)
1789
1783
  return df
1790
1784
 
1791
- def init_globals(counter_, lock_):
1785
+ def initiate_counter(counter_, lock_):
1792
1786
  global counter, lock
1793
1787
  counter = counter_
1794
1788
  lock = lock_
1795
1789
 
1796
- def add_images_to_tar(args):
1797
- global counter, lock, total_images
1798
- paths_chunk, tar_path = args
1790
+ def add_images_to_tar(paths_chunk, tar_path, total_images):
1799
1791
  with tarfile.open(tar_path, 'w') as tar:
1800
- for img_path in paths_chunk:
1792
+ for i, img_path in enumerate(paths_chunk):
1801
1793
  arcname = os.path.basename(img_path)
1802
1794
  try:
1803
1795
  tar.add(img_path, arcname=arcname)
1804
1796
  with lock:
1805
1797
  counter.value += 1
1806
- print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
1798
+ if counter.value % 100 == 0: # Print every 100 updates
1799
+ progress = (counter.value / total_images) * 100
1800
+ print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
1807
1801
  except FileNotFoundError:
1808
1802
  print(f"File not found: {img_path}")
1809
- return tar_path
1810
1803
 
1811
1804
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
1812
1805
  df['fraction'] = df['count']/df['well_read_sum']
@@ -2255,8 +2248,8 @@ def dice_coefficient(mask1, mask2):
2255
2248
  def extract_boundaries(mask, dilation_radius=1):
2256
2249
  binary_mask = (mask > 0).astype(np.uint8)
2257
2250
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2258
- dilated = binary_dilation(binary_mask, footprint=struct_elem)
2259
- eroded = binary_erosion(binary_mask, footprint=struct_elem)
2251
+ dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
2252
+ eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
2260
2253
  boundary = dilated ^ eroded
2261
2254
  return boundary
2262
2255
 
@@ -2669,6 +2662,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2669
2662
  print(f'Number of objects before filtration: {num_objects}')
2670
2663
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2671
2664
 
2665
+ if merge:
2666
+ mask = merge_touching_objects(mask, threshold=0.66)
2667
+ if plot and idx == 0:
2668
+ num_objects = mask_object_count(mask)
2669
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2670
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2671
+
2672
2672
  if filter_size:
2673
2673
  props = measure.regionprops_table(mask, properties=['label', 'area'])
2674
2674
  valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
@@ -2714,13 +2714,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2714
2714
  print(f'Number of objects after removing border objects, : {num_objects}')
2715
2715
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2716
2716
 
2717
- if merge:
2718
- mask = merge_touching_objects(mask, threshold=0.25)
2719
- if plot and idx == 0:
2720
- num_objects = mask_object_count(mask)
2721
- print(f'Number of objects after merging adjacent objects, : {num_objects}')
2722
- plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2723
-
2724
2717
  mask_stack.append(mask)
2725
2718
 
2726
2719
  return mask_stack
@@ -2789,6 +2782,7 @@ def _run_test_mode(src, regex, timelapse=False):
2789
2782
 
2790
2783
  # Prepare for random selection
2791
2784
  set_identifiers = list(images_by_set.keys())
2785
+ random.seed(42)
2792
2786
  random.shuffle(set_identifiers) # Randomize the order
2793
2787
 
2794
2788
  # Select a subset based on the test_images count
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.0.20
3
+ Version: 0.0.35
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -19,6 +19,7 @@ Requires-Dist: scikit-image >=0.22.0
19
19
  Requires-Dist: scikit-learn >=1.4.1
20
20
  Requires-Dist: seaborn >=0.13.2
21
21
  Requires-Dist: matplotlib >=3.8.3
22
+ Requires-Dist: shap >=0.45.0
22
23
  Requires-Dist: pillow >=10.2.0
23
24
  Requires-Dist: imageio >=2.34.0
24
25
  Requires-Dist: scipy >=1.12.0
@@ -31,6 +32,8 @@ Requires-Dist: IPython >=8.18.1
31
32
  Requires-Dist: opencv-python-headless >=4.9.0.80
32
33
  Requires-Dist: umap >=0.1.1
33
34
  Requires-Dist: ttkthemes >=3.2.2
35
+ Requires-Dist: xgboost >=2.0.3
36
+ Requires-Dist: PyWavelets >=1.6.0
34
37
  Requires-Dist: lxml >=5.1.0
35
38
  Provides-Extra: dev
36
39
  Requires-Dist: pytest >=3.9 ; extra == 'dev'
@@ -49,7 +52,7 @@ Requires-Dist: opencv-python-headless ; extra == 'headless'
49
52
  <tr>
50
53
  <td>
51
54
 
52
- Spatial phenotype analysis of crisp screens (SpaCr). A collection of functions for generating cellpose masks -> single object images and measurements -> annotation and classification of single object images. Spacr uses batch normalization to facilitate accurate segmentation of objects with low foreground representation.
55
+ Spatial phenotype analysis of CRISPR-Cas9 screens (SpaCr). The spatial organization of organelles and proteins within cells constitutes a key level of functional regulation. In the context of infectious disease, the spatial relationships between host cell structures and intracellular pathogens are critical to understand host clearance mechanisms and how pathogens evade them. Spacr is a Python-based software package for generating single cell image data for deep-learning sub-cellular/cellular phenotypic classification from pooled genetic CRISPR-Cas9 screens. Spacr provides a flexible toolset to extract single cell images and measurements from high content cell painting experiments, train deep-learning models to classify cellular/ subcellular phenotypes, simulate and analyze pooled CRISPR-Cas9 imaging screens.
53
56
 
54
57
  </td>
55
58
  <td>
@@ -0,0 +1,35 @@
1
+ spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
2
+ spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
3
+ spacr/alpha.py,sha256=1LUtTaeVHqcTMxoCMA7mlsNYyaR0KEaXglwXsUQLwKo,13776
4
+ spacr/annotate_app.py,sha256=8ziG6HZ6Kvany2yYDR15jtW84OnPFe8SZXykIrrNfX0,20031
5
+ spacr/chris.py,sha256=YlBjSgeZaY8HPy6jkrT_ISAnCMAKVfvCxF0I9eAZLFM,2418
6
+ spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
7
+ spacr/core.py,sha256=_R8gXNnjf680yrnbCi2piWQUz7PDbqWYn7SL5MACLfo,156457
8
+ spacr/foldseek.py,sha256=cWtLzvFF2O_mq5I71UMiuU9DTvDCp7wl6aaWAZRrBZc,33970
9
+ spacr/get_alfafold_structures.py,sha256=n0g8gne-oyAV3Uo6qxZoJq5X1cUUyD8u0pOC_W2PX40,3541
10
+ spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
11
+ spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
12
+ spacr/gui.py,sha256=2d2JHYVWhEFUkB3u_2OarCV_V07eLtJKUMKVKuJ7nAo,6430
13
+ spacr/gui_classify_app.py,sha256=RqVC5Ac0GiFhKg1qUyU-xfbSVMwSTjySkReWCvmsZ1U,7917
14
+ spacr/gui_mask_app.py,sha256=c0NwgzzMSYx7xE60sp6zLE1h7ct3pnT7n7gz1SqmWbs,9750
15
+ spacr/gui_measure_app.py,sha256=I6OXPJZZ0sG3OqRO4l7Fr9Sb_5jr6X1Y2LXLgPW_Q_4,9675
16
+ spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ spacr/gui_utils.py,sha256=URyEUfDn7RSE8iWBDwmC1jXWNA2Vl1oA2GsSjbSbuFU,40715
18
+ spacr/io.py,sha256=-Ho1Rw46s5DfutcTAtp0RhEs3I-GPgXIphOpjJHutPU,108688
19
+ spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
20
+ spacr/mask_app.py,sha256=p9oA0JH0Rcly2Fbsrg-Vye_iThRCZZF9axU6hkE3SAI,39376
21
+ spacr/measure.py,sha256=_f6UDugw75rILEg0uo2-QsUGUvc4AQdrdHl-BPZk74I,54686
22
+ spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
23
+ spacr/plot.py,sha256=VtDKTJ_zo8CAVC3ILuIN_wUP6197vq089wNZuom7T8g,61655
24
+ spacr/sim.py,sha256=2NR5hm--HVcYQnj1SCHoUCVbh_b2XUjjjfoAUIXFwnQ,72997
25
+ spacr/timelapse.py,sha256=plPjR8nZ7_Q50VAvMvHK2TUE4F-vh7R23JnI6tSW02g,39661
26
+ spacr/train.py,sha256=lp66dWYkiMMlgdYlMjAsJnkIZFWLizKB-xwyVnKgFBs,25904
27
+ spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
28
+ spacr/utils.py,sha256=D3WRf_0w0T6dZHh3BfwScGBQjorljgWW6CQUdM0ToN8,120918
29
+ spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
30
+ spacr-0.0.35.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
31
+ spacr-0.0.35.dist-info/METADATA,sha256=0iEqhFIza7SaHVeYbl0Rc8WocPgGEVtiwMoWsIBBZzQ,4973
32
+ spacr-0.0.35.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
33
+ spacr-0.0.35.dist-info/entry_points.txt,sha256=_khj_UcegrI5N29QcgxECsxVsfpVQzG3U5OUoKdtTcw,288
34
+ spacr-0.0.35.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
35
+ spacr-0.0.35.dist-info/RECORD,,
@@ -0,0 +1,8 @@
1
+ [console_scripts]
2
+ annotate = spacr.annotate_app:gui_annotation
3
+ classify = spacr.gui_classify_app:gui_classify
4
+ gui = spacr.gui:gui_app
5
+ make_masks = spacr.mask_app:gui_make_masks
6
+ mask = spacr.gui_mask_app:gui_mask
7
+ measure = spacr.gui_measure_app:gui_measure
8
+ sim = spacr.gui_sim_app:gui_sim