spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1377 -296
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +259 -65
- spacr/graph_learning_lap.py +73 -71
- spacr/gui_classify_app.py +5 -21
- spacr/gui_mask_app.py +36 -30
- spacr/gui_measure_app.py +10 -24
- spacr/gui_utils.py +82 -54
- spacr/io.py +505 -205
- spacr/measure.py +160 -80
- spacr/old_code.py +155 -1
- spacr/plot.py +243 -99
- spacr/sim.py +666 -119
- spacr/timelapse.py +343 -52
- spacr/train.py +18 -10
- spacr/utils.py +252 -151
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/METADATA +32 -27
- spacr-0.0.21.dist-info/RECORD +33 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/WHEEL +1 -1
- spacr/gui_temp.py +0 -212
- spacr/test_annotate_app.py +0 -58
- spacr/test_plot.py +0 -43
- spacr/test_train.py +0 -39
- spacr/test_utils.py +0 -33
- spacr-0.0.18.dist-info/RECORD +0 -36
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/timelapse.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import cv2, os, re, glob, random, btrack
|
1
|
+
import cv2, os, re, glob, random, btrack, sqlite3
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
4
|
from collections import defaultdict
|
@@ -9,6 +9,10 @@ from IPython.display import Image as ipyimage
|
|
9
9
|
import trackpy as tp
|
10
10
|
from btrack import datasets as btrack_datasets
|
11
11
|
from skimage.measure import regionprops
|
12
|
+
from scipy.signal import find_peaks
|
13
|
+
from scipy.optimize import curve_fit
|
14
|
+
from scipy.integrate import trapz
|
15
|
+
import matplotlib.pyplot as plt
|
12
16
|
|
13
17
|
from .logger import log_function_call
|
14
18
|
|
@@ -144,56 +148,6 @@ def _sort_key(file_path):
|
|
144
148
|
# Return a tuple that sorts this file as "earliest" or "lowest"
|
145
149
|
return ('', '', '', 0)
|
146
150
|
|
147
|
-
def _save_mask_timelapse_as_gif(masks, path, cmap, norm, filenames):
|
148
|
-
"""
|
149
|
-
Save a timelapse of masks as a GIF.
|
150
|
-
|
151
|
-
Parameters:
|
152
|
-
masks (list): List of mask frames.
|
153
|
-
path (str): Path to save the GIF.
|
154
|
-
cmap: Colormap for displaying the masks.
|
155
|
-
norm: Normalization for the masks.
|
156
|
-
filenames (list): List of filenames corresponding to each mask frame.
|
157
|
-
|
158
|
-
Returns:
|
159
|
-
None
|
160
|
-
"""
|
161
|
-
def _update(frame):
|
162
|
-
"""
|
163
|
-
Update the plot with the given frame.
|
164
|
-
|
165
|
-
Parameters:
|
166
|
-
frame (int): The frame number to update the plot with.
|
167
|
-
|
168
|
-
Returns:
|
169
|
-
None
|
170
|
-
"""
|
171
|
-
nonlocal filename_text_obj
|
172
|
-
if filename_text_obj is not None:
|
173
|
-
filename_text_obj.remove()
|
174
|
-
ax.clear()
|
175
|
-
ax.axis('off')
|
176
|
-
current_mask = masks[frame]
|
177
|
-
ax.imshow(current_mask, cmap=cmap, norm=norm)
|
178
|
-
ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
|
179
|
-
filename_text = filenames[frame]
|
180
|
-
filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white')
|
181
|
-
for label_value in np.unique(current_mask):
|
182
|
-
if label_value == 0: continue # Skip background
|
183
|
-
y, x = np.mean(np.where(current_mask == label_value), axis=1)
|
184
|
-
ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
|
185
|
-
|
186
|
-
fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
|
187
|
-
ax.set_facecolor('black')
|
188
|
-
ax.axis('off')
|
189
|
-
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
|
190
|
-
|
191
|
-
filename_text_obj = None
|
192
|
-
anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
|
193
|
-
anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
|
194
|
-
plt.close(fig)
|
195
|
-
print(f'Saved timelapse to {path}')
|
196
|
-
|
197
151
|
def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
|
198
152
|
"""
|
199
153
|
Converts a sequence of masks into a GIF file.
|
@@ -208,6 +162,9 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
|
|
208
162
|
Returns:
|
209
163
|
None
|
210
164
|
"""
|
165
|
+
|
166
|
+
from .io import _save_mask_timelapse_as_gif
|
167
|
+
|
211
168
|
def _display_gif(path):
|
212
169
|
with open(path, 'rb') as file:
|
213
170
|
display(ipyimage(file.read()))
|
@@ -220,7 +177,7 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
|
|
220
177
|
norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
|
221
178
|
|
222
179
|
save_path_gif = os.path.join(gif_folder, f'timelapse_masks_{object_type}_{name}.gif')
|
223
|
-
_save_mask_timelapse_as_gif(masks, save_path_gif, cmap, norm, filenames)
|
180
|
+
_save_mask_timelapse_as_gif(masks, None, save_path_gif, cmap, norm, filenames)
|
224
181
|
#_display_gif(save_path_gif)
|
225
182
|
|
226
183
|
def _timelapse_masks_to_gif(folder_path, mask_channels, object_types):
|
@@ -450,6 +407,8 @@ def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelap
|
|
450
407
|
from .plot import _visualize_and_save_timelapse_stack_with_tracks
|
451
408
|
from .utils import _masks_to_masks_stack
|
452
409
|
|
410
|
+
print(f'Tracking objects with trackpy')
|
411
|
+
|
453
412
|
if timelapse_displacement is None:
|
454
413
|
features = _prepare_for_tracking(masks)
|
455
414
|
timelapse_displacement = _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3)
|
@@ -574,3 +533,335 @@ def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, mas
|
|
574
533
|
|
575
534
|
mask_stack = _masks_to_masks_stack(masks)
|
576
535
|
return mask_stack
|
536
|
+
|
537
|
+
def exponential_decay(x, a, b, c):
|
538
|
+
return a * np.exp(-b * x) + c
|
539
|
+
|
540
|
+
def preprocess_pathogen_data(pathogen_df):
|
541
|
+
# Group by identifiers and count the number of parasites
|
542
|
+
parasite_counts = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
|
543
|
+
|
544
|
+
# Aggregate numerical columns and take the first of object columns
|
545
|
+
agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id', 'parasite_count']}
|
546
|
+
pathogen_agg = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index()
|
547
|
+
|
548
|
+
# Merge the counts back into the aggregated data
|
549
|
+
pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id'])
|
550
|
+
|
551
|
+
# Remove the object_label column as it corresponds to the pathogen ID not the cell ID
|
552
|
+
if 'object_label' in pathogen_agg.columns:
|
553
|
+
pathogen_agg.drop(columns=['object_label'], inplace=True)
|
554
|
+
|
555
|
+
# Change the name of pathogen_cell_id to object_label
|
556
|
+
pathogen_agg.rename(columns={'pathogen_cell_id': 'object_label'}, inplace=True)
|
557
|
+
|
558
|
+
return pathogen_agg
|
559
|
+
|
560
|
+
def plot_data(measurement, group, ax, label, marker='o', linestyle='-'):
|
561
|
+
ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
|
562
|
+
|
563
|
+
def infected_vs_noninfected(result_df, measurement):
|
564
|
+
# Separate the merged dataframe into two groups based on pathogen_count
|
565
|
+
infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]
|
566
|
+
uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]
|
567
|
+
|
568
|
+
# Plotting
|
569
|
+
fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
|
570
|
+
|
571
|
+
# Plot for cells that were infected at some time
|
572
|
+
for group_id in infected_cells_df['plate_row_column_field_object'].unique():
|
573
|
+
group = infected_cells_df[infected_cells_df['plate_row_column_field_object'] == group_id]
|
574
|
+
plot_data(measurement, group, axs[0], 'Infected', marker='x')
|
575
|
+
|
576
|
+
# Plot for cells that were never infected
|
577
|
+
for group_id in uninfected_cells_df['plate_row_column_field_object'].unique():
|
578
|
+
group = uninfected_cells_df[uninfected_cells_df['plate_row_column_field_object'] == group_id]
|
579
|
+
plot_data(measurement, group, axs[1], 'Uninfected')
|
580
|
+
|
581
|
+
# Set the titles and labels
|
582
|
+
axs[0].set_title('Cells Infected at Some Time')
|
583
|
+
axs[1].set_title('Cells Never Infected')
|
584
|
+
for ax in axs:
|
585
|
+
ax.set_xlabel('Time')
|
586
|
+
ax.set_ylabel('Normalized Delta ' + measurement)
|
587
|
+
all_timepoints = sorted(result_df['time'].unique())
|
588
|
+
ax.set_xticks(all_timepoints)
|
589
|
+
ax.set_xticklabels(all_timepoints, rotation=45, ha="right")
|
590
|
+
|
591
|
+
plt.tight_layout()
|
592
|
+
plt.show()
|
593
|
+
|
594
|
+
def save_figure(fig, src, figure_number):
|
595
|
+
source = os.path.dirname(src)
|
596
|
+
results_fldr = os.path.join(source,'results')
|
597
|
+
os.makedirs(results_fldr, exist_ok=True)
|
598
|
+
fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
|
599
|
+
fig.savefig(fig_loc)
|
600
|
+
print(f'Saved figure:{fig_loc}')
|
601
|
+
|
602
|
+
def save_results_dataframe(df, src, results_name):
|
603
|
+
source = os.path.dirname(src)
|
604
|
+
results_fldr = os.path.join(source,'results')
|
605
|
+
os.makedirs(results_fldr, exist_ok=True)
|
606
|
+
csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
|
607
|
+
df.to_csv(csv_loc, index=True)
|
608
|
+
print(f'Saved results:{csv_loc}')
|
609
|
+
|
610
|
+
def summarize_per_well(peak_details_df):
|
611
|
+
# Step 1: Split the 'ID' column
|
612
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
613
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
614
|
+
|
615
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
616
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
617
|
+
|
618
|
+
# Filter entries where 'amplitude' is not null
|
619
|
+
filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
|
620
|
+
|
621
|
+
# Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
|
622
|
+
numeric_cols = filtered_df.select_dtypes(include=['number']).columns
|
623
|
+
|
624
|
+
# Step 3: Calculate summary statistics
|
625
|
+
summary_df = filtered_df.groupby('well_ID').agg(
|
626
|
+
peaks_per_well=('ID', 'size'),
|
627
|
+
unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
|
628
|
+
**{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
|
629
|
+
).reset_index()
|
630
|
+
|
631
|
+
# Step 3: Calculate summary statistics
|
632
|
+
summary_df_2 = peak_details_df.groupby('well_ID').agg(
|
633
|
+
cells_per_well=('object_number', 'nunique'),
|
634
|
+
).reset_index()
|
635
|
+
|
636
|
+
summary_df['cells_per_well'] = summary_df_2['cells_per_well']
|
637
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
638
|
+
|
639
|
+
return summary_df
|
640
|
+
|
641
|
+
def summarize_per_well_inf_non_inf(peak_details_df):
|
642
|
+
# Step 1: Split the 'ID' column
|
643
|
+
split_columns = peak_details_df['ID'].str.split('_', expand=True)
|
644
|
+
peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
|
645
|
+
|
646
|
+
# Step 2: Create 'well_ID' by combining 'row' and 'column'
|
647
|
+
peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
|
648
|
+
|
649
|
+
# Assume 'pathogen_count' indicates infection if > 0
|
650
|
+
# Add an 'infected_status' column to classify cells
|
651
|
+
peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
|
652
|
+
|
653
|
+
# Preparation for Step 3: Identify numeric columns for averaging
|
654
|
+
numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
|
655
|
+
|
656
|
+
# Step 3: Calculate summary statistics
|
657
|
+
summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
|
658
|
+
cells_per_well=('object_number', 'nunique'),
|
659
|
+
peaks_per_well=('ID', 'size'),
|
660
|
+
**{col: (col, 'mean') for col in numeric_cols}
|
661
|
+
).reset_index()
|
662
|
+
|
663
|
+
# Calculate peaks per cell
|
664
|
+
summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
|
665
|
+
|
666
|
+
return summary_df
|
667
|
+
|
668
|
+
def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
|
669
|
+
# Load data
|
670
|
+
conn = sqlite3.connect(db_loc)
|
671
|
+
# Load cell table
|
672
|
+
cell_df = pd.read_sql(f"SELECT * FROM {'cell'}", conn)
|
673
|
+
|
674
|
+
if pathogen:
|
675
|
+
pathogen_df = pd.read_sql("SELECT * FROM pathogen", conn)
|
676
|
+
pathogen_df['pathogen_cell_id'] = pathogen_df['pathogen_cell_id'].astype(float).astype('Int64')
|
677
|
+
pathogen_df = preprocess_pathogen_data(pathogen_df)
|
678
|
+
cell_df = cell_df.merge(pathogen_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen'))
|
679
|
+
cell_df['parasite_count'] = cell_df['parasite_count'].fillna(0)
|
680
|
+
print(f'After pathogen merge: {len(cell_df)} objects')
|
681
|
+
|
682
|
+
# Optionally load cytoplasm table and merge
|
683
|
+
if cytoplasm:
|
684
|
+
cytoplasm_df = pd.read_sql(f"SELECT * FROM {'cytoplasm'}", conn)
|
685
|
+
# Merge on specified columns
|
686
|
+
cell_df = cell_df.merge(cytoplasm_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm'))
|
687
|
+
|
688
|
+
print(f'After cytoplasm merge: {len(cell_df)} objects')
|
689
|
+
|
690
|
+
conn.close()
|
691
|
+
|
692
|
+
# Continue with your existing processing on cell_df now containing merged data...
|
693
|
+
# Prepare DataFrame (use cell_df instead of df)
|
694
|
+
prcf_components = cell_df['prcf'].str.split('_', expand=True)
|
695
|
+
cell_df['plate'] = prcf_components[0]
|
696
|
+
cell_df['row'] = prcf_components[1]
|
697
|
+
cell_df['column'] = prcf_components[2]
|
698
|
+
cell_df['field'] = prcf_components[3]
|
699
|
+
cell_df['time'] = prcf_components[4].str.extract('t(\d+)').astype(int)
|
700
|
+
cell_df['object_number'] = cell_df['object_label']
|
701
|
+
cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
|
702
|
+
|
703
|
+
df = cell_df.copy()
|
704
|
+
|
705
|
+
# Fit exponential decay model to all scaled fluorescence data
|
706
|
+
try:
|
707
|
+
params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
|
708
|
+
df['corrected_' + measurement] = df[measurement] / exponential_decay(df['time'], *params)
|
709
|
+
except RuntimeError as e:
|
710
|
+
print(f"Curve fitting failed for the entire dataset with error: {e}")
|
711
|
+
return
|
712
|
+
if verbose:
|
713
|
+
print(f'Analyzing: {len(df)} objects')
|
714
|
+
|
715
|
+
# Normalizing corrected fluorescence for each cell
|
716
|
+
corrected_dfs = []
|
717
|
+
peak_details_list = []
|
718
|
+
total_timepoints = df['time'].nunique()
|
719
|
+
size_filter_removed = 0
|
720
|
+
transience_removed = 0
|
721
|
+
|
722
|
+
for unique_id, group in df.groupby('plate_row_column_field_object'):
|
723
|
+
group = group.sort_values('time')
|
724
|
+
if remove_transient:
|
725
|
+
|
726
|
+
threshold = int(transience_threshold * total_timepoints)
|
727
|
+
|
728
|
+
if verbose:
|
729
|
+
print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
|
730
|
+
|
731
|
+
if len(group) <= threshold:
|
732
|
+
transience_removed += 1
|
733
|
+
if verbose:
|
734
|
+
print(f'removed group {unique_id} due to transience')
|
735
|
+
continue
|
736
|
+
|
737
|
+
size_diff = group[size_filter].std() / group[size_filter].mean()
|
738
|
+
|
739
|
+
if size_diff <= fluctuation_threshold:
|
740
|
+
group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
|
741
|
+
corrected_dfs.append(group)
|
742
|
+
|
743
|
+
# Detect peaks
|
744
|
+
peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
|
745
|
+
|
746
|
+
# Set values < 0 to 0
|
747
|
+
group_filtered = group.copy()
|
748
|
+
group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
|
749
|
+
above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
|
750
|
+
auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
|
751
|
+
is_infected = (group['parasite_count'] > 0).any()
|
752
|
+
|
753
|
+
if is_infected:
|
754
|
+
is_infected = 1
|
755
|
+
else:
|
756
|
+
is_infected = 0
|
757
|
+
|
758
|
+
if len(peaks) == 0:
|
759
|
+
peak_details_list.append({
|
760
|
+
'ID': unique_id,
|
761
|
+
'plate': group['plate'].iloc[0],
|
762
|
+
'row': group['row'].iloc[0],
|
763
|
+
'column': group['column'].iloc[0],
|
764
|
+
'field': group['field'].iloc[0],
|
765
|
+
'object_number': group['object_number'].iloc[0],
|
766
|
+
'time': np.nan, # The time of the peak
|
767
|
+
'amplitude': np.nan,
|
768
|
+
'delta': np.nan,
|
769
|
+
'AUC': auc,
|
770
|
+
'AUC_positive': above_zero_auc,
|
771
|
+
'AUC_peak': np.nan,
|
772
|
+
'infected': is_infected
|
773
|
+
})
|
774
|
+
|
775
|
+
# Inside the for loop where peaks are detected
|
776
|
+
for i, peak in enumerate(peaks):
|
777
|
+
|
778
|
+
amplitude = properties['peak_heights'][i]
|
779
|
+
peak_time = group['time'].iloc[peak]
|
780
|
+
pathogen_count_at_peak = group['parasite_count'].iloc[peak]
|
781
|
+
|
782
|
+
start_idx = max(peak - 1, 0)
|
783
|
+
end_idx = min(peak + 1, len(group) - 1)
|
784
|
+
|
785
|
+
# Using indices to slice for AUC calculation
|
786
|
+
peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
|
787
|
+
peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
|
788
|
+
peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
|
789
|
+
|
790
|
+
peak_details_list.append({
|
791
|
+
'ID': unique_id,
|
792
|
+
'plate': group['plate'].iloc[0],
|
793
|
+
'row': group['row'].iloc[0],
|
794
|
+
'column': group['column'].iloc[0],
|
795
|
+
'field': group['field'].iloc[0],
|
796
|
+
'object_number': group['object_number'].iloc[0],
|
797
|
+
'time': peak_time, # The time of the peak
|
798
|
+
'amplitude': amplitude,
|
799
|
+
'delta': group['delta_' + measurement].iloc[peak],
|
800
|
+
'AUC': auc,
|
801
|
+
'AUC_positive': above_zero_auc,
|
802
|
+
'AUC_peak': peak_auc,
|
803
|
+
'infected': pathogen_count_at_peak
|
804
|
+
})
|
805
|
+
else:
|
806
|
+
size_filter_removed += 1
|
807
|
+
|
808
|
+
if verbose:
|
809
|
+
print(f'Removed {size_filter_removed} objects due to size filter fluctuation')
|
810
|
+
print(f'Removed {transience_removed} objects due to transience')
|
811
|
+
|
812
|
+
if len(corrected_dfs) > 0:
|
813
|
+
result_df = pd.concat(corrected_dfs)
|
814
|
+
else:
|
815
|
+
print("No suitable cells found for analysis")
|
816
|
+
return
|
817
|
+
|
818
|
+
peak_details_df = pd.DataFrame(peak_details_list)
|
819
|
+
summary_df = summarize_per_well(peak_details_df)
|
820
|
+
summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
|
821
|
+
|
822
|
+
save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
|
823
|
+
save_results_dataframe(df=result_df, src=db_loc, results_name='results')
|
824
|
+
save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
|
825
|
+
save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
|
826
|
+
|
827
|
+
# Plotting
|
828
|
+
fig, ax = plt.subplots(figsize=(10, 8))
|
829
|
+
sampled_groups = result_df['plate_row_column_field_object'].unique()
|
830
|
+
if num_lines is not None and 0 < num_lines < len(sampled_groups):
|
831
|
+
sampled_groups = np.random.choice(sampled_groups, size=num_lines, replace=False)
|
832
|
+
|
833
|
+
for group_id in sampled_groups:
|
834
|
+
group = result_df[result_df['plate_row_column_field_object'] == group_id]
|
835
|
+
ax.plot(group['time'], group['delta_' + measurement], marker='o', linestyle='-')
|
836
|
+
|
837
|
+
ax.set_xticks(sorted(df['time'].unique()))
|
838
|
+
ax.set_xticklabels(sorted(df['time'].unique()), rotation=45, ha="right")
|
839
|
+
ax.set_title(f'Normalized Delta of {measurement} Over Time (Corrected for Photobleaching)')
|
840
|
+
ax.set_xlabel('Time')
|
841
|
+
ax.set_ylabel('Normalized Delta ' + measurement)
|
842
|
+
plt.tight_layout()
|
843
|
+
|
844
|
+
plt.show()
|
845
|
+
|
846
|
+
save_figure(fig, src=db_loc, figure_number=1)
|
847
|
+
|
848
|
+
if pathogen:
|
849
|
+
infected_vs_noninfected(result_df, measurement)
|
850
|
+
save_figure(fig, src=db_loc, figure_number=2)
|
851
|
+
|
852
|
+
# Identify cells with and without pathogens
|
853
|
+
infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
|
854
|
+
noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
|
855
|
+
|
856
|
+
# Peaks in infected and noninfected cells
|
857
|
+
infected_peaks = peak_details_df[peak_details_df['ID'].isin(infected_cells)]
|
858
|
+
noninfected_peaks = peak_details_df[peak_details_df['ID'].isin(noninfected_cells)]
|
859
|
+
|
860
|
+
# Calculate the average number of peaks per cell
|
861
|
+
avg_inf_peaks_per_cell = len(infected_peaks) / len(infected_cells) if len(infected_cells) > 0 else 0
|
862
|
+
avg_non_inf_peaks_per_cell = len(noninfected_peaks) / len(noninfected_cells) if len(noninfected_cells) > 0 else 0
|
863
|
+
|
864
|
+
print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
|
865
|
+
print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
|
866
|
+
print(f'done')
|
867
|
+
return result_df, peak_details_df, fig
|
spacr/train.py
CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
|
|
194
194
|
|
195
195
|
def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
196
196
|
|
197
|
-
from .io import
|
198
|
-
from .utils import pick_best_model
|
197
|
+
from .io import _save_settings, _copy_missclassified
|
198
|
+
from .utils import pick_best_model
|
199
199
|
from .core import generate_loaders
|
200
200
|
|
201
201
|
settings['src'] = src
|
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
208
208
|
model = torch.load(custom_model_path)
|
209
209
|
|
210
210
|
if settings['train']:
|
211
|
-
|
211
|
+
_save_settings(settings, src)
|
212
212
|
torch.cuda.empty_cache()
|
213
213
|
torch.cuda.memory.empty_cache()
|
214
214
|
gc.collect()
|
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
227
227
|
validation_split=settings['val_split'],
|
228
228
|
pin_memory=settings['pin_memory'],
|
229
229
|
normalize=settings['normalize'],
|
230
|
-
|
230
|
+
channels=settings['channels'],
|
231
|
+
verbose=settings['verbose'])
|
232
|
+
|
231
233
|
|
232
234
|
if settings['test']:
|
233
235
|
test, _, plate_names_test = generate_loaders(src,
|
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
|
|
240
242
|
validation_split=0.0,
|
241
243
|
pin_memory=settings['pin_memory'],
|
242
244
|
normalize=settings['normalize'],
|
245
|
+
channels=settings['channels'],
|
243
246
|
verbose=settings['verbose'])
|
244
247
|
if model == None:
|
245
248
|
model_path = pick_best_model(src+'/model')
|
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
330
333
|
None
|
331
334
|
"""
|
332
335
|
|
333
|
-
from .io import
|
334
|
-
from .utils import
|
336
|
+
from .io import _save_model, _save_progress
|
337
|
+
from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
|
335
338
|
|
336
339
|
print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
|
337
340
|
|
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
347
350
|
break
|
348
351
|
|
349
352
|
model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
|
353
|
+
|
354
|
+
if model is None:
|
355
|
+
print(f'Model {model_type} not found')
|
356
|
+
return
|
357
|
+
|
350
358
|
model.to(device)
|
351
359
|
|
352
360
|
if optimizer_type == 'adamw':
|
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
421
429
|
if schedule == 'step_lr':
|
422
430
|
scheduler.step()
|
423
431
|
|
424
|
-
|
432
|
+
_save_progress(dst, results_df, train_metrics_df)
|
425
433
|
clear_output(wait=True)
|
426
434
|
display(results_df)
|
427
|
-
|
435
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
428
436
|
|
429
437
|
if train_mode == 'irm':
|
430
438
|
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
|
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
|
|
494
502
|
|
495
503
|
clear_output(wait=True)
|
496
504
|
display(results_df)
|
497
|
-
|
498
|
-
|
505
|
+
_save_progress(dst, results_df, train_metrics_df)
|
506
|
+
_save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
|
499
507
|
print(f'Saved model: {dst}')
|
500
508
|
return
|