spacr 0.3.81__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -6
- spacr/core.py +27 -13
- spacr/deep_spacr.py +285 -5
- spacr/gui_core.py +69 -38
- spacr/gui_elements.py +193 -3
- spacr/gui_utils.py +1 -1
- spacr/io.py +5 -176
- spacr/measure.py +10 -6
- spacr/ml.py +369 -46
- spacr/plot.py +203 -92
- spacr/settings.py +53 -17
- spacr/sp_stats.py +221 -0
- spacr/submodules.py +283 -2
- spacr/toxo.py +98 -75
- spacr/utils.py +144 -52
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/METADATA +2 -1
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/RECORD +21 -20
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/LICENSE +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/WHEEL +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.1.dist-info}/top_level.txt +0 -0
spacr/sp_stats.py
ADDED
@@ -0,0 +1,221 @@
|
|
1
|
+
from scipy.stats import shapiro, normaltest, levene, ttest_ind, mannwhitneyu, kruskal, f_oneway
|
2
|
+
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
3
|
+
import scikit_posthocs as sp
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
from scipy.stats import chi2_contingency, fisher_exact
|
7
|
+
import itertools
|
8
|
+
from statsmodels.stats.multitest import multipletests
|
9
|
+
|
10
|
+
|
11
|
+
def choose_p_adjust_method(num_groups, num_data_points):
|
12
|
+
"""
|
13
|
+
Selects the most appropriate p-value adjustment method based on data characteristics.
|
14
|
+
|
15
|
+
Parameters:
|
16
|
+
- num_groups: Number of unique groups being compared
|
17
|
+
- num_data_points: Number of data points per group (assuming balanced groups)
|
18
|
+
|
19
|
+
Returns:
|
20
|
+
- A string representing the recommended p-adjustment method
|
21
|
+
"""
|
22
|
+
num_comparisons = (num_groups * (num_groups - 1)) // 2 # Number of pairwise comparisons
|
23
|
+
|
24
|
+
# Decision logic for choosing the adjustment method
|
25
|
+
if num_comparisons <= 10 and num_data_points > 5:
|
26
|
+
return 'holm' # Balanced between power and Type I error control
|
27
|
+
elif num_comparisons > 10 and num_data_points <= 5:
|
28
|
+
return 'fdr_bh' # FDR control for large number of comparisons and small sample size
|
29
|
+
elif num_comparisons <= 10:
|
30
|
+
return 'sidak' # Less conservative than Bonferroni, good for independent comparisons
|
31
|
+
else:
|
32
|
+
return 'bonferroni' # Very conservative, use for strict control of Type I errors
|
33
|
+
|
34
|
+
def perform_normality_tests(df, grouping_column, data_columns):
|
35
|
+
"""Perform normality tests for each group and data column."""
|
36
|
+
unique_groups = df[grouping_column].unique()
|
37
|
+
normality_results = []
|
38
|
+
|
39
|
+
for column in data_columns:
|
40
|
+
for group in unique_groups:
|
41
|
+
data = df.loc[df[grouping_column] == group, column].dropna()
|
42
|
+
n_samples = len(data)
|
43
|
+
|
44
|
+
if n_samples < 3:
|
45
|
+
# Skip test if there aren't enough data points
|
46
|
+
print(f"Skipping normality test for group '{group}' on column '{column}' - Not enough data.")
|
47
|
+
normality_results.append({
|
48
|
+
'Comparison': f'Normality test for {group} on {column}',
|
49
|
+
'Test Statistic': None,
|
50
|
+
'p-value': None,
|
51
|
+
'Test Name': 'Skipped',
|
52
|
+
'Column': column,
|
53
|
+
'n': n_samples
|
54
|
+
})
|
55
|
+
continue
|
56
|
+
|
57
|
+
# Choose the appropriate normality test based on the sample size
|
58
|
+
if n_samples >= 8:
|
59
|
+
stat, p_value = normaltest(data)
|
60
|
+
test_name = "D'Agostino-Pearson test"
|
61
|
+
else:
|
62
|
+
stat, p_value = shapiro(data)
|
63
|
+
test_name = "Shapiro-Wilk test"
|
64
|
+
|
65
|
+
normality_results.append({
|
66
|
+
'Comparison': f'Normality test for {group} on {column}',
|
67
|
+
'Test Statistic': stat,
|
68
|
+
'p-value': p_value,
|
69
|
+
'Test Name': test_name,
|
70
|
+
'Column': column,
|
71
|
+
'n': n_samples
|
72
|
+
})
|
73
|
+
|
74
|
+
# Check if all groups are normally distributed (p > 0.05)
|
75
|
+
normal_p_values = [result['p-value'] for result in normality_results if result['Column'] == column and result['p-value'] is not None]
|
76
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
77
|
+
|
78
|
+
return is_normal, normality_results
|
79
|
+
|
80
|
+
|
81
|
+
def perform_levene_test(df, grouping_column, data_column):
|
82
|
+
"""Perform Levene's test for equal variance."""
|
83
|
+
unique_groups = df[grouping_column].unique()
|
84
|
+
grouped_data = [df.loc[df[grouping_column] == group, data_column].dropna() for group in unique_groups]
|
85
|
+
stat, p_value = levene(*grouped_data)
|
86
|
+
return stat, p_value
|
87
|
+
|
88
|
+
def perform_statistical_tests(df, grouping_column, data_columns, paired=False):
|
89
|
+
"""Perform statistical tests for each data column."""
|
90
|
+
unique_groups = df[grouping_column].unique()
|
91
|
+
test_results = []
|
92
|
+
|
93
|
+
for column in data_columns:
|
94
|
+
grouped_data = [df.loc[df[grouping_column] == group, column].dropna() for group in unique_groups]
|
95
|
+
if len(unique_groups) == 2: # For two groups
|
96
|
+
if paired:
|
97
|
+
print("Performing paired tests (not implemented in this template).")
|
98
|
+
continue # Extend as needed
|
99
|
+
else:
|
100
|
+
# Check normality for two groups
|
101
|
+
is_normal, _ = perform_normality_tests(df, grouping_column, [column])
|
102
|
+
if is_normal:
|
103
|
+
stat, p = ttest_ind(grouped_data[0], grouped_data[1])
|
104
|
+
test_name = 'T-test'
|
105
|
+
else:
|
106
|
+
stat, p = mannwhitneyu(grouped_data[0], grouped_data[1])
|
107
|
+
test_name = 'Mann-Whitney U test'
|
108
|
+
else:
|
109
|
+
# Check normality for multiple groups
|
110
|
+
is_normal, _ = perform_normality_tests(df, grouping_column, [column])
|
111
|
+
if is_normal:
|
112
|
+
stat, p = f_oneway(*grouped_data)
|
113
|
+
test_name = 'One-way ANOVA'
|
114
|
+
else:
|
115
|
+
stat, p = kruskal(*grouped_data)
|
116
|
+
test_name = 'Kruskal-Wallis test'
|
117
|
+
|
118
|
+
test_results.append({
|
119
|
+
'Column': column,
|
120
|
+
'Test Name': test_name,
|
121
|
+
'Test Statistic': stat,
|
122
|
+
'p-value': p,
|
123
|
+
'Groups': len(unique_groups)
|
124
|
+
})
|
125
|
+
|
126
|
+
return test_results
|
127
|
+
|
128
|
+
|
129
|
+
def perform_posthoc_tests(df, grouping_column, data_column, is_normal):
|
130
|
+
"""Perform post-hoc tests for multiple groups with both original and adjusted p-values."""
|
131
|
+
unique_groups = df[grouping_column].unique()
|
132
|
+
posthoc_results = []
|
133
|
+
|
134
|
+
if len(unique_groups) > 2:
|
135
|
+
num_groups = len(unique_groups)
|
136
|
+
num_data_points = len(df[data_column].dropna()) // num_groups # Assuming roughly equal data points per group
|
137
|
+
p_adjust_method = choose_p_adjust_method(num_groups, num_data_points)
|
138
|
+
|
139
|
+
if is_normal:
|
140
|
+
# Tukey's HSD automatically adjusts p-values
|
141
|
+
tukey_result = pairwise_tukeyhsd(df[data_column], df[grouping_column], alpha=0.05)
|
142
|
+
for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
|
143
|
+
posthoc_results.append({
|
144
|
+
'Comparison': f"{comparison[0]} vs {comparison[1]}",
|
145
|
+
'Original p-value': None, # Tukey HSD does not provide raw p-values
|
146
|
+
'Adjusted p-value': p_value,
|
147
|
+
'Adjusted Method': 'Tukey HSD',
|
148
|
+
'Test Name': 'Tukey HSD'
|
149
|
+
})
|
150
|
+
else:
|
151
|
+
# Dunn's test with p-value adjustment
|
152
|
+
raw_dunn_result = sp.posthoc_dunn(df, val_col=data_column, group_col=grouping_column, p_adjust=None)
|
153
|
+
adjusted_dunn_result = sp.posthoc_dunn(df, val_col=data_column, group_col=grouping_column, p_adjust=p_adjust_method)
|
154
|
+
for i, group_a in enumerate(adjusted_dunn_result.index):
|
155
|
+
for j, group_b in enumerate(adjusted_dunn_result.columns):
|
156
|
+
if i < j: # Only consider unique pairs
|
157
|
+
posthoc_results.append({
|
158
|
+
'Comparison': f"{group_a} vs {group_b}",
|
159
|
+
'Original p-value': raw_dunn_result.iloc[i, j],
|
160
|
+
'Adjusted p-value': adjusted_dunn_result.iloc[i, j],
|
161
|
+
'Adjusted Method': p_adjust_method,
|
162
|
+
'Test Name': "Dunn's Post-hoc"
|
163
|
+
})
|
164
|
+
|
165
|
+
return posthoc_results
|
166
|
+
|
167
|
+
def chi_pairwise(raw_counts, verbose=False):
|
168
|
+
"""
|
169
|
+
Perform pairwise chi-square or Fisher's exact tests between all unique group pairs
|
170
|
+
and apply p-value correction.
|
171
|
+
|
172
|
+
Parameters:
|
173
|
+
- raw_counts (DataFrame): Contingency table with group-wise counts.
|
174
|
+
- verbose (bool): Whether to print results for each pair.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
- pairwise_df (DataFrame): DataFrame with pairwise test results, including corrected p-values.
|
178
|
+
"""
|
179
|
+
pairwise_results = []
|
180
|
+
groups = raw_counts.index.unique() # Use index from raw_counts for group pairs
|
181
|
+
raw_p_values = [] # Store raw p-values for correction later
|
182
|
+
|
183
|
+
# Calculate the number of groups and average number of data points per group
|
184
|
+
num_groups = len(groups)
|
185
|
+
num_data_points = raw_counts.sum(axis=1).mean() # Average total data points per group
|
186
|
+
p_adjust_method = choose_p_adjust_method(num_groups, num_data_points)
|
187
|
+
|
188
|
+
for group1, group2 in itertools.combinations(groups, 2):
|
189
|
+
contingency_table = raw_counts.loc[[group1, group2]].values
|
190
|
+
if contingency_table.shape[1] == 2: # Fisher's Exact Test for 2x2 tables
|
191
|
+
oddsratio, p_value = fisher_exact(contingency_table)
|
192
|
+
test_name = "Fisher's Exact Test"
|
193
|
+
else: # Chi-Square Test for larger tables
|
194
|
+
chi2_stat, p_value, _, _ = chi2_contingency(contingency_table)
|
195
|
+
test_name = 'Pairwise Chi-Square Test'
|
196
|
+
|
197
|
+
pairwise_results.append({
|
198
|
+
'Group 1': group1,
|
199
|
+
'Group 2': group2,
|
200
|
+
'Test Name': test_name,
|
201
|
+
'p-value': p_value
|
202
|
+
})
|
203
|
+
raw_p_values.append(p_value)
|
204
|
+
|
205
|
+
# Apply p-value correction
|
206
|
+
corrected_p_values = multipletests(raw_p_values, method=p_adjust_method)[1]
|
207
|
+
|
208
|
+
# Add corrected p-values to results
|
209
|
+
for i, result in enumerate(pairwise_results):
|
210
|
+
result['p-value_adj'] = corrected_p_values[i]
|
211
|
+
|
212
|
+
pairwise_df = pd.DataFrame(pairwise_results)
|
213
|
+
|
214
|
+
pairwise_df['adj'] = p_adjust_method
|
215
|
+
|
216
|
+
if verbose:
|
217
|
+
# Print pairwise results
|
218
|
+
print("\nPairwise Frequency Analysis Results:")
|
219
|
+
print(pairwise_df.to_string(index=False))
|
220
|
+
|
221
|
+
return pairwise_df
|
spacr/submodules.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
|
1
4
|
import seaborn as sns
|
2
5
|
import os, random, sqlite3, re, shap
|
3
6
|
import pandas as pd
|
@@ -10,7 +13,10 @@ from IPython.display import display
|
|
10
13
|
from sklearn.ensemble import RandomForestClassifier
|
11
14
|
from sklearn.inspection import permutation_importance
|
12
15
|
from math import pi
|
13
|
-
from scipy.stats import chi2_contingency
|
16
|
+
from scipy.stats import chi2_contingency, pearsonr
|
17
|
+
from scipy.spatial.distance import cosine
|
18
|
+
|
19
|
+
from sklearn.metrics import mean_absolute_error
|
14
20
|
|
15
21
|
import matplotlib.pyplot as plt
|
16
22
|
from natsort import natsorted
|
@@ -1035,7 +1041,7 @@ def analyze_class_proportion(settings):
|
|
1035
1041
|
from .io import _read_and_merge_data
|
1036
1042
|
from .settings import set_analyze_class_proportion_defaults
|
1037
1043
|
from .plot import plot_plates, plot_proportion_stacked_bars
|
1038
|
-
from .
|
1044
|
+
from .sp_stats import perform_normality_tests, perform_levene_test, perform_statistical_tests, perform_posthoc_tests
|
1039
1045
|
|
1040
1046
|
settings = set_analyze_class_proportion_defaults(settings)
|
1041
1047
|
save_settings(settings, name='analyze_class_proportion', show=True)
|
@@ -1132,3 +1138,278 @@ def analyze_class_proportion(settings):
|
|
1132
1138
|
print("Statistical analysis results saved.")
|
1133
1139
|
|
1134
1140
|
return output
|
1141
|
+
|
1142
|
+
def generate_score_heatmap(settings):
|
1143
|
+
|
1144
|
+
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
1145
|
+
|
1146
|
+
df = pd.read_csv(csv)
|
1147
|
+
if 'col' in df.columns:
|
1148
|
+
df = df[df['col']==column]
|
1149
|
+
elif 'column' in df.columns:
|
1150
|
+
df['col'] = df['column']
|
1151
|
+
df = df[df['col']==column]
|
1152
|
+
if not plate is None:
|
1153
|
+
df['plate'] = f"plate{plate}"
|
1154
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
1155
|
+
grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
|
1156
|
+
return grouped_df
|
1157
|
+
|
1158
|
+
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
1159
|
+
df = pd.read_csv(csv)
|
1160
|
+
df = df[df['column_name']==column]
|
1161
|
+
if plate not in df.columns:
|
1162
|
+
df['plate'] = f"plate{plate}"
|
1163
|
+
df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
|
1164
|
+
grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
|
1165
|
+
grouped_df = grouped_df.rename(columns={'count': 'total_count'})
|
1166
|
+
merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
|
1167
|
+
merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
|
1168
|
+
merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
|
1169
|
+
return merged_df
|
1170
|
+
|
1171
|
+
def plot_multi_channel_heatmap(df, column='c3', cmap='coolwarm'):
|
1172
|
+
"""
|
1173
|
+
Plot a heatmap with multiple channels as columns.
|
1174
|
+
|
1175
|
+
Parameters:
|
1176
|
+
- df: DataFrame with scores for different channels.
|
1177
|
+
- column: Column to filter by (default is 'c3').
|
1178
|
+
"""
|
1179
|
+
# Extract row number and convert to integer for sorting
|
1180
|
+
df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
|
1181
|
+
|
1182
|
+
# Filter and sort by plate, row, and column
|
1183
|
+
df = df[df['col'] == column]
|
1184
|
+
df = df.sort_values(by=['plate', 'row_num', 'col'])
|
1185
|
+
|
1186
|
+
# Drop temporary 'row_num' column after sorting
|
1187
|
+
df = df.drop('row_num', axis=1)
|
1188
|
+
|
1189
|
+
# Create a new column combining plate, row, and column for the index
|
1190
|
+
df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
|
1191
|
+
|
1192
|
+
# Set 'plate_row_col' as the index
|
1193
|
+
df.set_index('plate_row_col', inplace=True)
|
1194
|
+
|
1195
|
+
# Extract only numeric data for the heatmap
|
1196
|
+
heatmap_data = df.select_dtypes(include=[float, int])
|
1197
|
+
|
1198
|
+
# Plot heatmap with square boxes, no annotations, and 'viridis' colormap
|
1199
|
+
plt.figure(figsize=(12, 8))
|
1200
|
+
sns.heatmap(
|
1201
|
+
heatmap_data,
|
1202
|
+
cmap=cmap,
|
1203
|
+
cbar=True,
|
1204
|
+
square=True,
|
1205
|
+
annot=False
|
1206
|
+
)
|
1207
|
+
|
1208
|
+
plt.title("Heatmap of Prediction Scores for All Channels")
|
1209
|
+
plt.xlabel("Channels")
|
1210
|
+
plt.ylabel("Plate-Row-Column")
|
1211
|
+
plt.tight_layout()
|
1212
|
+
|
1213
|
+
# Save the figure object and return it
|
1214
|
+
fig = plt.gcf()
|
1215
|
+
plt.show()
|
1216
|
+
|
1217
|
+
return fig
|
1218
|
+
|
1219
|
+
|
1220
|
+
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
|
1221
|
+
# Ensure `folders` is a list
|
1222
|
+
if isinstance(folders, str):
|
1223
|
+
folders = [folders]
|
1224
|
+
|
1225
|
+
ls = [] # Initialize ls to store found CSV file paths
|
1226
|
+
|
1227
|
+
# Iterate over the provided folders
|
1228
|
+
for folder in folders:
|
1229
|
+
sub_folders = os.listdir(folder) # Get sub-folder list
|
1230
|
+
for sub_folder in sub_folders: # Iterate through sub-folders
|
1231
|
+
path = os.path.join(folder, sub_folder) # Join the full path
|
1232
|
+
|
1233
|
+
if os.path.isdir(path): # Check if it’s a directory
|
1234
|
+
csv = os.path.join(path, csv_name) # Join path to the CSV file
|
1235
|
+
if os.path.exists(csv): # If CSV exists, add to list
|
1236
|
+
ls.append(csv)
|
1237
|
+
else:
|
1238
|
+
print(f'No such file: {csv}')
|
1239
|
+
|
1240
|
+
# Initialize combined DataFrame
|
1241
|
+
combined_df = None
|
1242
|
+
print(f'Found {len(ls)} CSV files')
|
1243
|
+
|
1244
|
+
# Loop through all collected CSV files and process them
|
1245
|
+
for csv_file in ls:
|
1246
|
+
df = pd.read_csv(csv_file) # Read CSV into DataFrame
|
1247
|
+
df = df[df['col']==column]
|
1248
|
+
if not plate is None:
|
1249
|
+
df['plate'] = f"plate{plate}"
|
1250
|
+
# Group the data by 'plate', 'row', and 'col'
|
1251
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
1252
|
+
# Use the CSV filename to create a new column name
|
1253
|
+
folder_name = os.path.dirname(csv_file).replace(".csv", "")
|
1254
|
+
new_column_name = os.path.basename(f"{folder_name}_{data_column}")
|
1255
|
+
print(new_column_name)
|
1256
|
+
grouped_df = grouped_df.rename(columns={data_column: new_column_name})
|
1257
|
+
|
1258
|
+
# Merge into the combined DataFrame
|
1259
|
+
if combined_df is None:
|
1260
|
+
combined_df = grouped_df
|
1261
|
+
else:
|
1262
|
+
combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
|
1263
|
+
combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
|
1264
|
+
return combined_df
|
1265
|
+
|
1266
|
+
def calculate_mae(df):
|
1267
|
+
"""
|
1268
|
+
Calculate the MAE between each channel's predictions and the fraction column for all rows.
|
1269
|
+
"""
|
1270
|
+
# Extract numeric columns excluding 'fraction' and 'prc'
|
1271
|
+
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
|
1272
|
+
|
1273
|
+
mae_data = []
|
1274
|
+
|
1275
|
+
# Compute MAE for each channel with 'fraction' for all rows
|
1276
|
+
for column in channels.columns:
|
1277
|
+
for index, row in df.iterrows():
|
1278
|
+
mae = mean_absolute_error([row['fraction']], [row[column]])
|
1279
|
+
mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
|
1280
|
+
|
1281
|
+
# Convert the list of dictionaries to a DataFrame
|
1282
|
+
mae_df = pd.DataFrame(mae_data)
|
1283
|
+
return mae_df
|
1284
|
+
|
1285
|
+
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
|
1286
|
+
df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
|
1287
|
+
df = df[df['grna_name']==settings['fraction_grna']]
|
1288
|
+
fraction_df = df[['fraction', 'prc']]
|
1289
|
+
merged_df = pd.merge(fraction_df, result_df, on=['prc'])
|
1290
|
+
cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
|
1291
|
+
cv_df = cv_df[[settings['data_column_cv'], 'prc']]
|
1292
|
+
merged_df = pd.merge(merged_df, cv_df, on=['prc'])
|
1293
|
+
|
1294
|
+
fig = plot_multi_channel_heatmap(merged_df, settings['column'], settings['cmap'])
|
1295
|
+
if 'row_number' in merged_df.columns:
|
1296
|
+
merged_df = merged_df.drop('row_num', axis=1)
|
1297
|
+
mae_df = calculate_mae(merged_df)
|
1298
|
+
if 'row_number' in mae_df.columns:
|
1299
|
+
mae_df = mae_df.drop('row_num', axis=1)
|
1300
|
+
|
1301
|
+
if not settings['dst'] is None:
|
1302
|
+
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
|
1303
|
+
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
|
1304
|
+
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
|
1305
|
+
mae_df.to_csv(mae_dst, index=False)
|
1306
|
+
merged_df.to_csv(merged_dst, index=False)
|
1307
|
+
fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
|
1308
|
+
return merged_df
|
1309
|
+
|
1310
|
+
def post_regression_analysis(csv_file, grna_dict, grna_list, save=False):
|
1311
|
+
|
1312
|
+
def _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save=False):
|
1313
|
+
"""
|
1314
|
+
Analyze and visualize the correlation matrix of gRNAs based on their fractions and overlap.
|
1315
|
+
|
1316
|
+
Parameters:
|
1317
|
+
df (pd.DataFrame): DataFrame with columns ['grna', 'fraction', 'prc'].
|
1318
|
+
grna_list (list): List of gRNAs to include in the correlation analysis.
|
1319
|
+
save_folder (str): Path to the folder where figures and data will be saved.
|
1320
|
+
|
1321
|
+
Returns:
|
1322
|
+
pd.DataFrame: Correlation matrix of the gRNAs.
|
1323
|
+
"""
|
1324
|
+
# Filter the DataFrame to include only rows with gRNAs in the list
|
1325
|
+
filtered_df = df[df['grna'].isin(grna_list)]
|
1326
|
+
|
1327
|
+
# Pivot the data to create a prc-by-gRNA matrix, using fractions as values
|
1328
|
+
pivot_df = filtered_df.pivot_table(index='prc', columns='grna', values='fraction', aggfunc='sum').fillna(0)
|
1329
|
+
|
1330
|
+
# Compute the correlation matrix
|
1331
|
+
correlation_matrix = pivot_df.corr()
|
1332
|
+
|
1333
|
+
if save:
|
1334
|
+
# Save the correlation matrix
|
1335
|
+
correlation_matrix.to_csv(os.path.join(save_folder, 'correlation_matrix.csv'))
|
1336
|
+
|
1337
|
+
# Visualize the correlation matrix as a heatmap
|
1338
|
+
plt.figure(figsize=(10, 8))
|
1339
|
+
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', cbar=True)
|
1340
|
+
plt.title('gRNA Correlation Matrix')
|
1341
|
+
plt.xlabel('gRNAs')
|
1342
|
+
plt.ylabel('gRNAs')
|
1343
|
+
plt.tight_layout()
|
1344
|
+
|
1345
|
+
if save:
|
1346
|
+
correlation_fig_path = os.path.join(save_folder, 'correlation_matrix_heatmap.pdf')
|
1347
|
+
plt.savefig(correlation_fig_path, dpi=300)
|
1348
|
+
|
1349
|
+
plt.show()
|
1350
|
+
|
1351
|
+
return correlation_matrix
|
1352
|
+
|
1353
|
+
def _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save=False):
|
1354
|
+
"""
|
1355
|
+
Compute and visualize the effect sizes of gRNAs given fixed effect sizes for a subset of gRNAs.
|
1356
|
+
|
1357
|
+
Parameters:
|
1358
|
+
correlation_matrix (pd.DataFrame): Correlation matrix of gRNAs.
|
1359
|
+
grna_dict (dict): Dictionary of gRNAs with fixed effect sizes {grna_name: effect_size}.
|
1360
|
+
save_folder (str): Path to the folder where figures and data will be saved.
|
1361
|
+
|
1362
|
+
Returns:
|
1363
|
+
pd.Series: Effect sizes of all gRNAs.
|
1364
|
+
"""
|
1365
|
+
# Ensure the matrix is symmetric and normalize values to 0-1
|
1366
|
+
corr_matrix = correlation_matrix.copy()
|
1367
|
+
corr_matrix = (corr_matrix - corr_matrix.min().min()) / (corr_matrix.max().max() - corr_matrix.min().min())
|
1368
|
+
|
1369
|
+
# Initialize the effect sizes with dtype float
|
1370
|
+
effect_sizes = pd.Series(0.0, index=corr_matrix.index)
|
1371
|
+
|
1372
|
+
# Set the effect sizes for the specified gRNAs
|
1373
|
+
for grna, size in grna_dict.items():
|
1374
|
+
effect_sizes[grna] = size
|
1375
|
+
|
1376
|
+
# Propagate the effect sizes
|
1377
|
+
for grna in corr_matrix.index:
|
1378
|
+
if grna not in grna_dict:
|
1379
|
+
# Weighted sum of correlations with the fixed gRNAs
|
1380
|
+
effect_sizes[grna] = np.dot(corr_matrix.loc[grna], effect_sizes) / np.sum(corr_matrix.loc[grna])
|
1381
|
+
|
1382
|
+
if save:
|
1383
|
+
# Save the effect sizes
|
1384
|
+
effect_sizes.to_csv(os.path.join(save_folder, 'effect_sizes.csv'))
|
1385
|
+
|
1386
|
+
# Visualization
|
1387
|
+
plt.figure(figsize=(10, 6))
|
1388
|
+
sns.barplot(x=effect_sizes.index, y=effect_sizes.values, palette="viridis", hue=None, legend=False)
|
1389
|
+
|
1390
|
+
#for i, val in enumerate(effect_sizes.values):
|
1391
|
+
# plt.text(i, val + 0.02, f"{val:.2f}", ha='center', va='bottom', fontsize=9)
|
1392
|
+
plt.title("Effect Sizes of gRNAs")
|
1393
|
+
plt.xlabel("gRNAs")
|
1394
|
+
plt.ylabel("Effect Size")
|
1395
|
+
plt.xticks(rotation=45)
|
1396
|
+
plt.tight_layout()
|
1397
|
+
|
1398
|
+
if save:
|
1399
|
+
effect_sizes_fig_path = os.path.join(save_folder, 'effect_sizes_barplot.pdf')
|
1400
|
+
plt.savefig(effect_sizes_fig_path, dpi=300)
|
1401
|
+
|
1402
|
+
plt.show()
|
1403
|
+
|
1404
|
+
return effect_sizes
|
1405
|
+
|
1406
|
+
# Ensure the save folder exists
|
1407
|
+
save_folder = os.path.join(os.path.dirname(csv_file), 'post_regression_analysis_results')
|
1408
|
+
os.makedirs(save_folder, exist_ok=True)
|
1409
|
+
|
1410
|
+
# Load the data
|
1411
|
+
df = pd.read_csv(csv_file)
|
1412
|
+
|
1413
|
+
# Perform analysis
|
1414
|
+
correlation_matrix = _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save)
|
1415
|
+
effect_sizes = _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save)
|