spacr 0.3.46__py3-none-any.whl → 0.3.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/chat_bot.py +31 -0
- spacr/gui_elements.py +33 -7
- spacr/ml.py +478 -76
- spacr/plot.py +488 -47
- spacr/sequencing.py +122 -1
- spacr/settings.py +2 -1
- spacr/toxo.py +266 -147
- spacr/utils.py +27 -4
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/METADATA +2 -1
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/RECORD +14 -13
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/LICENSE +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/WHEEL +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.46.dist-info → spacr-0.3.50.dist-info}/top_level.txt +0 -0
spacr/sequencing.py
CHANGED
@@ -2,6 +2,11 @@ import os, gzip, re, time, gzip
|
|
2
2
|
import pandas as pd
|
3
3
|
from multiprocessing import Pool, cpu_count, Queue, Process
|
4
4
|
from Bio.Seq import Seq
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import seaborn as sns
|
7
|
+
import numpy as np
|
8
|
+
from .plot import plot_plates
|
9
|
+
from IPython.display import display
|
5
10
|
|
6
11
|
# Function to map sequences to names (same as your original)
|
7
12
|
def map_sequences_to_names(csv_file, sequences, rc):
|
@@ -480,4 +485,120 @@ def barecodes_reverse_complement(csv_file):
|
|
480
485
|
# Save the DataFrame with the reverse complement sequences
|
481
486
|
df.to_csv(new_filename, index=False)
|
482
487
|
|
483
|
-
print(f"Reverse complement file saved as {new_filename}")
|
488
|
+
print(f"Reverse complement file saved as {new_filename}")
|
489
|
+
|
490
|
+
def graph_sequencing_stats(settings):
|
491
|
+
|
492
|
+
from .utils import correct_metadata_column_names
|
493
|
+
|
494
|
+
def _plot_density(df, dependent_variable, dst=None):
|
495
|
+
"""Plot a density plot of the dependent variable."""
|
496
|
+
plt.figure(figsize=(10, 10))
|
497
|
+
sns.kdeplot(df[dependent_variable], fill=True, alpha=0.6)
|
498
|
+
plt.title(f'Density Plot of {dependent_variable}')
|
499
|
+
plt.xlabel(dependent_variable)
|
500
|
+
plt.ylabel('Density')
|
501
|
+
if dst is not None:
|
502
|
+
filename = os.path.join(dst, 'dependent_variable_density.pdf')
|
503
|
+
plt.savefig(filename, format='pdf')
|
504
|
+
print(f'Saved density plot to {filename}')
|
505
|
+
plt.show()
|
506
|
+
|
507
|
+
def find_and_visualize_fraction_threshold(df, target_unique_count=5, log_x=False, log_y=False, dst=None):
|
508
|
+
"""
|
509
|
+
Find the fraction threshold where the recalculated unique count matches the target value,
|
510
|
+
and visualize the relationship between fraction thresholds and unique counts.
|
511
|
+
"""
|
512
|
+
|
513
|
+
def _line_plot(df, x='fraction_threshold', y='unique_count', log_x=False, log_y=False):
|
514
|
+
if x not in df.columns or y not in df.columns:
|
515
|
+
raise ValueError(f"Columns '{x}' and/or '{y}' not found in the DataFrame.")
|
516
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
517
|
+
ax.plot(df[x], df[y], linestyle='-', color=(0 / 255, 155 / 255, 155 / 255), label=f"{y}")
|
518
|
+
ax.set_xlabel(x)
|
519
|
+
ax.set_ylabel(y)
|
520
|
+
ax.set_title(f'{y} vs {x}')
|
521
|
+
ax.legend()
|
522
|
+
if log_x:
|
523
|
+
ax.set_xscale('log')
|
524
|
+
if log_y:
|
525
|
+
ax.set_yscale('log')
|
526
|
+
fig.tight_layout()
|
527
|
+
return fig, ax
|
528
|
+
|
529
|
+
fraction_thresholds = np.linspace(0.001, 0.99, 1000)
|
530
|
+
results = []
|
531
|
+
|
532
|
+
# Iterate through the fraction thresholds
|
533
|
+
for threshold in fraction_thresholds:
|
534
|
+
filtered_df = df[df['fraction'] >= threshold]
|
535
|
+
unique_count = filtered_df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
|
536
|
+
results.append((threshold, unique_count))
|
537
|
+
|
538
|
+
results_df = pd.DataFrame(results, columns=['fraction_threshold', 'unique_count'])
|
539
|
+
closest_index = (results_df['unique_count'] - target_unique_count).abs().argmin()
|
540
|
+
closest_threshold = results_df.iloc[closest_index]
|
541
|
+
|
542
|
+
print(f"Closest Fraction Threshold: {closest_threshold['fraction_threshold']}")
|
543
|
+
print(f"Unique Count at Threshold: {closest_threshold['unique_count']}")
|
544
|
+
|
545
|
+
fig, ax = _line_plot(df=results_df, x='fraction_threshold', y='unique_count', log_x=log_x, log_y=log_y)
|
546
|
+
|
547
|
+
plt.axvline(x=closest_threshold['fraction_threshold'], color='black', linestyle='--',
|
548
|
+
label=f'Closest Threshold ({closest_threshold["fraction_threshold"]:.4f})')
|
549
|
+
plt.axhline(y=target_unique_count, color='black', linestyle='--',
|
550
|
+
label=f'Target Unique Count ({target_unique_count})')
|
551
|
+
|
552
|
+
plt.xlim(0,0.1)
|
553
|
+
plt.ylim(0,20)
|
554
|
+
|
555
|
+
if dst is not None:
|
556
|
+
fig_path = os.path.join(dst, 'results')
|
557
|
+
os.makedirs(fig_path, exist_ok=True)
|
558
|
+
fig_file_path = os.path.join(fig_path, 'fraction_threshold.pdf')
|
559
|
+
fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
|
560
|
+
print(f"Saved {fig_file_path}")
|
561
|
+
plt.show()
|
562
|
+
|
563
|
+
return closest_threshold['fraction_threshold']
|
564
|
+
|
565
|
+
if isinstance(settings['count_data'], str):
|
566
|
+
settings['count_data'] = [settings['count_data']]
|
567
|
+
|
568
|
+
dfs = []
|
569
|
+
for i, count_data in enumerate(settings['count_data']):
|
570
|
+
df = pd.read_csv(count_data)
|
571
|
+
df['plate'] = f'plate{i+1}'
|
572
|
+
df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
|
573
|
+
df['total_count'] = df.groupby(['prc'])['count'].transform('sum')
|
574
|
+
df['fraction'] = df['count'] / df['total_count']
|
575
|
+
dfs.append(df)
|
576
|
+
|
577
|
+
df = pd.concat(dfs, axis=0)
|
578
|
+
|
579
|
+
df = correct_metadata_column_names(df)
|
580
|
+
|
581
|
+
for c in settings['control_wells']:
|
582
|
+
df = df[df[settings['filter_column']] != c]
|
583
|
+
|
584
|
+
dst = os.path.dirname(settings['count_data'][0])
|
585
|
+
|
586
|
+
closest_threshold = find_and_visualize_fraction_threshold(df, settings['target_unique_count'], log_x=settings['log_x'], log_y=settings['log_y'], dst=dst)
|
587
|
+
|
588
|
+
# Apply the closest threshold to the DataFrame
|
589
|
+
df = df[df['fraction'] >= closest_threshold]
|
590
|
+
|
591
|
+
# Group by 'plate', 'row', 'column' and compute unique counts of 'grna'
|
592
|
+
unique_counts = df.groupby(['plate', 'row', 'column'])['grna'].nunique().reset_index(name='unique_counts')
|
593
|
+
unique_count_mean = df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
|
594
|
+
unique_count_std = df.groupby(['plate', 'row', 'column'])['grna'].nunique().std()
|
595
|
+
|
596
|
+
# Merge the unique counts back into the original DataFrame
|
597
|
+
df = pd.merge(df, unique_counts, on=['plate', 'row', 'column'], how='left')
|
598
|
+
|
599
|
+
print(f"unique_count mean: {unique_count_mean} std: {unique_count_std}")
|
600
|
+
display(df)
|
601
|
+
#_plot_density(df, dependent_variable='unique_counts')
|
602
|
+
plot_plates(df=df, variable='unique_counts', grouping='mean', min_max='allq', cmap='viridis',min_count=0, verbose=True, dst=dst)
|
603
|
+
|
604
|
+
return closest_threshold
|
spacr/settings.py
CHANGED
@@ -549,7 +549,8 @@ def get_perform_regression_default_settings(settings):
|
|
549
549
|
settings.setdefault('filter_column','column')
|
550
550
|
settings.setdefault('plate','plate1')
|
551
551
|
settings.setdefault('class_1_threshold',None)
|
552
|
-
settings.setdefault('metadata_files',['/home/carruthers/Documents/
|
552
|
+
settings.setdefault('metadata_files',['/home/carruthers/Documents/TGGT1_Summary.csv','/home/carruthers/Documents/TGME49_Summary.csv'])
|
553
|
+
settings.setdefault('volcano','gene')
|
553
554
|
settings.setdefault('toxo', True)
|
554
555
|
|
555
556
|
if settings['regression_type'] == 'quantile':
|
spacr/toxo.py
CHANGED
@@ -6,25 +6,53 @@ from adjustText import adjust_text
|
|
6
6
|
import pandas as pd
|
7
7
|
from scipy.stats import fisher_exact
|
8
8
|
from IPython.display import display
|
9
|
+
from matplotlib.legend import Legend
|
10
|
+
from matplotlib.transforms import Bbox
|
11
|
+
from brokenaxes import brokenaxes
|
9
12
|
|
10
|
-
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, split_axis_lims = [10, None, None, 10], save_path=None):
|
11
|
-
"""
|
12
|
-
Create a volcano plot with the ability to control the shape of points based on a categorical column,
|
13
|
-
color points based on a condition, annotate specific points based on p-value and coefficient thresholds,
|
14
|
-
and control the size of points.
|
15
|
-
"""
|
16
|
-
volcano_path = save_path
|
17
13
|
|
18
|
-
|
14
|
+
from matplotlib.gridspec import GridSpec
|
15
|
+
|
16
|
+
|
17
|
+
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',
|
18
|
+
point_size=50, figsize=20, threshold=0,
|
19
|
+
save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
|
20
|
+
|
21
|
+
markers = [
|
22
|
+
'o', # Circle
|
23
|
+
'X', # X-shaped marker
|
24
|
+
'^', # Upward triangle
|
25
|
+
's', # Square
|
26
|
+
'v', # Downward triangle
|
27
|
+
'P', # Plus-filled pentagon
|
28
|
+
'*', # Star
|
29
|
+
'+', # Plus
|
30
|
+
'x', # Cross
|
31
|
+
'.', # Point
|
32
|
+
',', # Pixel
|
33
|
+
'd', # Diamond
|
34
|
+
'D', # Thin diamond
|
35
|
+
'h', # Hexagon 1
|
36
|
+
'H', # Hexagon 2
|
37
|
+
'p', # Pentagon
|
38
|
+
'|', # Vertical line
|
39
|
+
'_', # Horizontal line
|
40
|
+
]
|
41
|
+
|
42
|
+
plt.rcParams.update({'font.size': 14})
|
43
|
+
|
44
|
+
# Load data
|
19
45
|
if isinstance(data_path, pd.DataFrame):
|
20
46
|
data = data_path
|
21
47
|
else:
|
22
48
|
data = pd.read_csv(data_path)
|
23
|
-
|
49
|
+
|
50
|
+
fontsize = 18
|
51
|
+
|
52
|
+
plt.rcParams.update({'font.size': fontsize})
|
24
53
|
data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
|
25
54
|
data['variable'].fillna(data['feature'], inplace=True)
|
26
|
-
|
27
|
-
data['gene_nr'] = split_columns[0]
|
55
|
+
data['gene_nr'] = data['variable'].str.split('_').str[0]
|
28
56
|
data = data[data['variable'] != 'Intercept']
|
29
57
|
|
30
58
|
# Load metadata
|
@@ -32,165 +60,110 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
|
|
32
60
|
metadata = metadata_path
|
33
61
|
else:
|
34
62
|
metadata = pd.read_csv(metadata_path)
|
35
|
-
|
36
63
|
metadata['gene_nr'] = metadata['gene_nr'].astype(str)
|
37
64
|
data['gene_nr'] = data['gene_nr'].astype(str)
|
38
65
|
|
39
|
-
|
40
|
-
merged_data
|
66
|
+
merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left')
|
67
|
+
merged_data[metadata_column].fillna('unknown', inplace=True)
|
41
68
|
|
42
|
-
|
43
|
-
|
44
|
-
|
69
|
+
# Define palette and markers
|
70
|
+
palette = {'pc': 'red', 'nc': 'green', 'control': 'white', 'other': 'gray'}
|
71
|
+
marker_dict = {val: marker for val, marker in zip(
|
72
|
+
merged_data[metadata_column].unique(), markers)}
|
45
73
|
|
46
|
-
#
|
47
|
-
|
48
|
-
|
49
|
-
categories=['other','pc', 'nc', 'control'],
|
50
|
-
ordered=True)
|
51
|
-
|
74
|
+
# Create the figure with custom spacing
|
75
|
+
fig = plt.figure(figsize=(figsize,figsize))
|
76
|
+
gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
|
52
77
|
|
53
|
-
|
78
|
+
ax_upper = fig.add_subplot(gs[0])
|
79
|
+
ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
|
54
80
|
|
55
|
-
#
|
56
|
-
|
57
|
-
fig, (ax1, ax2) = plt.subplots(
|
58
|
-
2, 1, figsize=(figsize, figsize),
|
59
|
-
sharex=True, gridspec_kw={'height_ratios': [1, 3]}
|
60
|
-
)
|
81
|
+
# Hide x-axis labels on the upper plot
|
82
|
+
ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
|
61
83
|
|
62
|
-
|
63
|
-
palette = {
|
64
|
-
'pc': 'red',
|
65
|
-
'nc': 'green',
|
66
|
-
'control': 'white',
|
67
|
-
'other': 'gray'}
|
84
|
+
hit_list = []
|
68
85
|
|
69
86
|
# Scatter plot on both axes
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
y='-log10(p_value)',
|
74
|
-
hue='condition', # Keep colors but prevent them from showing in the final legend
|
75
|
-
style=metadata_column if metadata_column else None, # Shape-based legend
|
76
|
-
s=point_size,
|
77
|
-
edgecolor='black',
|
78
|
-
palette=palette,
|
79
|
-
legend='brief', # Capture the full legend initially
|
80
|
-
alpha=0.8,
|
81
|
-
ax=ax2 # Lower plot
|
82
|
-
)
|
87
|
+
for _, row in merged_data.iterrows():
|
88
|
+
y_val = -np.log10(row['p_value'])
|
89
|
+
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
83
90
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
s=point_size,
|
91
|
-
palette=palette,
|
92
|
-
edgecolor='black',
|
93
|
-
legend=False, # Suppress legend for upper plot
|
94
|
-
alpha=0.8,
|
95
|
-
ax=ax1 # Upper plot
|
96
|
-
)
|
97
|
-
|
98
|
-
if isinstance(split_axis_lims, list):
|
99
|
-
if len(split_axis_lims) == 4:
|
100
|
-
ylim_min_ax1 = split_axis_lims[0]
|
101
|
-
if split_axis_lims[1] is None:
|
102
|
-
ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
|
103
|
-
else:
|
104
|
-
ylim_max_ax1 = split_axis_lims[1]
|
105
|
-
ylim_min_ax2 = split_axis_lims[2]
|
106
|
-
ylim_max_ax2 = split_axis_lims[3]
|
107
|
-
else:
|
108
|
-
ylim_min_ax1 = None
|
109
|
-
ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
|
110
|
-
ylim_min_ax2 = 0
|
111
|
-
ylim_max_ax2 = None
|
112
|
-
|
113
|
-
# Set axis limits and hide unnecessary parts
|
114
|
-
ax1.set_ylim(ylim_min_ax1, ylim_max_ax1)
|
115
|
-
ax2.set_ylim(0, ylim_max_ax2)
|
116
|
-
ax1.spines['bottom'].set_visible(False)
|
117
|
-
ax2.spines['top'].set_visible(False)
|
118
|
-
ax1.tick_params(labelbottom=False)
|
119
|
-
|
120
|
-
if ax1.get_legend() is not None:
|
121
|
-
ax1.legend_.remove()
|
122
|
-
ax1.get_legend().remove() # Extract handles and labels from the legend
|
123
|
-
handles, labels = ax2.get_legend_handles_labels()
|
124
|
-
|
125
|
-
# Identify shape-based legend entries (skip color-based entries)
|
126
|
-
shape_handles = handles[len(set(merged_data['condition'])):]
|
127
|
-
shape_labels = labels[len(set(merged_data['condition'])):]
|
128
|
-
|
129
|
-
# Set the legend with only shape-based entries
|
130
|
-
ax2.legend(
|
131
|
-
shape_handles,
|
132
|
-
shape_labels,
|
133
|
-
bbox_to_anchor=(1.05, 1),
|
134
|
-
loc='upper left',
|
135
|
-
borderaxespad=0.
|
136
|
-
)
|
137
|
-
|
138
|
-
ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
|
139
|
-
|
140
|
-
# Add vertical threshold lines to both plots
|
141
|
-
if threshold > 0:
|
142
|
-
for ax in (ax1, ax2):
|
143
|
-
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
|
144
|
-
ax.axvline(x=abs(threshold), linestyle='--', color='black')
|
91
|
+
ax.scatter(
|
92
|
+
row['coefficient'], y_val,
|
93
|
+
color=palette.get(row['condition'], 'gray'),
|
94
|
+
marker=marker_dict.get(row[metadata_column], 'o'),
|
95
|
+
s=point_size, edgecolor='black', alpha=0.6
|
96
|
+
)
|
145
97
|
|
146
|
-
# Add a horizontal line at p-value threshold (0.05)
|
147
|
-
ax2.axhline(y=-np.log10(0.05), color='black', linestyle='--')
|
148
|
-
|
149
|
-
# Annotate significant points on both axes
|
150
|
-
texts_ax1 = []
|
151
|
-
texts_ax2 = []
|
152
|
-
|
153
|
-
for i, row in merged_data.iterrows():
|
154
98
|
if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
|
155
|
-
|
156
|
-
#ax = ax1 if row['-log10(p_value)'] > 10 else ax2
|
99
|
+
hit_list.append(row['variable'])
|
157
100
|
|
158
|
-
|
101
|
+
# Set axis limits
|
102
|
+
ax_upper.set_ylim(y_lims[1])
|
103
|
+
ax_lower.set_ylim(y_lims[0])
|
104
|
+
ax_lower.set_xlim(x_lim)
|
159
105
|
|
106
|
+
ax_lower.spines['top'].set_visible(False)
|
107
|
+
ax_upper.spines['top'].set_visible(False)
|
108
|
+
ax_upper.spines['bottom'].set_visible(False)
|
160
109
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
va='bottom',
|
169
|
-
)
|
110
|
+
# Set x-axis and y-axis titles
|
111
|
+
ax_lower.set_xlabel('Coefficient') # X-axis title on the lower graph
|
112
|
+
ax_lower.set_ylabel('-log10(p-value)') # Y-axis title on the lower graph
|
113
|
+
ax_upper.set_ylabel('-log10(p-value)') # Y-axis title on the upper graph
|
114
|
+
|
115
|
+
for ax in [ax_upper, ax_lower]:
|
116
|
+
ax.spines['right'].set_visible(False)
|
170
117
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
texts_ax2.append(text)
|
118
|
+
# Add threshold lines to both axes
|
119
|
+
for ax in [ax_upper, ax_lower]:
|
120
|
+
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
|
121
|
+
ax.axvline(x=abs(threshold), linestyle='--', color='black')
|
176
122
|
|
177
|
-
|
178
|
-
adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1)
|
179
|
-
adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2)
|
123
|
+
ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
|
180
124
|
|
181
|
-
#
|
182
|
-
|
125
|
+
# Annotate significant points
|
126
|
+
texts_upper, texts_lower = [], [] # Collect text annotations separately
|
183
127
|
|
184
|
-
|
185
|
-
|
186
|
-
|
128
|
+
for _, row in merged_data.iterrows():
|
129
|
+
y_val = -np.log10(row['p_value'])
|
130
|
+
if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold):
|
131
|
+
continue
|
187
132
|
|
188
|
-
|
189
|
-
|
190
|
-
|
133
|
+
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
134
|
+
text = ax.text(row['coefficient'], y_val, row['variable'],
|
135
|
+
fontsize=fontsize, ha='center', va='bottom')
|
191
136
|
|
192
|
-
|
137
|
+
if ax == ax_upper:
|
138
|
+
texts_upper.append(text)
|
139
|
+
else:
|
140
|
+
texts_lower.append(text)
|
141
|
+
|
142
|
+
# Adjust text positions to avoid overlap
|
143
|
+
adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
|
144
|
+
adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
|
145
|
+
|
146
|
+
# Add a single legend on the lower axis
|
147
|
+
handles = [plt.Line2D([0], [0], marker=m, color='w', markerfacecolor='gray', markersize=10)
|
148
|
+
for m in marker_dict.values()]
|
149
|
+
labels = marker_dict.keys()
|
150
|
+
ax_lower.legend(handles,
|
151
|
+
labels,
|
152
|
+
bbox_to_anchor=(1.05, 1),
|
153
|
+
loc='upper left',
|
154
|
+
borderaxespad=0.25,
|
155
|
+
labelspacing=2,
|
156
|
+
handletextpad=0.25,
|
157
|
+
markerscale=2,
|
158
|
+
prop={'size': fontsize})
|
159
|
+
|
160
|
+
|
161
|
+
# Save and show the plot
|
162
|
+
if save_path:
|
163
|
+
plt.savefig(save_path, format='pdf', bbox_inches='tight')
|
193
164
|
plt.show()
|
165
|
+
|
166
|
+
return hit_list
|
194
167
|
|
195
168
|
def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
|
196
169
|
"""
|
@@ -331,4 +304,150 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
|
|
331
304
|
|
332
305
|
# Show the combined plot
|
333
306
|
plt.tight_layout()
|
307
|
+
plt.show()
|
308
|
+
|
309
|
+
def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
|
310
|
+
"""
|
311
|
+
Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
data (pd.DataFrame): The input DataFrame containing gene data.
|
315
|
+
gene_list (list): A list of gene names to highlight on the plot.
|
316
|
+
"""
|
317
|
+
# Ensure x_column is properly processed
|
318
|
+
def extract_gene_id(gene):
|
319
|
+
if isinstance(gene, str) and '_' in gene:
|
320
|
+
return gene.split('_')[1]
|
321
|
+
return str(gene)
|
322
|
+
|
323
|
+
data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce')
|
324
|
+
data = data.dropna(subset=[data_column])
|
325
|
+
data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce')
|
326
|
+
data = data.dropna(subset=[error_column])
|
327
|
+
|
328
|
+
data['x'] = data[x_column].apply(extract_gene_id)
|
329
|
+
|
330
|
+
# Sort by the data_column and assign ranks
|
331
|
+
data = data.sort_values(by=data_column).reset_index(drop=True)
|
332
|
+
data['rank'] = range(1, len(data) + 1)
|
333
|
+
|
334
|
+
# Prepare the x, y, and error values for plotting
|
335
|
+
x = data['rank']
|
336
|
+
y = data[data_column]
|
337
|
+
yerr = data[error_column]
|
338
|
+
|
339
|
+
# Create the plot
|
340
|
+
plt.figure(figsize=(10, 10))
|
341
|
+
|
342
|
+
# Plot the mean phenotype with standard error shading
|
343
|
+
plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2)
|
344
|
+
plt.fill_between(
|
345
|
+
x, y - yerr, y + yerr,
|
346
|
+
color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error'
|
347
|
+
)
|
348
|
+
|
349
|
+
# Prepare for adjustText
|
350
|
+
texts = [] # Store text objects for adjustment
|
351
|
+
|
352
|
+
# Highlight the genes in the gene_list
|
353
|
+
for gene in gene_list:
|
354
|
+
gene_id = extract_gene_id(gene)
|
355
|
+
gene_data = data[data['x'] == gene_id]
|
356
|
+
if not gene_data.empty:
|
357
|
+
# Scatter the highlighted points in purple and add labels for adjustment
|
358
|
+
plt.scatter(
|
359
|
+
gene_data['rank'],
|
360
|
+
gene_data[data_column],
|
361
|
+
color=(155/255, 55/255, 155/255),
|
362
|
+
s=200,
|
363
|
+
alpha=0.6,
|
364
|
+
label=f'Highlighted Gene: {gene}',
|
365
|
+
zorder=3 # Ensure the points are on top
|
366
|
+
)
|
367
|
+
# Add the text label next to the highlighted gene
|
368
|
+
texts.append(
|
369
|
+
plt.text(
|
370
|
+
gene_data['rank'].values[0],
|
371
|
+
gene_data[data_column].values[0],
|
372
|
+
gene,
|
373
|
+
fontsize=18,
|
374
|
+
ha='right'
|
375
|
+
)
|
376
|
+
)
|
377
|
+
|
378
|
+
# Adjust text to avoid overlap with lines drawn from points to text
|
379
|
+
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray'))
|
380
|
+
|
381
|
+
# Label the plot
|
382
|
+
plt.xlabel('Rank')
|
383
|
+
plt.ylabel('Mean Phenotype')
|
384
|
+
#plt.xticks(rotation=90) # Rotate x-axis labels for readability
|
385
|
+
plt.legend().remove() # Remove the legend if not needed
|
386
|
+
plt.tight_layout()
|
387
|
+
|
388
|
+
# Save the plot if a path is provided
|
389
|
+
if save_path:
|
390
|
+
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
|
391
|
+
print(f"Figure saved to {save_path}")
|
392
|
+
|
393
|
+
plt.show()
|
394
|
+
|
395
|
+
def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None):
|
396
|
+
"""
|
397
|
+
Generate a teal-to-white heatmap with the specified columns and genes.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
data (pd.DataFrame): The input DataFrame containing gene data.
|
401
|
+
gene_list (list): A list of genes to include in the heatmap.
|
402
|
+
columns (list): A list of column names to visualize as heatmaps.
|
403
|
+
normalize (bool): If True, normalize the values for each gene between 0 and 1.
|
404
|
+
save_path (str): Optional. If provided, the plot will be saved to this path.
|
405
|
+
"""
|
406
|
+
# Ensure x_column is properly processed
|
407
|
+
def extract_gene_id(gene):
|
408
|
+
if isinstance(gene, str) and '_' in gene:
|
409
|
+
return gene.split('_')[1]
|
410
|
+
return str(gene)
|
411
|
+
|
412
|
+
data['x'] = data[x_column].apply(extract_gene_id)
|
413
|
+
|
414
|
+
# Filter the data to only include the specified genes
|
415
|
+
filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns]
|
416
|
+
|
417
|
+
# Normalize each gene's values between 0 and 1 if normalize=True
|
418
|
+
if normalize:
|
419
|
+
filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
|
420
|
+
|
421
|
+
# Define the figure size dynamically based on the number of genes and columns
|
422
|
+
width = len(columns) * 4
|
423
|
+
height = len(gene_list) * 1
|
424
|
+
|
425
|
+
# Create the heatmap
|
426
|
+
plt.figure(figsize=(width, height))
|
427
|
+
cmap = sns.color_palette("viridis", as_cmap=True)
|
428
|
+
|
429
|
+
# Plot the heatmap with genes on the y-axis and columns on the x-axis
|
430
|
+
sns.heatmap(
|
431
|
+
filtered_data,
|
432
|
+
cmap=cmap,
|
433
|
+
cbar=True,
|
434
|
+
annot=False,
|
435
|
+
linewidths=0.5,
|
436
|
+
square=True
|
437
|
+
)
|
438
|
+
|
439
|
+
# Set the labels
|
440
|
+
plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability
|
441
|
+
plt.yticks(rotation=0) # Keep y-axis labels horizontal
|
442
|
+
plt.xlabel('')
|
443
|
+
plt.ylabel('')
|
444
|
+
|
445
|
+
# Adjust layout to ensure the plot fits well
|
446
|
+
plt.tight_layout()
|
447
|
+
|
448
|
+
# Save the plot if a path is provided
|
449
|
+
if save_path:
|
450
|
+
plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
|
451
|
+
print(f"Figure saved to {save_path}")
|
452
|
+
|
334
453
|
plt.show()
|
spacr/utils.py
CHANGED
@@ -4067,7 +4067,7 @@ def generate_path_list_from_db(db_path, file_metadata):
|
|
4067
4067
|
|
4068
4068
|
return all_paths
|
4069
4069
|
|
4070
|
-
def correct_paths(df, base_path):
|
4070
|
+
def correct_paths(df, base_path, folder='data'):
|
4071
4071
|
|
4072
4072
|
if isinstance(df, pd.DataFrame):
|
4073
4073
|
|
@@ -4083,9 +4083,9 @@ def correct_paths(df, base_path):
|
|
4083
4083
|
adjusted_image_paths = []
|
4084
4084
|
for path in image_paths:
|
4085
4085
|
if base_path not in path:
|
4086
|
-
parts = path.split('/
|
4086
|
+
parts = path.split(f'/{folder}/')
|
4087
4087
|
if len(parts) > 1:
|
4088
|
-
new_path = os.path.join(base_path, '
|
4088
|
+
new_path = os.path.join(base_path, f'{folder}', parts[1])
|
4089
4089
|
adjusted_image_paths.append(new_path)
|
4090
4090
|
else:
|
4091
4091
|
adjusted_image_paths.append(path)
|
@@ -5209,4 +5209,27 @@ def fill_holes_in_mask(mask):
|
|
5209
5209
|
# Assign the original label back to the filled object
|
5210
5210
|
filled_mask[filled_object] = i
|
5211
5211
|
|
5212
|
-
return filled_mask
|
5212
|
+
return filled_mask
|
5213
|
+
|
5214
|
+
def correct_metadata_column_names(df):
|
5215
|
+
if 'plate_name' in df.columns:
|
5216
|
+
df = df.rename(columns={'plate_name': 'plate'})
|
5217
|
+
if 'column_name' in df.columns:
|
5218
|
+
df = df.rename(columns={'column_name': 'column'})
|
5219
|
+
if 'col' in df.columns:
|
5220
|
+
df = df.rename(columns={'col': 'column'})
|
5221
|
+
if 'row_name' in df.columns:
|
5222
|
+
df = df.rename(columns={'row_name': 'row'})
|
5223
|
+
if 'grna_name' in df.columns:
|
5224
|
+
df = df.rename(columns={'grna_name': 'grna'})
|
5225
|
+
if 'plate_row' in df.columns:
|
5226
|
+
df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
|
5227
|
+
return df
|
5228
|
+
|
5229
|
+
def control_filelist(folder, mode='column', values=['01','02']):
|
5230
|
+
files = os.listdir(folder)
|
5231
|
+
if mode is 'column':
|
5232
|
+
filtered_files = [file for file in files if file.split('_')[1][1:] in values]
|
5233
|
+
if mode is 'row':
|
5234
|
+
filtered_files = [file for file in files if file.split('_')[1][:1] in values]
|
5235
|
+
return filtered_files
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: spacr
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.50
|
4
4
|
Summary: Spatial phenotype analysis of crisp screens (SpaCr)
|
5
5
|
Home-page: https://github.com/EinarOlafsson/spacr
|
6
6
|
Author: Einar Birnir Olafsson
|
@@ -66,6 +66,7 @@ Requires-Dist: gdown
|
|
66
66
|
Requires-Dist: IPython<9.0,>=8.18.1
|
67
67
|
Requires-Dist: ipykernel
|
68
68
|
Requires-Dist: ipywidgets<9.0,>=8.1.2
|
69
|
+
Requires-Dist: brokenaxes<1.0,>=0.6.2
|
69
70
|
Requires-Dist: huggingface-hub<0.25,>=0.24.0
|
70
71
|
Provides-Extra: dev
|
71
72
|
Requires-Dist: pytest<3.11,>=3.9; extra == "dev"
|