spacr 0.0.17__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -1,4 +1,4 @@
1
- import cv2, os, re, glob, random, btrack
1
+ import cv2, os, re, glob, random, btrack, sqlite3
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  from collections import defaultdict
@@ -9,6 +9,9 @@ from IPython.display import Image as ipyimage
9
9
  import trackpy as tp
10
10
  from btrack import datasets as btrack_datasets
11
11
  from skimage.measure import regionprops
12
+ from scipy.signal import find_peaks
13
+ from scipy.optimize import curve_fit
14
+ import matplotlib.pyplot as plt
12
15
 
13
16
  from .logger import log_function_call
14
17
 
@@ -144,56 +147,6 @@ def _sort_key(file_path):
144
147
  # Return a tuple that sorts this file as "earliest" or "lowest"
145
148
  return ('', '', '', 0)
146
149
 
147
- def _save_mask_timelapse_as_gif(masks, path, cmap, norm, filenames):
148
- """
149
- Save a timelapse of masks as a GIF.
150
-
151
- Parameters:
152
- masks (list): List of mask frames.
153
- path (str): Path to save the GIF.
154
- cmap: Colormap for displaying the masks.
155
- norm: Normalization for the masks.
156
- filenames (list): List of filenames corresponding to each mask frame.
157
-
158
- Returns:
159
- None
160
- """
161
- def _update(frame):
162
- """
163
- Update the plot with the given frame.
164
-
165
- Parameters:
166
- frame (int): The frame number to update the plot with.
167
-
168
- Returns:
169
- None
170
- """
171
- nonlocal filename_text_obj
172
- if filename_text_obj is not None:
173
- filename_text_obj.remove()
174
- ax.clear()
175
- ax.axis('off')
176
- current_mask = masks[frame]
177
- ax.imshow(current_mask, cmap=cmap, norm=norm)
178
- ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
179
- filename_text = filenames[frame]
180
- filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white')
181
- for label_value in np.unique(current_mask):
182
- if label_value == 0: continue # Skip background
183
- y, x = np.mean(np.where(current_mask == label_value), axis=1)
184
- ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
185
-
186
- fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
187
- ax.set_facecolor('black')
188
- ax.axis('off')
189
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
190
-
191
- filename_text_obj = None
192
- anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
193
- anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
194
- plt.close(fig)
195
- print(f'Saved timelapse to {path}')
196
-
197
150
  def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
198
151
  """
199
152
  Converts a sequence of masks into a GIF file.
@@ -208,6 +161,9 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
208
161
  Returns:
209
162
  None
210
163
  """
164
+
165
+ from .io import _save_mask_timelapse_as_gif
166
+
211
167
  def _display_gif(path):
212
168
  with open(path, 'rb') as file:
213
169
  display(ipyimage(file.read()))
@@ -220,7 +176,7 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
220
176
  norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
221
177
 
222
178
  save_path_gif = os.path.join(gif_folder, f'timelapse_masks_{object_type}_{name}.gif')
223
- _save_mask_timelapse_as_gif(masks, save_path_gif, cmap, norm, filenames)
179
+ _save_mask_timelapse_as_gif(masks, None, save_path_gif, cmap, norm, filenames)
224
180
  #_display_gif(save_path_gif)
225
181
 
226
182
  def _timelapse_masks_to_gif(folder_path, mask_channels, object_types):
@@ -450,6 +406,8 @@ def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelap
450
406
  from .plot import _visualize_and_save_timelapse_stack_with_tracks
451
407
  from .utils import _masks_to_masks_stack
452
408
 
409
+ print(f'Tracking objects with trackpy')
410
+
453
411
  if timelapse_displacement is None:
454
412
  features = _prepare_for_tracking(masks)
455
413
  timelapse_displacement = _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3)
@@ -574,3 +532,206 @@ def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, mas
574
532
 
575
533
  mask_stack = _masks_to_masks_stack(masks)
576
534
  return mask_stack
535
+
536
+ def exponential_decay(x, a, b, c):
537
+ return a * np.exp(-b * x) + c
538
+
539
+ def preprocess_pathogen_data(pathogen_df):
540
+ # Group by identifiers and count the number of parasites
541
+ parasite_counts = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
542
+
543
+ # Aggregate numerical columns and take the first of object columns
544
+ agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id', 'parasite_count']}
545
+ pathogen_agg = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index()
546
+
547
+ # Merge the counts back into the aggregated data
548
+ pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id'])
549
+
550
+ # Remove the object_label column as it corresponds to the pathogen ID not the cell ID
551
+ if 'object_label' in pathogen_agg.columns:
552
+ pathogen_agg.drop(columns=['object_label'], inplace=True)
553
+
554
+ # Change the name of pathogen_cell_id to object_label
555
+ pathogen_agg.rename(columns={'pathogen_cell_id': 'object_label'}, inplace=True)
556
+
557
+ return pathogen_agg
558
+
559
+ def plot_data(measurement, group, ax, label, marker='o', linestyle='-'):
560
+ ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
561
+
562
+ def infected_vs_noninfected(result_df, measurement):
563
+ # Separate the merged dataframe into two groups based on pathogen_count
564
+ infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]
565
+ uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]
566
+
567
+ # Plotting
568
+ fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
569
+
570
+ # Plot for cells that were infected at some time
571
+ for group_id in infected_cells_df['plate_row_column_field_object'].unique():
572
+ group = infected_cells_df[infected_cells_df['plate_row_column_field_object'] == group_id]
573
+ plot_data(measurement, group, axs[0], 'Infected', marker='x')
574
+
575
+ # Plot for cells that were never infected
576
+ for group_id in uninfected_cells_df['plate_row_column_field_object'].unique():
577
+ group = uninfected_cells_df[uninfected_cells_df['plate_row_column_field_object'] == group_id]
578
+ plot_data(measurement, group, axs[1], 'Uninfected')
579
+
580
+ # Set the titles and labels
581
+ axs[0].set_title('Cells Infected at Some Time')
582
+ axs[1].set_title('Cells Never Infected')
583
+ for ax in axs:
584
+ ax.set_xlabel('Time')
585
+ ax.set_ylabel('Normalized Delta ' + measurement)
586
+ all_timepoints = sorted(result_df['time'].unique())
587
+ ax.set_xticks(all_timepoints)
588
+ ax.set_xticklabels(all_timepoints, rotation=45, ha="right")
589
+
590
+ plt.tight_layout()
591
+ plt.show()
592
+
593
+ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
594
+ # Load data
595
+ conn = sqlite3.connect(db_loc)
596
+ # Load cell table
597
+ cell_df = pd.read_sql(f"SELECT * FROM {'cell'}", conn)
598
+
599
+ if pathogen:
600
+ pathogen_df = pd.read_sql("SELECT * FROM pathogen", conn)
601
+ pathogen_df['pathogen_cell_id'] = pathogen_df['pathogen_cell_id'].astype(float).astype('Int64')
602
+ pathogen_df = preprocess_pathogen_data(pathogen_df)
603
+ cell_df = cell_df.merge(pathogen_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen'))
604
+ cell_df['parasite_count'] = cell_df['parasite_count'].fillna(0)
605
+ print(f'After pathogen merge: {len(cell_df)} objects')
606
+
607
+ # Optionally load cytoplasm table and merge
608
+ if cytoplasm:
609
+ cytoplasm_df = pd.read_sql(f"SELECT * FROM {'cytoplasm'}", conn)
610
+ # Merge on specified columns
611
+ cell_df = cell_df.merge(cytoplasm_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm'))
612
+
613
+ print(f'After cytoplasm merge: {len(cell_df)} objects')
614
+
615
+ conn.close()
616
+
617
+ # Continue with your existing processing on cell_df now containing merged data...
618
+ # Prepare DataFrame (use cell_df instead of df)
619
+ prcf_components = cell_df['prcf'].str.split('_', expand=True)
620
+ cell_df['plate'] = prcf_components[0]
621
+ cell_df['row'] = prcf_components[1]
622
+ cell_df['column'] = prcf_components[2]
623
+ cell_df['field'] = prcf_components[3]
624
+ cell_df['time'] = prcf_components[4].str.extract('t(\d+)').astype(int)
625
+ cell_df['object_number'] = cell_df['object_label']
626
+ cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
627
+
628
+ df = cell_df.copy()
629
+
630
+ # Fit exponential decay model to all scaled fluorescence data
631
+ try:
632
+ params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
633
+ df['corrected_' + measurement] = df[measurement] / exponential_decay(df['time'], *params)
634
+ except RuntimeError as e:
635
+ print(f"Curve fitting failed for the entire dataset with error: {e}")
636
+ return
637
+ if verbose:
638
+ print(f'Analyzing: {len(df)} objects')
639
+
640
+ # Normalizing corrected fluorescence for each cell
641
+ corrected_dfs = []
642
+ peak_details_list = []
643
+ total_timepoints = df['time'].nunique()
644
+ size_filter_removed = 0
645
+ transience_removed = 0
646
+
647
+ for unique_id, group in df.groupby('plate_row_column_field_object'):
648
+ group = group.sort_values('time')
649
+ if remove_transient:
650
+
651
+ threshold = int(transience_threshold * total_timepoints)
652
+
653
+ if verbose:
654
+ print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
655
+
656
+ if not len(group) <= threshold:
657
+ transience_removed += 1
658
+ continue
659
+
660
+ size_diff = group[size_filter].std() / group[size_filter].mean()
661
+ if size_diff <= fluctuation_threshold:
662
+ group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
663
+ corrected_dfs.append(group)
664
+
665
+ # Detect peaks
666
+ peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
667
+
668
+ # Inside the for loop where peaks are detected
669
+ for i, peak in enumerate(peaks):
670
+ amplitude = properties['peak_heights'][i] # Correctly access the amplitude
671
+ peak_time = group['time'].iloc[peak] # Time corresponding to the peak
672
+ # Get the number of pathogens in the cell at the time of the peak
673
+ pathogen_count_at_peak = group['parasite_count'].iloc[peak]
674
+ peak_details_list.append({
675
+ 'ID': unique_id,
676
+ 'plate': group['plate'].iloc[0],
677
+ 'row': group['row'].iloc[0],
678
+ 'column': group['column'].iloc[0],
679
+ 'field': group['field'].iloc[0],
680
+ 'object_number': group['object_number'].iloc[0],
681
+ 'time': peak_time, # The time of the peak
682
+ 'amplitude': amplitude,
683
+ 'delta': group['delta_' + measurement].iloc[peak],
684
+ 'infected': pathogen_count_at_peak
685
+ })
686
+ else:
687
+ size_filter_removed += 1
688
+
689
+ if verbose:
690
+ print(f'Removed {size_filter_removed} objects due to size filter fluctuation')
691
+ print(f'Removed {transience_removed} objects due to transience')
692
+
693
+ if len(corrected_dfs) > 0:
694
+ result_df = pd.concat(corrected_dfs)
695
+ else:
696
+ print("No suitable cells found for analysis")
697
+ return
698
+
699
+ peak_details_df = pd.DataFrame(peak_details_list)
700
+
701
+ # Plotting
702
+ fig, ax = plt.subplots(figsize=(10, 8))
703
+ sampled_groups = result_df['plate_row_column_field_object'].unique()
704
+ if num_lines is not None and 0 < num_lines < len(sampled_groups):
705
+ sampled_groups = np.random.choice(sampled_groups, size=num_lines, replace=False)
706
+
707
+ for group_id in sampled_groups:
708
+ group = result_df[result_df['plate_row_column_field_object'] == group_id]
709
+ ax.plot(group['time'], group['delta_' + measurement], marker='o', linestyle='-')
710
+
711
+ ax.set_xticks(sorted(df['time'].unique()))
712
+ ax.set_xticklabels(sorted(df['time'].unique()), rotation=45, ha="right")
713
+ ax.set_title(f'Normalized Delta of {measurement} Over Time (Corrected for Photobleaching)')
714
+ ax.set_xlabel('Time')
715
+ ax.set_ylabel('Normalized Delta ' + measurement)
716
+ plt.tight_layout()
717
+ plt.show()
718
+
719
+ if pathogen:
720
+ infected_vs_noninfected(result_df, measurement)
721
+
722
+ # Identifying cells with and without infection
723
+ infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
724
+ noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
725
+
726
+ # Peaks in infected and noninfected cells
727
+ infected_peaks = peak_details_df[peak_details_df['ID'].isin(infected_cells)]
728
+ noninfected_peaks = peak_details_df[peak_details_df['ID'].isin(noninfected_cells)]
729
+
730
+ # Calculate the average number of peaks per cell
731
+ avg_inf_peaks_per_cell = len(infected_peaks) / len(infected_cells) if len(infected_cells) > 0 else 0
732
+ avg_non_inf_peaks_per_cell = len(noninfected_peaks) / len(noninfected_cells) if len(noninfected_cells) > 0 else 0
733
+
734
+ print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
735
+ print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
736
+
737
+ return result_df, peak_details_df