spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -11,6 +11,7 @@ from btrack import datasets as btrack_datasets
11
11
  from skimage.measure import regionprops
12
12
  from scipy.signal import find_peaks
13
13
  from scipy.optimize import curve_fit
14
+ from scipy.integrate import trapz
14
15
  import matplotlib.pyplot as plt
15
16
 
16
17
  from .logger import log_function_call
@@ -590,6 +591,80 @@ def infected_vs_noninfected(result_df, measurement):
590
591
  plt.tight_layout()
591
592
  plt.show()
592
593
 
594
+ def save_figure(fig, src, figure_number):
595
+ source = os.path.dirname(src)
596
+ results_fldr = os.path.join(source,'results')
597
+ os.makedirs(results_fldr, exist_ok=True)
598
+ fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
599
+ fig.savefig(fig_loc)
600
+ print(f'Saved figure:{fig_loc}')
601
+
602
+ def save_results_dataframe(df, src, results_name):
603
+ source = os.path.dirname(src)
604
+ results_fldr = os.path.join(source,'results')
605
+ os.makedirs(results_fldr, exist_ok=True)
606
+ csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
607
+ df.to_csv(csv_loc, index=True)
608
+ print(f'Saved results:{csv_loc}')
609
+
610
+ def summarize_per_well(peak_details_df):
611
+ # Step 1: Split the 'ID' column
612
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
613
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
614
+
615
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
616
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
617
+
618
+ # Filter entries where 'amplitude' is not null
619
+ filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
620
+
621
+ # Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
622
+ numeric_cols = filtered_df.select_dtypes(include=['number']).columns
623
+
624
+ # Step 3: Calculate summary statistics
625
+ summary_df = filtered_df.groupby('well_ID').agg(
626
+ peaks_per_well=('ID', 'size'),
627
+ unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
628
+ **{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
629
+ ).reset_index()
630
+
631
+ # Step 3: Calculate summary statistics
632
+ summary_df_2 = peak_details_df.groupby('well_ID').agg(
633
+ cells_per_well=('object_number', 'nunique'),
634
+ ).reset_index()
635
+
636
+ summary_df['cells_per_well'] = summary_df_2['cells_per_well']
637
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
638
+
639
+ return summary_df
640
+
641
+ def summarize_per_well_inf_non_inf(peak_details_df):
642
+ # Step 1: Split the 'ID' column
643
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
644
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
645
+
646
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
647
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
648
+
649
+ # Assume 'pathogen_count' indicates infection if > 0
650
+ # Add an 'infected_status' column to classify cells
651
+ peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
652
+
653
+ # Preparation for Step 3: Identify numeric columns for averaging
654
+ numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
655
+
656
+ # Step 3: Calculate summary statistics
657
+ summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
658
+ cells_per_well=('object_number', 'nunique'),
659
+ peaks_per_well=('ID', 'size'),
660
+ **{col: (col, 'mean') for col in numeric_cols}
661
+ ).reset_index()
662
+
663
+ # Calculate peaks per cell
664
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
665
+
666
+ return summary_df
667
+
593
668
  def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
594
669
  # Load data
595
670
  conn = sqlite3.connect(db_loc)
@@ -626,7 +701,7 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
626
701
  cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
627
702
 
628
703
  df = cell_df.copy()
629
-
704
+
630
705
  # Fit exponential decay model to all scaled fluorescence data
631
706
  try:
632
707
  params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
@@ -653,11 +728,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
653
728
  if verbose:
654
729
  print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
655
730
 
656
- if not len(group) <= threshold:
731
+ if len(group) <= threshold:
657
732
  transience_removed += 1
733
+ if verbose:
734
+ print(f'removed group {unique_id} due to transience')
658
735
  continue
659
736
 
660
737
  size_diff = group[size_filter].std() / group[size_filter].mean()
738
+
661
739
  if size_diff <= fluctuation_threshold:
662
740
  group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
663
741
  corrected_dfs.append(group)
@@ -665,12 +743,50 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
665
743
  # Detect peaks
666
744
  peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
667
745
 
746
+ # Set values < 0 to 0
747
+ group_filtered = group.copy()
748
+ group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
749
+ above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
750
+ auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
751
+ is_infected = (group['parasite_count'] > 0).any()
752
+
753
+ if is_infected:
754
+ is_infected = 1
755
+ else:
756
+ is_infected = 0
757
+
758
+ if len(peaks) == 0:
759
+ peak_details_list.append({
760
+ 'ID': unique_id,
761
+ 'plate': group['plate'].iloc[0],
762
+ 'row': group['row'].iloc[0],
763
+ 'column': group['column'].iloc[0],
764
+ 'field': group['field'].iloc[0],
765
+ 'object_number': group['object_number'].iloc[0],
766
+ 'time': np.nan, # The time of the peak
767
+ 'amplitude': np.nan,
768
+ 'delta': np.nan,
769
+ 'AUC': auc,
770
+ 'AUC_positive': above_zero_auc,
771
+ 'AUC_peak': np.nan,
772
+ 'infected': is_infected
773
+ })
774
+
668
775
  # Inside the for loop where peaks are detected
669
776
  for i, peak in enumerate(peaks):
670
- amplitude = properties['peak_heights'][i] # Correctly access the amplitude
671
- peak_time = group['time'].iloc[peak] # Time corresponding to the peak
672
- # Get the number of pathogens in the cell at the time of the peak
777
+
778
+ amplitude = properties['peak_heights'][i]
779
+ peak_time = group['time'].iloc[peak]
673
780
  pathogen_count_at_peak = group['parasite_count'].iloc[peak]
781
+
782
+ start_idx = max(peak - 1, 0)
783
+ end_idx = min(peak + 1, len(group) - 1)
784
+
785
+ # Using indices to slice for AUC calculation
786
+ peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
787
+ peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
788
+ peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
789
+
674
790
  peak_details_list.append({
675
791
  'ID': unique_id,
676
792
  'plate': group['plate'].iloc[0],
@@ -681,6 +797,9 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
681
797
  'time': peak_time, # The time of the peak
682
798
  'amplitude': amplitude,
683
799
  'delta': group['delta_' + measurement].iloc[peak],
800
+ 'AUC': auc,
801
+ 'AUC_positive': above_zero_auc,
802
+ 'AUC_peak': peak_auc,
684
803
  'infected': pathogen_count_at_peak
685
804
  })
686
805
  else:
@@ -697,7 +816,14 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
697
816
  return
698
817
 
699
818
  peak_details_df = pd.DataFrame(peak_details_list)
700
-
819
+ summary_df = summarize_per_well(peak_details_df)
820
+ summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
821
+
822
+ save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
823
+ save_results_dataframe(df=result_df, src=db_loc, results_name='results')
824
+ save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
825
+ save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
826
+
701
827
  # Plotting
702
828
  fig, ax = plt.subplots(figsize=(10, 8))
703
829
  sampled_groups = result_df['plate_row_column_field_object'].unique()
@@ -714,12 +840,16 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
714
840
  ax.set_xlabel('Time')
715
841
  ax.set_ylabel('Normalized Delta ' + measurement)
716
842
  plt.tight_layout()
843
+
717
844
  plt.show()
845
+
846
+ save_figure(fig, src=db_loc, figure_number=1)
718
847
 
719
848
  if pathogen:
720
849
  infected_vs_noninfected(result_df, measurement)
850
+ save_figure(fig, src=db_loc, figure_number=2)
721
851
 
722
- # Identifying cells with and without infection
852
+ # Identify cells with and without pathogens
723
853
  infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
724
854
  noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
725
855
 
@@ -733,5 +863,5 @@ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intens
733
863
 
734
864
  print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
735
865
  print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
736
-
737
- return result_df, peak_details_df
866
+ print(f'done')
867
+ return result_df, peak_details_df, fig
spacr/train.py CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
194
194
 
195
195
  def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
196
 
197
- from .io import save_settings, _copy_missclassified
198
- from .utils import pick_best_model, test_model_performance
197
+ from .io import _save_settings, _copy_missclassified
198
+ from .utils import pick_best_model
199
199
  from .core import generate_loaders
200
200
 
201
201
  settings['src'] = src
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
208
208
  model = torch.load(custom_model_path)
209
209
 
210
210
  if settings['train']:
211
- save_settings(settings, src)
211
+ _save_settings(settings, src)
212
212
  torch.cuda.empty_cache()
213
213
  torch.cuda.memory.empty_cache()
214
214
  gc.collect()
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
227
227
  validation_split=settings['val_split'],
228
228
  pin_memory=settings['pin_memory'],
229
229
  normalize=settings['normalize'],
230
- verbose=settings['verbose'])
230
+ channels=settings['channels'],
231
+ verbose=settings['verbose'])
232
+
231
233
 
232
234
  if settings['test']:
233
235
  test, _, plate_names_test = generate_loaders(src,
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
240
242
  validation_split=0.0,
241
243
  pin_memory=settings['pin_memory'],
242
244
  normalize=settings['normalize'],
245
+ channels=settings['channels'],
243
246
  verbose=settings['verbose'])
244
247
  if model == None:
245
248
  model_path = pick_best_model(src+'/model')
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
330
333
  None
331
334
  """
332
335
 
333
- from .io import save_model, save_progress
334
- from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
336
+ from .io import _save_model, _save_progress
337
+ from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
335
338
 
336
339
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
337
340
 
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
347
350
  break
348
351
 
349
352
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
353
+
354
+ if model is None:
355
+ print(f'Model {model_type} not found')
356
+ return
357
+
350
358
  model.to(device)
351
359
 
352
360
  if optimizer_type == 'adamw':
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
421
429
  if schedule == 'step_lr':
422
430
  scheduler.step()
423
431
 
424
- save_progress(dst, results_df, train_metrics_df)
432
+ _save_progress(dst, results_df, train_metrics_df)
425
433
  clear_output(wait=True)
426
434
  display(results_df)
427
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
435
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
428
436
 
429
437
  if train_mode == 'irm':
430
438
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
494
502
 
495
503
  clear_output(wait=True)
496
504
  display(results_df)
497
- save_progress(dst, results_df, train_metrics_df)
498
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
505
+ _save_progress(dst, results_df, train_metrics_df)
506
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
499
507
  print(f'Saved model: {dst}')
500
508
  return
spacr/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
1
+ import sys, os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
2
2
 
3
3
  import numpy as np
4
4
  from cellpose import models as cp_models
@@ -71,7 +71,8 @@ def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_th
71
71
  outline = np.zeros_like(mask, dtype=np.uint8) # Use uint8 for contour detection efficiency
72
72
 
73
73
  # Find and draw contours
74
- for j in np.unique(mask)[1:]:
74
+ for j in np.unique(mask):
75
+ #for j in np.unique(mask)[1:]:
75
76
  contours = find_contours(mask == j, 0.5)
76
77
  # Convert contours for OpenCV format and draw directly to optimize
77
78
  cv_contours = [np.flip(contour.astype(int), axis=1) for contour in contours]
@@ -781,8 +782,8 @@ def _get_object_settings(object_type, settings):
781
782
  object_settings = {}
782
783
 
783
784
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
784
- object_settings['minimum_size'] = (object_settings['diameter']**2)/5
785
- object_settings['maximum_size'] = (object_settings['diameter']**2)*3
785
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/4
786
+ object_settings['maximum_size'] = (object_settings['diameter']**2)*10
786
787
  object_settings['merge'] = False
787
788
  object_settings['resample'] = True
788
789
  object_settings['remove_border_objects'] = False
@@ -793,21 +794,22 @@ def _get_object_settings(object_type, settings):
793
794
  object_settings['model_name'] = 'cyto'
794
795
  else:
795
796
  object_settings['model_name'] = 'cyto2'
796
- object_settings['filter_size'] = True
797
- object_settings['filter_intensity'] = True
797
+ object_settings['filter_size'] = False
798
+ object_settings['filter_intensity'] = False
798
799
  object_settings['restore_type'] = settings.get('cell_restore_type', None)
799
800
 
800
801
  elif object_type == 'nucleus':
801
802
  object_settings['model_name'] = 'nuclei'
802
- object_settings['filter_size'] = True
803
- object_settings['filter_intensity'] = True
803
+ object_settings['filter_size'] = False
804
+ object_settings['filter_intensity'] = False
804
805
  object_settings['restore_type'] = settings.get('nucleus_restore_type', None)
805
806
 
806
807
  elif object_type == 'pathogen':
807
808
  object_settings['model_name'] = 'cyto'
808
809
  object_settings['filter_size'] = True
809
- object_settings['filter_intensity'] = True
810
+ object_settings['filter_intensity'] = False
810
811
  object_settings['restore_type'] = settings.get('pathogen_restore_type', None)
812
+ object_settings['merge'] = settings['merge_pathogens']
811
813
 
812
814
  else:
813
815
  print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
@@ -884,17 +886,15 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
884
886
 
885
887
  if not pathogen_channel is None:
886
888
  if not nucleus_channel is None:
887
- cellpose_channels['pathogen'] = [0,1]
889
+ if not pathogen_channel is None:
890
+ cellpose_channels['pathogen'] = [0,2]
891
+ else:
892
+ cellpose_channels['pathogen'] = [0,1]
888
893
  else:
889
894
  cellpose_channels['pathogen'] = [0,0]
890
895
 
891
896
  if not cell_channel is None:
892
897
  if not nucleus_channel is None:
893
- if not pathogen_channel is None:
894
- cellpose_channels['cell'] = [0,2]
895
- else:
896
- cellpose_channels['cell'] = [0,1]
897
- elif not pathogen_channel is None:
898
898
  cellpose_channels['cell'] = [0,1]
899
899
  else:
900
900
  cellpose_channels['cell'] = [0,0]
@@ -1069,7 +1069,7 @@ class Cache:
1069
1069
  cache (OrderedDict): The cache data structure.
1070
1070
  """
1071
1071
 
1072
- def _init__(self, max_size):
1072
+ def __init__(self, max_size):
1073
1073
  self.cache = OrderedDict()
1074
1074
  self.max_size = max_size
1075
1075
 
@@ -1100,7 +1100,7 @@ class ScaledDotProductAttention(nn.Module):
1100
1100
 
1101
1101
  """
1102
1102
 
1103
- def _init__(self, d_k):
1103
+ def __init__(self, d_k):
1104
1104
  super(ScaledDotProductAttention, self).__init__()
1105
1105
  self.d_k = d_k
1106
1106
 
@@ -1131,7 +1131,7 @@ class SelfAttention(nn.Module):
1131
1131
  d_k (int): Dimensionality of the key and query vectors.
1132
1132
  """
1133
1133
 
1134
- def _init__(self, in_channels, d_k):
1134
+ def __init__(self, in_channels, d_k):
1135
1135
  super(SelfAttention, self).__init__()
1136
1136
  self.W_q = nn.Linear(in_channels, d_k)
1137
1137
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1155,7 +1155,7 @@ class SelfAttention(nn.Module):
1155
1155
  return output
1156
1156
 
1157
1157
  class ScaledDotProductAttention(nn.Module):
1158
- def _init__(self, d_k):
1158
+ def __init__(self, d_k):
1159
1159
  """
1160
1160
  Initializes the ScaledDotProductAttention module.
1161
1161
 
@@ -1192,7 +1192,7 @@ class SelfAttention(nn.Module):
1192
1192
  in_channels (int): Number of input channels.
1193
1193
  d_k (int): Dimensionality of the key and query vectors.
1194
1194
  """
1195
- def _init__(self, in_channels, d_k):
1195
+ def __init__(self, in_channels, d_k):
1196
1196
  super(SelfAttention, self).__init__()
1197
1197
  self.W_q = nn.Linear(in_channels, d_k)
1198
1198
  self.W_k = nn.Linear(in_channels, d_k)
@@ -1223,7 +1223,7 @@ class EarlyFusion(nn.Module):
1223
1223
  Args:
1224
1224
  in_channels (int): Number of input channels.
1225
1225
  """
1226
- def _init__(self, in_channels):
1226
+ def __init__(self, in_channels):
1227
1227
  super(EarlyFusion, self).__init__()
1228
1228
  self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1229
1229
 
@@ -1242,7 +1242,7 @@ class EarlyFusion(nn.Module):
1242
1242
 
1243
1243
  # Spatial Attention Mechanism
1244
1244
  class SpatialAttention(nn.Module):
1245
- def _init__(self, kernel_size=7):
1245
+ def __init__(self, kernel_size=7):
1246
1246
  """
1247
1247
  Initializes the SpatialAttention module.
1248
1248
 
@@ -1287,7 +1287,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1287
1287
  forward: Forward method for the module.
1288
1288
  """
1289
1289
 
1290
- def _init__(self, in_channels, out_channels):
1290
+ def __init__(self, in_channels, out_channels):
1291
1291
  super(MultiScaleBlockWithAttention, self).__init__()
1292
1292
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1293
1293
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
@@ -1320,7 +1320,7 @@ class MultiScaleBlockWithAttention(nn.Module):
1320
1320
 
1321
1321
  # Final Classifier
1322
1322
  class CustomCellClassifier(nn.Module):
1323
- def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1323
+ def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1324
1324
  super(CustomCellClassifier, self).__init__()
1325
1325
  self.early_fusion = EarlyFusion(in_channels=3)
1326
1326
 
@@ -1349,7 +1349,7 @@ class CustomCellClassifier(nn.Module):
1349
1349
 
1350
1350
  #CNN and Transformer class, pick any Torch model.
1351
1351
  class TorchModel(nn.Module):
1352
- def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1352
+ def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1353
1353
  super(TorchModel, self).__init__()
1354
1354
  self.model_name = model_name
1355
1355
  self.use_checkpoint = use_checkpoint
@@ -1423,7 +1423,7 @@ class TorchModel(nn.Module):
1423
1423
  return logits
1424
1424
 
1425
1425
  class FocalLossWithLogits(nn.Module):
1426
- def _init__(self, alpha=1, gamma=2):
1426
+ def __init__(self, alpha=1, gamma=2):
1427
1427
  super(FocalLossWithLogits, self).__init__()
1428
1428
  self.alpha = alpha
1429
1429
  self.gamma = gamma
@@ -1435,7 +1435,7 @@ class FocalLossWithLogits(nn.Module):
1435
1435
  return focal_loss.mean()
1436
1436
 
1437
1437
  class ResNet(nn.Module):
1438
- def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1438
+ def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1439
1439
  super(ResNet, self).__init__()
1440
1440
 
1441
1441
  resnet_map = {
@@ -1788,25 +1788,24 @@ def annotate_predictions(csv_loc):
1788
1788
  df['cond'] = df.apply(assign_condition, axis=1)
1789
1789
  return df
1790
1790
 
1791
- def init_globals(counter_, lock_):
1791
+ def initiate_counter(counter_, lock_):
1792
1792
  global counter, lock
1793
1793
  counter = counter_
1794
1794
  lock = lock_
1795
1795
 
1796
- def add_images_to_tar(args):
1797
- global counter, lock, total_images
1798
- paths_chunk, tar_path = args
1796
+ def add_images_to_tar(paths_chunk, tar_path, total_images):
1799
1797
  with tarfile.open(tar_path, 'w') as tar:
1800
- for img_path in paths_chunk:
1798
+ for i, img_path in enumerate(paths_chunk):
1801
1799
  arcname = os.path.basename(img_path)
1802
1800
  try:
1803
1801
  tar.add(img_path, arcname=arcname)
1804
1802
  with lock:
1805
1803
  counter.value += 1
1806
- print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
1804
+ if counter.value % 100 == 0: # Print every 100 updates
1805
+ progress = (counter.value / total_images) * 100
1806
+ print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
1807
1807
  except FileNotFoundError:
1808
1808
  print(f"File not found: {img_path}")
1809
- return tar_path
1810
1809
 
1811
1810
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
1812
1811
  df['fraction'] = df['count']/df['well_read_sum']
@@ -2255,8 +2254,8 @@ def dice_coefficient(mask1, mask2):
2255
2254
  def extract_boundaries(mask, dilation_radius=1):
2256
2255
  binary_mask = (mask > 0).astype(np.uint8)
2257
2256
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2258
- dilated = binary_dilation(binary_mask, footprint=struct_elem)
2259
- eroded = binary_erosion(binary_mask, footprint=struct_elem)
2257
+ dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
2258
+ eroded = morphology.binary_erosion(binary_mask, footprint=struct_elem)
2260
2259
  boundary = dilated ^ eroded
2261
2260
  return boundary
2262
2261
 
@@ -2669,6 +2668,13 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2669
2668
  print(f'Number of objects before filtration: {num_objects}')
2670
2669
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2671
2670
 
2671
+ if merge:
2672
+ mask = merge_touching_objects(mask, threshold=0.66)
2673
+ if plot and idx == 0:
2674
+ num_objects = mask_object_count(mask)
2675
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2676
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2677
+
2672
2678
  if filter_size:
2673
2679
  props = measure.regionprops_table(mask, properties=['label', 'area'])
2674
2680
  valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)]
@@ -2714,13 +2720,6 @@ def _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size,
2714
2720
  print(f'Number of objects after removing border objects, : {num_objects}')
2715
2721
  plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2716
2722
 
2717
- if merge:
2718
- mask = merge_touching_objects(mask, threshold=0.25)
2719
- if plot and idx == 0:
2720
- num_objects = mask_object_count(mask)
2721
- print(f'Number of objects after merging adjacent objects, : {num_objects}')
2722
- plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2723
-
2724
2723
  mask_stack.append(mask)
2725
2724
 
2726
2725
  return mask_stack
@@ -2789,6 +2788,7 @@ def _run_test_mode(src, regex, timelapse=False):
2789
2788
 
2790
2789
  # Prepare for random selection
2791
2790
  set_identifiers = list(images_by_set.keys())
2791
+ random.seed(42)
2792
2792
  random.shuffle(set_identifiers) # Randomize the order
2793
2793
 
2794
2794
  # Select a subset based on the test_images count
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -19,6 +19,7 @@ Requires-Dist: scikit-image >=0.22.0
19
19
  Requires-Dist: scikit-learn >=1.4.1
20
20
  Requires-Dist: seaborn >=0.13.2
21
21
  Requires-Dist: matplotlib >=3.8.3
22
+ Requires-Dist: shap >=0.45.0
22
23
  Requires-Dist: pillow >=10.2.0
23
24
  Requires-Dist: imageio >=2.34.0
24
25
  Requires-Dist: scipy >=1.12.0
@@ -31,6 +32,8 @@ Requires-Dist: IPython >=8.18.1
31
32
  Requires-Dist: opencv-python-headless >=4.9.0.80
32
33
  Requires-Dist: umap >=0.1.1
33
34
  Requires-Dist: ttkthemes >=3.2.2
35
+ Requires-Dist: xgboost >=2.0.3
36
+ Requires-Dist: PyWavelets >=1.6.0
34
37
  Requires-Dist: lxml >=5.1.0
35
38
  Provides-Extra: dev
36
39
  Requires-Dist: pytest >=3.9 ; extra == 'dev'
@@ -49,7 +52,7 @@ Requires-Dist: opencv-python-headless ; extra == 'headless'
49
52
  <tr>
50
53
  <td>
51
54
 
52
- Spatial phenotype analysis of crisp screens (SpaCr). A collection of functions for generating cellpose masks -> single object images and measurements -> annotation and classification of single object images. Spacr uses batch normalization to facilitate accurate segmentation of objects with low foreground representation.
55
+ Spatial phenotype analysis of CRISPR-Cas9 screens (SpaCr). The spatial organization of organelles and proteins within cells constitutes a key level of functional regulation. In the context of infectious disease, the spatial relationships between host cell structures and intracellular pathogens are critical to understand host clearance mechanisms and how pathogens evade them. Spacr is a Python-based software package for generating single cell image data for deep-learning sub-cellular/cellular phenotypic classification from pooled genetic CRISPR-Cas9 screens. Spacr provides a flexible toolset to extract single cell images and measurements from high content cell painting experiments, train deep-learning models to classify cellular/ subcellular phenotypes, simulate and analyze pooled CRISPR-Cas9 imaging screens.
53
56
 
54
57
  </td>
55
58
  <td>
@@ -0,0 +1,33 @@
1
+ spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
2
+ spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
3
+ spacr/alpha.py,sha256=1LUtTaeVHqcTMxoCMA7mlsNYyaR0KEaXglwXsUQLwKo,13776
4
+ spacr/annotate_app.py,sha256=jtm1HgWngv9BTsyVzdmBqg1flbmbXs9JkP6CB4wmOH8,19951
5
+ spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
6
+ spacr/core.py,sha256=_R8gXNnjf680yrnbCi2piWQUz7PDbqWYn7SL5MACLfo,156457
7
+ spacr/foldseek.py,sha256=cWtLzvFF2O_mq5I71UMiuU9DTvDCp7wl6aaWAZRrBZc,33970
8
+ spacr/get_alfafold_structures.py,sha256=n0g8gne-oyAV3Uo6qxZoJq5X1cUUyD8u0pOC_W2PX40,3541
9
+ spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
10
+ spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
11
+ spacr/gui_classify_app.py,sha256=-I06tVoA3U0jaAoTs32H1Y5ACMz6QBaEM1NEfg5w-9c,7965
12
+ spacr/gui_mask_app.py,sha256=wb1w_-U1RTJFfRCGfdrkku4FVULSA_2gFZyVM9oQlB0,8752
13
+ spacr/gui_measure_app.py,sha256=9mAw3Tiuq61uKTzMVslr0MgD8m1Lv5PNI0K4-gQiuXE,8061
14
+ spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ spacr/gui_utils.py,sha256=NlgGFt62yqYnusv4yT1yjPcjd-v29N4IYukkJkEdL28,30493
16
+ spacr/io.py,sha256=gcQJ-WOzwDzRZsIzzIoMKIbhFN5AP9lYy2HCv9fG8CU,108683
17
+ spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
18
+ spacr/mask_app.py,sha256=B6-zYXVFg-cc58gLcz-Ry6LClO2jxLitL6B2ACb0HTw,39278
19
+ spacr/measure.py,sha256=_f6UDugw75rILEg0uo2-QsUGUvc4AQdrdHl-BPZk74I,54686
20
+ spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
21
+ spacr/plot.py,sha256=VtDKTJ_zo8CAVC3ILuIN_wUP6197vq089wNZuom7T8g,61655
22
+ spacr/sim.py,sha256=2NR5hm--HVcYQnj1SCHoUCVbh_b2XUjjjfoAUIXFwnQ,72997
23
+ spacr/timelapse.py,sha256=plPjR8nZ7_Q50VAvMvHK2TUE4F-vh7R23JnI6tSW02g,39661
24
+ spacr/train.py,sha256=lp66dWYkiMMlgdYlMjAsJnkIZFWLizKB-xwyVnKgFBs,25904
25
+ spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
26
+ spacr/utils.py,sha256=r0BjycEXqr48m0xEW9o14bLPj62JNpmEUBaBwR817Bw,121205
27
+ spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
28
+ spacr-0.0.21.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
29
+ spacr-0.0.21.dist-info/METADATA,sha256=u93xu8dbff0h_8p25GP5yi49sUOh7GS-axPRbik4qRk,4973
30
+ spacr-0.0.21.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
31
+ spacr-0.0.21.dist-info/entry_points.txt,sha256=5uyJaAxWCbjWYwP15InAKU1yFxTwyuvCGtIGceso1es,290
32
+ spacr-0.0.21.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
33
+ spacr-0.0.21.dist-info/RECORD,,
@@ -1,31 +0,0 @@
1
- spacr/__init__.py,sha256=mDi-Qu5r1vZnqIbUBV1JAoSq-mxmMEOmni1JSG2e4Wo,879
2
- spacr/__main__.py,sha256=_qRkhbFrH_cXr7AZs6KHL8Hh4VApqNdpNCtiKn2ePTo,285
3
- spacr/alpha.py,sha256=Q1vnqO0hvU1G7QP26amFwJY2RjZ68zIc3jYoqQSBMrw,462
4
- spacr/annotate_app.py,sha256=IPgZfS4TrSqbJr81P1FWUNOgCPPcS6EdQjUsXRwY-4E,19932
5
- spacr/cli.py,sha256=507jfOOEV8BoL4eeUcblvH-iiDHdBrEVJLu1ghAAPSc,1800
6
- spacr/core.py,sha256=VTk81S80PKq5pGeUEdduubjbsAyPWKDyRu11EWDO2ms,120307
7
- spacr/graph_learning.py,sha256=sD4eOC7Q16rr7WO20mCi_E16_LqioGUUgPamAHIIeNI,12568
8
- spacr/graph_learning_lap.py,sha256=MyNRLb63gsjBlui-ByZ0anHugYulL6M-OsGm8rnGBmE,3385
9
- spacr/gui_classify_app.py,sha256=-I06tVoA3U0jaAoTs32H1Y5ACMz6QBaEM1NEfg5w-9c,7965
10
- spacr/gui_mask_app.py,sha256=xpQ_kh-8lTb9xnyKNrZGtf9JnKqutBFu2-LdYzQShh0,8079
11
- spacr/gui_measure_app.py,sha256=9mAw3Tiuq61uKTzMVslr0MgD8m1Lv5PNI0K4-gQiuXE,8061
12
- spacr/gui_sim_app.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- spacr/gui_utils.py,sha256=dpbrsDVebjm8ZmkYYOmIZCbkQYyc6JyMIPA7C0r4Xxw,29631
14
- spacr/io.py,sha256=bCe40kli7jx7hKQyDNuZYCpZ44-Brj4TZVl9r02cFHo,102042
15
- spacr/logger.py,sha256=7Zqr3TuuOQLWT32gYr2q1qvv7x0a2JhLANmZcnBXAW8,670
16
- spacr/mask_app.py,sha256=B6-zYXVFg-cc58gLcz-Ry6LClO2jxLitL6B2ACb0HTw,39278
17
- spacr/measure.py,sha256=LvF6D-TydhXPwKkeAKuIvzHn14qVlqsn6h7ENnoKn5s,50795
18
- spacr/old_code.py,sha256=KxljHpKNsV5EfX9ifN2xJTnUeqAhyabZyfDWd5THOOc,11226
19
- spacr/plot.py,sha256=qSM0NzVQYqMYioRc_BPCkUipLRZH7_GaHKGpLXGL8oI,55040
20
- spacr/sim.py,sha256=tl40lgTMeeJSyBq_c-Rn54C9Ri0FJ2zLkLLLPLSjz3o,51534
21
- spacr/timelapse.py,sha256=wAEMv7oPyusLph3RPmF4F6UGmLfMZmrupYfsuaeJ9vI,34003
22
- spacr/train.py,sha256=r77zLvLFMzx6MJxXG3JjynD8qTWYM9pNgrChEXYQhtY,25631
23
- spacr/umap.py,sha256=4QSrQ16Og-Ijq-SwguMQT2f20UWz1LE5HQeSLmzSl8c,29370
24
- spacr/utils.py,sha256=kEQxucklUdogxjOSQxKdA1R_NU5qYj6dJPiQKIB8un4,120992
25
- spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
26
- spacr-0.0.20.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
27
- spacr-0.0.20.dist-info/METADATA,sha256=S8bOdo2ZqIj-ScjQRBxQnBQNXXFV-HKfidMs7MRIsIc,4381
28
- spacr-0.0.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
29
- spacr-0.0.20.dist-info/entry_points.txt,sha256=5uyJaAxWCbjWYwP15InAKU1yFxTwyuvCGtIGceso1es,290
30
- spacr-0.0.20.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
31
- spacr-0.0.20.dist-info/RECORD,,