spacr 0.3.47__py3-none-any.whl → 0.3.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/chat_bot.py +31 -0
- spacr/gui_elements.py +33 -7
- spacr/gui_utils.py +11 -12
- spacr/measure.py +4 -1
- spacr/ml.py +453 -141
- spacr/plot.py +612 -52
- spacr/sequencing.py +5 -2
- spacr/settings.py +15 -31
- spacr/toxo.py +447 -159
- spacr/utils.py +35 -4
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/METADATA +3 -1
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/RECORD +16 -15
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/LICENSE +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/WHEEL +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.47.dist-info → spacr-0.3.52.dist-info}/top_level.txt +0 -0
spacr/toxo.py
CHANGED
@@ -7,26 +7,63 @@ import pandas as pd
|
|
7
7
|
from scipy.stats import fisher_exact
|
8
8
|
from IPython.display import display
|
9
9
|
from matplotlib.legend import Legend
|
10
|
+
from matplotlib.transforms import Bbox
|
11
|
+
from brokenaxes import brokenaxes
|
10
12
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
13
|
+
import os
|
14
|
+
import pandas as pd
|
15
|
+
import seaborn as sns
|
16
|
+
import matplotlib.pyplot as plt
|
17
|
+
from scipy.spatial.distance import cosine
|
18
|
+
from scipy.stats import pearsonr
|
19
|
+
import pandas as pd
|
20
|
+
import matplotlib.pyplot as plt
|
21
|
+
import seaborn as sns
|
22
|
+
from sklearn.metrics import mean_absolute_error
|
23
|
+
|
24
|
+
|
25
|
+
from matplotlib.gridspec import GridSpec
|
26
|
+
|
27
|
+
|
28
|
+
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',
|
29
|
+
point_size=50, figsize=20, threshold=0,
|
30
|
+
save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
|
31
|
+
|
32
|
+
markers = [
|
33
|
+
'o', # Circle
|
34
|
+
'X', # X-shaped marker
|
35
|
+
'^', # Upward triangle
|
36
|
+
's', # Square
|
37
|
+
'v', # Downward triangle
|
38
|
+
'P', # Plus-filled pentagon
|
39
|
+
'*', # Star
|
40
|
+
'+', # Plus
|
41
|
+
'x', # Cross
|
42
|
+
'.', # Point
|
43
|
+
',', # Pixel
|
44
|
+
'd', # Diamond
|
45
|
+
'D', # Thin diamond
|
46
|
+
'h', # Hexagon 1
|
47
|
+
'H', # Hexagon 2
|
48
|
+
'p', # Pentagon
|
49
|
+
'|', # Vertical line
|
50
|
+
'_', # Horizontal line
|
51
|
+
]
|
52
|
+
|
53
|
+
plt.rcParams.update({'font.size': 14})
|
54
|
+
|
55
|
+
# Load data
|
21
56
|
if isinstance(data_path, pd.DataFrame):
|
22
57
|
data = data_path
|
23
58
|
else:
|
24
59
|
data = pd.read_csv(data_path)
|
25
|
-
|
60
|
+
|
61
|
+
fontsize = 18
|
62
|
+
|
63
|
+
plt.rcParams.update({'font.size': fontsize})
|
26
64
|
data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
|
27
65
|
data['variable'].fillna(data['feature'], inplace=True)
|
28
|
-
|
29
|
-
data['gene_nr'] = split_columns[0]
|
66
|
+
data['gene_nr'] = data['variable'].str.split('_').str[0]
|
30
67
|
data = data[data['variable'] != 'Intercept']
|
31
68
|
|
32
69
|
# Load metadata
|
@@ -34,173 +71,110 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
|
|
34
71
|
metadata = metadata_path
|
35
72
|
else:
|
36
73
|
metadata = pd.read_csv(metadata_path)
|
37
|
-
|
38
74
|
metadata['gene_nr'] = metadata['gene_nr'].astype(str)
|
39
75
|
data['gene_nr'] = data['gene_nr'].astype(str)
|
40
76
|
|
41
|
-
|
42
|
-
merged_data = pd.merge(data, metadata[['gene_nr', 'tagm_location']], on='gene_nr', how='left')
|
43
|
-
|
44
|
-
merged_data.loc[merged_data['gene_nr'].str.startswith('4'), metadata_column] = 'GT1_gene'
|
45
|
-
merged_data.loc[merged_data['gene_nr'] == 'Intercept', metadata_column] = 'Intercept'
|
46
|
-
merged_data.loc[merged_data['condition'] == 'control', metadata_column] = 'control'
|
77
|
+
merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left')
|
47
78
|
merged_data[metadata_column].fillna('unknown', inplace=True)
|
48
79
|
|
49
|
-
#
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
ordered=True)
|
54
|
-
|
55
|
-
# Create subplots with a broken y-axis
|
56
|
-
figsize_2 = figsize / 2
|
57
|
-
fig, (ax1, ax2) = plt.subplots(
|
58
|
-
2, 1, figsize=(figsize, figsize),
|
59
|
-
sharex=True, gridspec_kw={'height_ratios': [1, 3]}
|
60
|
-
)
|
61
|
-
|
62
|
-
# Define color palette
|
63
|
-
palette = {
|
64
|
-
'pc': 'red',
|
65
|
-
'nc': 'green',
|
66
|
-
'control': 'white',
|
67
|
-
'other': 'gray'}
|
68
|
-
|
69
|
-
# Scatter plot on both axes with legend completely disabled
|
70
|
-
sns.scatterplot(
|
71
|
-
data=merged_data,
|
72
|
-
x='coefficient',
|
73
|
-
y='-log10(p_value)',
|
74
|
-
hue='condition',
|
75
|
-
style=metadata_column if metadata_column else None,
|
76
|
-
s=point_size,
|
77
|
-
edgecolor='black',
|
78
|
-
palette=palette,
|
79
|
-
legend=False, # Disable automatic legend
|
80
|
-
alpha=0.6,
|
81
|
-
ax=ax2 # Lower plot
|
82
|
-
)
|
83
|
-
|
84
|
-
sns.scatterplot(
|
85
|
-
data=merged_data[merged_data['-log10(p_value)'] > 10],
|
86
|
-
x='coefficient',
|
87
|
-
y='-log10(p_value)',
|
88
|
-
hue='condition',
|
89
|
-
style=metadata_column if metadata_column else None,
|
90
|
-
s=point_size,
|
91
|
-
edgecolor='black',
|
92
|
-
palette=palette,
|
93
|
-
legend=False, # No legend on the upper plot
|
94
|
-
alpha=0.6,
|
95
|
-
ax=ax1 # Upper plot
|
96
|
-
)
|
97
|
-
|
98
|
-
# Ensure no previous legends on ax1 or ax2
|
99
|
-
if ax1.get_legend() is not None:
|
100
|
-
ax1.get_legend().remove()
|
80
|
+
# Define palette and markers
|
81
|
+
palette = {'pc': 'red', 'nc': 'green', 'control': 'white', 'other': 'gray'}
|
82
|
+
marker_dict = {val: marker for val, marker in zip(
|
83
|
+
merged_data[metadata_column].unique(), markers)}
|
101
84
|
|
102
|
-
|
103
|
-
|
85
|
+
# Create the figure with custom spacing
|
86
|
+
fig = plt.figure(figsize=(figsize,figsize))
|
87
|
+
gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
|
104
88
|
|
105
|
-
|
106
|
-
|
89
|
+
ax_upper = fig.add_subplot(gs[0])
|
90
|
+
ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
|
107
91
|
|
108
|
-
#
|
109
|
-
|
110
|
-
print(f"Labels: {labels}")
|
92
|
+
# Hide x-axis labels on the upper plot
|
93
|
+
ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
|
111
94
|
|
112
|
-
|
113
|
-
n_color_entries = len(set(merged_data['condition']))
|
114
|
-
shape_handles = handles[n_color_entries:]
|
115
|
-
shape_labels = labels[n_color_entries:]
|
95
|
+
hit_list = []
|
116
96
|
|
117
|
-
#
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
handletextpad=2.0, labelspacing=1.5, borderaxespad=1.0,
|
122
|
-
markerscale=2.0, prop={'size': 14}
|
123
|
-
)
|
124
|
-
ax2.add_artist(legend)
|
125
|
-
|
126
|
-
if isinstance(split_axis_lims, list):
|
127
|
-
if len(split_axis_lims) == 4:
|
128
|
-
ylim_min_ax1 = split_axis_lims[0]
|
129
|
-
if split_axis_lims[1] is None:
|
130
|
-
ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
|
131
|
-
else:
|
132
|
-
ylim_max_ax1 = split_axis_lims[1]
|
133
|
-
ylim_min_ax2 = split_axis_lims[2]
|
134
|
-
ylim_max_ax2 = split_axis_lims[3]
|
135
|
-
else:
|
136
|
-
ylim_min_ax1 = None
|
137
|
-
ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
|
138
|
-
ylim_min_ax2 = 0
|
139
|
-
ylim_max_ax2 = None
|
140
|
-
|
141
|
-
# Set axis limits and hide unnecessary parts
|
142
|
-
ax1.set_ylim(ylim_min_ax1, ylim_max_ax1)
|
143
|
-
ax2.set_ylim(0, ylim_max_ax2)
|
97
|
+
# Scatter plot on both axes
|
98
|
+
for _, row in merged_data.iterrows():
|
99
|
+
y_val = -np.log10(row['p_value'])
|
100
|
+
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
144
101
|
|
145
|
-
|
146
|
-
|
147
|
-
|
102
|
+
ax.scatter(
|
103
|
+
row['coefficient'], y_val,
|
104
|
+
color=palette.get(row['condition'], 'gray'),
|
105
|
+
marker=marker_dict.get(row[metadata_column], 'o'),
|
106
|
+
s=point_size, edgecolor='black', alpha=0.6
|
107
|
+
)
|
148
108
|
|
149
|
-
|
150
|
-
|
151
|
-
ax1.tick_params(labelbottom=False)
|
152
|
-
|
153
|
-
ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
154
|
-
|
155
|
-
# Add vertical threshold lines to both plots
|
156
|
-
if threshold > 0:
|
157
|
-
for ax in (ax1, ax2):
|
158
|
-
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
|
159
|
-
ax.axvline(x=abs(threshold), linestyle='--', color='black')
|
109
|
+
if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
|
110
|
+
hit_list.append(row['variable'])
|
160
111
|
|
161
|
-
#
|
162
|
-
|
112
|
+
# Set axis limits
|
113
|
+
ax_upper.set_ylim(y_lims[1])
|
114
|
+
ax_lower.set_ylim(y_lims[0])
|
115
|
+
ax_lower.set_xlim(x_lim)
|
163
116
|
|
164
|
-
|
165
|
-
|
166
|
-
|
117
|
+
ax_lower.spines['top'].set_visible(False)
|
118
|
+
ax_upper.spines['top'].set_visible(False)
|
119
|
+
ax_upper.spines['bottom'].set_visible(False)
|
167
120
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
row['variable'],
|
176
|
-
fontsize=fontsize,
|
177
|
-
ha='center',
|
178
|
-
va='bottom',
|
179
|
-
)
|
121
|
+
# Set x-axis and y-axis titles
|
122
|
+
ax_lower.set_xlabel('Coefficient') # X-axis title on the lower graph
|
123
|
+
ax_lower.set_ylabel('-log10(p-value)') # Y-axis title on the lower graph
|
124
|
+
ax_upper.set_ylabel('-log10(p-value)') # Y-axis title on the upper graph
|
125
|
+
|
126
|
+
for ax in [ax_upper, ax_lower]:
|
127
|
+
ax.spines['right'].set_visible(False)
|
180
128
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
texts_ax2.append(text)
|
129
|
+
# Add threshold lines to both axes
|
130
|
+
for ax in [ax_upper, ax_lower]:
|
131
|
+
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
|
132
|
+
ax.axvline(x=abs(threshold), linestyle='--', color='black')
|
186
133
|
|
187
|
-
|
188
|
-
adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1, expand_points=(padd, padd), fontsize=fontsize)
|
189
|
-
adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2, expand_points=(padd, padd), fontsize=fontsize)
|
134
|
+
ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
|
190
135
|
|
191
|
-
#
|
192
|
-
|
136
|
+
# Annotate significant points
|
137
|
+
texts_upper, texts_lower = [], [] # Collect text annotations separately
|
193
138
|
|
194
|
-
|
195
|
-
|
196
|
-
|
139
|
+
for _, row in merged_data.iterrows():
|
140
|
+
y_val = -np.log10(row['p_value'])
|
141
|
+
if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold):
|
142
|
+
continue
|
197
143
|
|
198
|
-
|
199
|
-
|
200
|
-
|
144
|
+
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
145
|
+
text = ax.text(row['coefficient'], y_val, row['variable'],
|
146
|
+
fontsize=fontsize, ha='center', va='bottom')
|
201
147
|
|
202
|
-
|
148
|
+
if ax == ax_upper:
|
149
|
+
texts_upper.append(text)
|
150
|
+
else:
|
151
|
+
texts_lower.append(text)
|
152
|
+
|
153
|
+
# Adjust text positions to avoid overlap
|
154
|
+
adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
|
155
|
+
adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
|
156
|
+
|
157
|
+
# Add a single legend on the lower axis
|
158
|
+
handles = [plt.Line2D([0], [0], marker=m, color='w', markerfacecolor='gray', markersize=10)
|
159
|
+
for m in marker_dict.values()]
|
160
|
+
labels = marker_dict.keys()
|
161
|
+
ax_lower.legend(handles,
|
162
|
+
labels,
|
163
|
+
bbox_to_anchor=(1.05, 1),
|
164
|
+
loc='upper left',
|
165
|
+
borderaxespad=0.25,
|
166
|
+
labelspacing=2,
|
167
|
+
handletextpad=0.25,
|
168
|
+
markerscale=2,
|
169
|
+
prop={'size': fontsize})
|
170
|
+
|
171
|
+
|
172
|
+
# Save and show the plot
|
173
|
+
if save_path:
|
174
|
+
plt.savefig(save_path, format='pdf', bbox_inches='tight')
|
203
175
|
plt.show()
|
176
|
+
|
177
|
+
return hit_list
|
204
178
|
|
205
179
|
def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
|
206
180
|
"""
|
@@ -341,4 +315,318 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
|
|
341
315
|
|
342
316
|
# Show the combined plot
|
343
317
|
plt.tight_layout()
|
344
|
-
plt.show()
|
318
|
+
plt.show()
|
319
|
+
|
320
|
+
def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
|
321
|
+
"""
|
322
|
+
Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
data (pd.DataFrame): The input DataFrame containing gene data.
|
326
|
+
gene_list (list): A list of gene names to highlight on the plot.
|
327
|
+
"""
|
328
|
+
# Ensure x_column is properly processed
|
329
|
+
def extract_gene_id(gene):
|
330
|
+
if isinstance(gene, str) and '_' in gene:
|
331
|
+
return gene.split('_')[1]
|
332
|
+
return str(gene)
|
333
|
+
|
334
|
+
data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce')
|
335
|
+
data = data.dropna(subset=[data_column])
|
336
|
+
data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce')
|
337
|
+
data = data.dropna(subset=[error_column])
|
338
|
+
|
339
|
+
data['x'] = data[x_column].apply(extract_gene_id)
|
340
|
+
|
341
|
+
# Sort by the data_column and assign ranks
|
342
|
+
data = data.sort_values(by=data_column).reset_index(drop=True)
|
343
|
+
data['rank'] = range(1, len(data) + 1)
|
344
|
+
|
345
|
+
# Prepare the x, y, and error values for plotting
|
346
|
+
x = data['rank']
|
347
|
+
y = data[data_column]
|
348
|
+
yerr = data[error_column]
|
349
|
+
|
350
|
+
# Create the plot
|
351
|
+
plt.figure(figsize=(10, 10))
|
352
|
+
|
353
|
+
# Plot the mean phenotype with standard error shading
|
354
|
+
plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2)
|
355
|
+
plt.fill_between(
|
356
|
+
x, y - yerr, y + yerr,
|
357
|
+
color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error'
|
358
|
+
)
|
359
|
+
|
360
|
+
# Prepare for adjustText
|
361
|
+
texts = [] # Store text objects for adjustment
|
362
|
+
|
363
|
+
# Highlight the genes in the gene_list
|
364
|
+
for gene in gene_list:
|
365
|
+
gene_id = extract_gene_id(gene)
|
366
|
+
gene_data = data[data['x'] == gene_id]
|
367
|
+
if not gene_data.empty:
|
368
|
+
# Scatter the highlighted points in purple and add labels for adjustment
|
369
|
+
plt.scatter(
|
370
|
+
gene_data['rank'],
|
371
|
+
gene_data[data_column],
|
372
|
+
color=(155/255, 55/255, 155/255),
|
373
|
+
s=200,
|
374
|
+
alpha=0.6,
|
375
|
+
label=f'Highlighted Gene: {gene}',
|
376
|
+
zorder=3 # Ensure the points are on top
|
377
|
+
)
|
378
|
+
# Add the text label next to the highlighted gene
|
379
|
+
texts.append(
|
380
|
+
plt.text(
|
381
|
+
gene_data['rank'].values[0],
|
382
|
+
gene_data[data_column].values[0],
|
383
|
+
gene,
|
384
|
+
fontsize=18,
|
385
|
+
ha='right'
|
386
|
+
)
|
387
|
+
)
|
388
|
+
|
389
|
+
# Adjust text to avoid overlap with lines drawn from points to text
|
390
|
+
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray'))
|
391
|
+
|
392
|
+
# Label the plot
|
393
|
+
plt.xlabel('Rank')
|
394
|
+
plt.ylabel('Mean Phenotype')
|
395
|
+
#plt.xticks(rotation=90) # Rotate x-axis labels for readability
|
396
|
+
plt.legend().remove() # Remove the legend if not needed
|
397
|
+
plt.tight_layout()
|
398
|
+
|
399
|
+
# Save the plot if a path is provided
|
400
|
+
if save_path:
|
401
|
+
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
|
402
|
+
print(f"Figure saved to {save_path}")
|
403
|
+
|
404
|
+
plt.show()
|
405
|
+
|
406
|
+
def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None):
|
407
|
+
"""
|
408
|
+
Generate a teal-to-white heatmap with the specified columns and genes.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
data (pd.DataFrame): The input DataFrame containing gene data.
|
412
|
+
gene_list (list): A list of genes to include in the heatmap.
|
413
|
+
columns (list): A list of column names to visualize as heatmaps.
|
414
|
+
normalize (bool): If True, normalize the values for each gene between 0 and 1.
|
415
|
+
save_path (str): Optional. If provided, the plot will be saved to this path.
|
416
|
+
"""
|
417
|
+
# Ensure x_column is properly processed
|
418
|
+
def extract_gene_id(gene):
|
419
|
+
if isinstance(gene, str) and '_' in gene:
|
420
|
+
return gene.split('_')[1]
|
421
|
+
return str(gene)
|
422
|
+
|
423
|
+
data['x'] = data[x_column].apply(extract_gene_id)
|
424
|
+
|
425
|
+
# Filter the data to only include the specified genes
|
426
|
+
filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns]
|
427
|
+
|
428
|
+
# Normalize each gene's values between 0 and 1 if normalize=True
|
429
|
+
if normalize:
|
430
|
+
filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
|
431
|
+
|
432
|
+
# Define the figure size dynamically based on the number of genes and columns
|
433
|
+
width = len(columns) * 4
|
434
|
+
height = len(gene_list) * 1
|
435
|
+
|
436
|
+
# Create the heatmap
|
437
|
+
plt.figure(figsize=(width, height))
|
438
|
+
cmap = sns.color_palette("viridis", as_cmap=True)
|
439
|
+
|
440
|
+
# Plot the heatmap with genes on the y-axis and columns on the x-axis
|
441
|
+
sns.heatmap(
|
442
|
+
filtered_data,
|
443
|
+
cmap=cmap,
|
444
|
+
cbar=True,
|
445
|
+
annot=False,
|
446
|
+
linewidths=0.5,
|
447
|
+
square=True
|
448
|
+
)
|
449
|
+
|
450
|
+
# Set the labels
|
451
|
+
plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability
|
452
|
+
plt.yticks(rotation=0) # Keep y-axis labels horizontal
|
453
|
+
plt.xlabel('')
|
454
|
+
plt.ylabel('')
|
455
|
+
|
456
|
+
# Adjust layout to ensure the plot fits well
|
457
|
+
plt.tight_layout()
|
458
|
+
|
459
|
+
# Save the plot if a path is provided
|
460
|
+
if save_path:
|
461
|
+
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
|
462
|
+
print(f"Figure saved to {save_path}")
|
463
|
+
|
464
|
+
plt.show()
|
465
|
+
|
466
|
+
def generate_score_heatmap(settings):
|
467
|
+
|
468
|
+
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
469
|
+
|
470
|
+
df = pd.read_csv(csv)
|
471
|
+
if 'col' in df.columns:
|
472
|
+
df = df[df['col']==column]
|
473
|
+
elif 'column' in df.columns:
|
474
|
+
df['col'] = df['column']
|
475
|
+
df = df[df['col']==column]
|
476
|
+
if not plate is None:
|
477
|
+
df['plate'] = f"plate{plate}"
|
478
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
479
|
+
grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
|
480
|
+
return grouped_df
|
481
|
+
|
482
|
+
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
483
|
+
df = pd.read_csv(csv)
|
484
|
+
df = df[df['column_name']==column]
|
485
|
+
if plate not in df.columns:
|
486
|
+
df['plate'] = f"plate{plate}"
|
487
|
+
df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
|
488
|
+
grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
|
489
|
+
grouped_df = grouped_df.rename(columns={'count': 'total_count'})
|
490
|
+
merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
|
491
|
+
merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
|
492
|
+
merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
|
493
|
+
return merged_df
|
494
|
+
|
495
|
+
def plot_multi_channel_heatmap(df, column='c3'):
|
496
|
+
"""
|
497
|
+
Plot a heatmap with multiple channels as columns.
|
498
|
+
|
499
|
+
Parameters:
|
500
|
+
- df: DataFrame with scores for different channels.
|
501
|
+
- column: Column to filter by (default is 'c3').
|
502
|
+
"""
|
503
|
+
# Extract row number and convert to integer for sorting
|
504
|
+
df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
|
505
|
+
|
506
|
+
# Filter and sort by plate, row, and column
|
507
|
+
df = df[df['col'] == column]
|
508
|
+
df = df.sort_values(by=['plate', 'row_num', 'col'])
|
509
|
+
|
510
|
+
# Drop temporary 'row_num' column after sorting
|
511
|
+
df = df.drop('row_num', axis=1)
|
512
|
+
|
513
|
+
# Create a new column combining plate, row, and column for the index
|
514
|
+
df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
|
515
|
+
|
516
|
+
# Set 'plate_row_col' as the index
|
517
|
+
df.set_index('plate_row_col', inplace=True)
|
518
|
+
|
519
|
+
# Extract only numeric data for the heatmap
|
520
|
+
heatmap_data = df.select_dtypes(include=[float, int])
|
521
|
+
|
522
|
+
# Plot heatmap with square boxes, no annotations, and 'viridis' colormap
|
523
|
+
plt.figure(figsize=(12, 8))
|
524
|
+
sns.heatmap(
|
525
|
+
heatmap_data,
|
526
|
+
cmap="viridis",
|
527
|
+
cbar=True,
|
528
|
+
square=True,
|
529
|
+
annot=False
|
530
|
+
)
|
531
|
+
|
532
|
+
plt.title("Heatmap of Prediction Scores for All Channels")
|
533
|
+
plt.xlabel("Channels")
|
534
|
+
plt.ylabel("Plate-Row-Column")
|
535
|
+
plt.tight_layout()
|
536
|
+
|
537
|
+
# Save the figure object and return it
|
538
|
+
fig = plt.gcf()
|
539
|
+
plt.show()
|
540
|
+
|
541
|
+
return fig
|
542
|
+
|
543
|
+
|
544
|
+
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
|
545
|
+
# Ensure `folders` is a list
|
546
|
+
if isinstance(folders, str):
|
547
|
+
folders = [folders]
|
548
|
+
|
549
|
+
ls = [] # Initialize ls to store found CSV file paths
|
550
|
+
|
551
|
+
# Iterate over the provided folders
|
552
|
+
for folder in folders:
|
553
|
+
sub_folders = os.listdir(folder) # Get sub-folder list
|
554
|
+
for sub_folder in sub_folders: # Iterate through sub-folders
|
555
|
+
path = os.path.join(folder, sub_folder) # Join the full path
|
556
|
+
|
557
|
+
if os.path.isdir(path): # Check if it’s a directory
|
558
|
+
csv = os.path.join(path, csv_name) # Join path to the CSV file
|
559
|
+
if os.path.exists(csv): # If CSV exists, add to list
|
560
|
+
ls.append(csv)
|
561
|
+
else:
|
562
|
+
print(f'No such file: {csv}')
|
563
|
+
|
564
|
+
# Initialize combined DataFrame
|
565
|
+
combined_df = None
|
566
|
+
print(f'Found {len(ls)} CSV files')
|
567
|
+
|
568
|
+
# Loop through all collected CSV files and process them
|
569
|
+
for csv_file in ls:
|
570
|
+
df = pd.read_csv(csv_file) # Read CSV into DataFrame
|
571
|
+
df = df[df['col']==column]
|
572
|
+
if not plate is None:
|
573
|
+
df['plate'] = f"plate{plate}"
|
574
|
+
# Group the data by 'plate', 'row', and 'col'
|
575
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
576
|
+
# Use the CSV filename to create a new column name
|
577
|
+
folder_name = os.path.dirname(csv_file).replace(".csv", "")
|
578
|
+
new_column_name = os.path.basename(f"{folder_name}_{data_column}")
|
579
|
+
print(new_column_name)
|
580
|
+
grouped_df = grouped_df.rename(columns={data_column: new_column_name})
|
581
|
+
|
582
|
+
# Merge into the combined DataFrame
|
583
|
+
if combined_df is None:
|
584
|
+
combined_df = grouped_df
|
585
|
+
else:
|
586
|
+
combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
|
587
|
+
combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
|
588
|
+
return combined_df
|
589
|
+
|
590
|
+
def calculate_mae(df):
|
591
|
+
"""
|
592
|
+
Calculate the MAE between each channel's predictions and the fraction column for all rows.
|
593
|
+
"""
|
594
|
+
# Extract numeric columns excluding 'fraction' and 'prc'
|
595
|
+
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
|
596
|
+
|
597
|
+
mae_data = []
|
598
|
+
|
599
|
+
# Compute MAE for each channel with 'fraction' for all rows
|
600
|
+
for column in channels.columns:
|
601
|
+
for index, row in df.iterrows():
|
602
|
+
mae = mean_absolute_error([row['fraction']], [row[column]])
|
603
|
+
mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
|
604
|
+
|
605
|
+
# Convert the list of dictionaries to a DataFrame
|
606
|
+
mae_df = pd.DataFrame(mae_data)
|
607
|
+
return mae_df
|
608
|
+
|
609
|
+
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
|
610
|
+
df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
|
611
|
+
df = df[df['grna_name']==settings['fraction_grna']]
|
612
|
+
fraction_df = df[['fraction', 'prc']]
|
613
|
+
merged_df = pd.merge(fraction_df, result_df, on=['prc'])
|
614
|
+
cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
|
615
|
+
cv_df = cv_df[[settings['data_column_cv'], 'prc']]
|
616
|
+
merged_df = pd.merge(merged_df, cv_df, on=['prc'])
|
617
|
+
|
618
|
+
fig = plot_multi_channel_heatmap(merged_df, settings['column'])
|
619
|
+
if 'row_number' in merged_df.columns:
|
620
|
+
merged_df = merged_df.drop('row_num', axis=1)
|
621
|
+
mae_df = calculate_mae(merged_df)
|
622
|
+
if 'row_number' in mae_df.columns:
|
623
|
+
mae_df = mae_df.drop('row_num', axis=1)
|
624
|
+
|
625
|
+
if not settings['dst'] is None:
|
626
|
+
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
|
627
|
+
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
|
628
|
+
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
|
629
|
+
mae_df.to_csv(mae_dst, index=False)
|
630
|
+
merged_df.to_csv(merged_dst, index=False)
|
631
|
+
fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
|
632
|
+
return merged_df
|