spacr 0.3.81__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -4
- spacr/core.py +27 -13
- spacr/deep_spacr.py +378 -5
- spacr/gui_core.py +69 -38
- spacr/gui_elements.py +193 -3
- spacr/gui_utils.py +1 -1
- spacr/io.py +5 -176
- spacr/measure.py +10 -6
- spacr/ml.py +369 -46
- spacr/plot.py +201 -90
- spacr/settings.py +52 -16
- spacr/submodules.py +282 -1
- spacr/toxo.py +98 -75
- spacr/utils.py +128 -36
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/METADATA +2 -1
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/RECORD +20 -20
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/LICENSE +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/WHEEL +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/top_level.txt +0 -0
spacr/submodules.py
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
|
2
|
+
|
3
|
+
|
1
4
|
import seaborn as sns
|
2
5
|
import os, random, sqlite3, re, shap
|
3
6
|
import pandas as pd
|
@@ -10,7 +13,10 @@ from IPython.display import display
|
|
10
13
|
from sklearn.ensemble import RandomForestClassifier
|
11
14
|
from sklearn.inspection import permutation_importance
|
12
15
|
from math import pi
|
13
|
-
from scipy.stats import chi2_contingency
|
16
|
+
from scipy.stats import chi2_contingency, pearsonr
|
17
|
+
from scipy.spatial.distance import cosine
|
18
|
+
|
19
|
+
from sklearn.metrics import mean_absolute_error
|
14
20
|
|
15
21
|
import matplotlib.pyplot as plt
|
16
22
|
from natsort import natsorted
|
@@ -1132,3 +1138,278 @@ def analyze_class_proportion(settings):
|
|
1132
1138
|
print("Statistical analysis results saved.")
|
1133
1139
|
|
1134
1140
|
return output
|
1141
|
+
|
1142
|
+
def generate_score_heatmap(settings):
|
1143
|
+
|
1144
|
+
def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
|
1145
|
+
|
1146
|
+
df = pd.read_csv(csv)
|
1147
|
+
if 'col' in df.columns:
|
1148
|
+
df = df[df['col']==column]
|
1149
|
+
elif 'column' in df.columns:
|
1150
|
+
df['col'] = df['column']
|
1151
|
+
df = df[df['col']==column]
|
1152
|
+
if not plate is None:
|
1153
|
+
df['plate'] = f"plate{plate}"
|
1154
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
1155
|
+
grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
|
1156
|
+
return grouped_df
|
1157
|
+
|
1158
|
+
def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
|
1159
|
+
df = pd.read_csv(csv)
|
1160
|
+
df = df[df['column_name']==column]
|
1161
|
+
if plate not in df.columns:
|
1162
|
+
df['plate'] = f"plate{plate}"
|
1163
|
+
df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
|
1164
|
+
grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
|
1165
|
+
grouped_df = grouped_df.rename(columns={'count': 'total_count'})
|
1166
|
+
merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
|
1167
|
+
merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
|
1168
|
+
merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
|
1169
|
+
return merged_df
|
1170
|
+
|
1171
|
+
def plot_multi_channel_heatmap(df, column='c3', cmap='coolwarm'):
|
1172
|
+
"""
|
1173
|
+
Plot a heatmap with multiple channels as columns.
|
1174
|
+
|
1175
|
+
Parameters:
|
1176
|
+
- df: DataFrame with scores for different channels.
|
1177
|
+
- column: Column to filter by (default is 'c3').
|
1178
|
+
"""
|
1179
|
+
# Extract row number and convert to integer for sorting
|
1180
|
+
df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
|
1181
|
+
|
1182
|
+
# Filter and sort by plate, row, and column
|
1183
|
+
df = df[df['col'] == column]
|
1184
|
+
df = df.sort_values(by=['plate', 'row_num', 'col'])
|
1185
|
+
|
1186
|
+
# Drop temporary 'row_num' column after sorting
|
1187
|
+
df = df.drop('row_num', axis=1)
|
1188
|
+
|
1189
|
+
# Create a new column combining plate, row, and column for the index
|
1190
|
+
df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
|
1191
|
+
|
1192
|
+
# Set 'plate_row_col' as the index
|
1193
|
+
df.set_index('plate_row_col', inplace=True)
|
1194
|
+
|
1195
|
+
# Extract only numeric data for the heatmap
|
1196
|
+
heatmap_data = df.select_dtypes(include=[float, int])
|
1197
|
+
|
1198
|
+
# Plot heatmap with square boxes, no annotations, and 'viridis' colormap
|
1199
|
+
plt.figure(figsize=(12, 8))
|
1200
|
+
sns.heatmap(
|
1201
|
+
heatmap_data,
|
1202
|
+
cmap=cmap,
|
1203
|
+
cbar=True,
|
1204
|
+
square=True,
|
1205
|
+
annot=False
|
1206
|
+
)
|
1207
|
+
|
1208
|
+
plt.title("Heatmap of Prediction Scores for All Channels")
|
1209
|
+
plt.xlabel("Channels")
|
1210
|
+
plt.ylabel("Plate-Row-Column")
|
1211
|
+
plt.tight_layout()
|
1212
|
+
|
1213
|
+
# Save the figure object and return it
|
1214
|
+
fig = plt.gcf()
|
1215
|
+
plt.show()
|
1216
|
+
|
1217
|
+
return fig
|
1218
|
+
|
1219
|
+
|
1220
|
+
def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
|
1221
|
+
# Ensure `folders` is a list
|
1222
|
+
if isinstance(folders, str):
|
1223
|
+
folders = [folders]
|
1224
|
+
|
1225
|
+
ls = [] # Initialize ls to store found CSV file paths
|
1226
|
+
|
1227
|
+
# Iterate over the provided folders
|
1228
|
+
for folder in folders:
|
1229
|
+
sub_folders = os.listdir(folder) # Get sub-folder list
|
1230
|
+
for sub_folder in sub_folders: # Iterate through sub-folders
|
1231
|
+
path = os.path.join(folder, sub_folder) # Join the full path
|
1232
|
+
|
1233
|
+
if os.path.isdir(path): # Check if it’s a directory
|
1234
|
+
csv = os.path.join(path, csv_name) # Join path to the CSV file
|
1235
|
+
if os.path.exists(csv): # If CSV exists, add to list
|
1236
|
+
ls.append(csv)
|
1237
|
+
else:
|
1238
|
+
print(f'No such file: {csv}')
|
1239
|
+
|
1240
|
+
# Initialize combined DataFrame
|
1241
|
+
combined_df = None
|
1242
|
+
print(f'Found {len(ls)} CSV files')
|
1243
|
+
|
1244
|
+
# Loop through all collected CSV files and process them
|
1245
|
+
for csv_file in ls:
|
1246
|
+
df = pd.read_csv(csv_file) # Read CSV into DataFrame
|
1247
|
+
df = df[df['col']==column]
|
1248
|
+
if not plate is None:
|
1249
|
+
df['plate'] = f"plate{plate}"
|
1250
|
+
# Group the data by 'plate', 'row', and 'col'
|
1251
|
+
grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
|
1252
|
+
# Use the CSV filename to create a new column name
|
1253
|
+
folder_name = os.path.dirname(csv_file).replace(".csv", "")
|
1254
|
+
new_column_name = os.path.basename(f"{folder_name}_{data_column}")
|
1255
|
+
print(new_column_name)
|
1256
|
+
grouped_df = grouped_df.rename(columns={data_column: new_column_name})
|
1257
|
+
|
1258
|
+
# Merge into the combined DataFrame
|
1259
|
+
if combined_df is None:
|
1260
|
+
combined_df = grouped_df
|
1261
|
+
else:
|
1262
|
+
combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
|
1263
|
+
combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
|
1264
|
+
return combined_df
|
1265
|
+
|
1266
|
+
def calculate_mae(df):
|
1267
|
+
"""
|
1268
|
+
Calculate the MAE between each channel's predictions and the fraction column for all rows.
|
1269
|
+
"""
|
1270
|
+
# Extract numeric columns excluding 'fraction' and 'prc'
|
1271
|
+
channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
|
1272
|
+
|
1273
|
+
mae_data = []
|
1274
|
+
|
1275
|
+
# Compute MAE for each channel with 'fraction' for all rows
|
1276
|
+
for column in channels.columns:
|
1277
|
+
for index, row in df.iterrows():
|
1278
|
+
mae = mean_absolute_error([row['fraction']], [row[column]])
|
1279
|
+
mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
|
1280
|
+
|
1281
|
+
# Convert the list of dictionaries to a DataFrame
|
1282
|
+
mae_df = pd.DataFrame(mae_data)
|
1283
|
+
return mae_df
|
1284
|
+
|
1285
|
+
result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
|
1286
|
+
df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
|
1287
|
+
df = df[df['grna_name']==settings['fraction_grna']]
|
1288
|
+
fraction_df = df[['fraction', 'prc']]
|
1289
|
+
merged_df = pd.merge(fraction_df, result_df, on=['prc'])
|
1290
|
+
cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
|
1291
|
+
cv_df = cv_df[[settings['data_column_cv'], 'prc']]
|
1292
|
+
merged_df = pd.merge(merged_df, cv_df, on=['prc'])
|
1293
|
+
|
1294
|
+
fig = plot_multi_channel_heatmap(merged_df, settings['column'], settings['cmap'])
|
1295
|
+
if 'row_number' in merged_df.columns:
|
1296
|
+
merged_df = merged_df.drop('row_num', axis=1)
|
1297
|
+
mae_df = calculate_mae(merged_df)
|
1298
|
+
if 'row_number' in mae_df.columns:
|
1299
|
+
mae_df = mae_df.drop('row_num', axis=1)
|
1300
|
+
|
1301
|
+
if not settings['dst'] is None:
|
1302
|
+
mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
|
1303
|
+
merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
|
1304
|
+
heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
|
1305
|
+
mae_df.to_csv(mae_dst, index=False)
|
1306
|
+
merged_df.to_csv(merged_dst, index=False)
|
1307
|
+
fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
|
1308
|
+
return merged_df
|
1309
|
+
|
1310
|
+
def post_regression_analysis(csv_file, grna_dict, grna_list, save=False):
|
1311
|
+
|
1312
|
+
def _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save=False):
|
1313
|
+
"""
|
1314
|
+
Analyze and visualize the correlation matrix of gRNAs based on their fractions and overlap.
|
1315
|
+
|
1316
|
+
Parameters:
|
1317
|
+
df (pd.DataFrame): DataFrame with columns ['grna', 'fraction', 'prc'].
|
1318
|
+
grna_list (list): List of gRNAs to include in the correlation analysis.
|
1319
|
+
save_folder (str): Path to the folder where figures and data will be saved.
|
1320
|
+
|
1321
|
+
Returns:
|
1322
|
+
pd.DataFrame: Correlation matrix of the gRNAs.
|
1323
|
+
"""
|
1324
|
+
# Filter the DataFrame to include only rows with gRNAs in the list
|
1325
|
+
filtered_df = df[df['grna'].isin(grna_list)]
|
1326
|
+
|
1327
|
+
# Pivot the data to create a prc-by-gRNA matrix, using fractions as values
|
1328
|
+
pivot_df = filtered_df.pivot_table(index='prc', columns='grna', values='fraction', aggfunc='sum').fillna(0)
|
1329
|
+
|
1330
|
+
# Compute the correlation matrix
|
1331
|
+
correlation_matrix = pivot_df.corr()
|
1332
|
+
|
1333
|
+
if save:
|
1334
|
+
# Save the correlation matrix
|
1335
|
+
correlation_matrix.to_csv(os.path.join(save_folder, 'correlation_matrix.csv'))
|
1336
|
+
|
1337
|
+
# Visualize the correlation matrix as a heatmap
|
1338
|
+
plt.figure(figsize=(10, 8))
|
1339
|
+
sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', cbar=True)
|
1340
|
+
plt.title('gRNA Correlation Matrix')
|
1341
|
+
plt.xlabel('gRNAs')
|
1342
|
+
plt.ylabel('gRNAs')
|
1343
|
+
plt.tight_layout()
|
1344
|
+
|
1345
|
+
if save:
|
1346
|
+
correlation_fig_path = os.path.join(save_folder, 'correlation_matrix_heatmap.pdf')
|
1347
|
+
plt.savefig(correlation_fig_path, dpi=300)
|
1348
|
+
|
1349
|
+
plt.show()
|
1350
|
+
|
1351
|
+
return correlation_matrix
|
1352
|
+
|
1353
|
+
def _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save=False):
|
1354
|
+
"""
|
1355
|
+
Compute and visualize the effect sizes of gRNAs given fixed effect sizes for a subset of gRNAs.
|
1356
|
+
|
1357
|
+
Parameters:
|
1358
|
+
correlation_matrix (pd.DataFrame): Correlation matrix of gRNAs.
|
1359
|
+
grna_dict (dict): Dictionary of gRNAs with fixed effect sizes {grna_name: effect_size}.
|
1360
|
+
save_folder (str): Path to the folder where figures and data will be saved.
|
1361
|
+
|
1362
|
+
Returns:
|
1363
|
+
pd.Series: Effect sizes of all gRNAs.
|
1364
|
+
"""
|
1365
|
+
# Ensure the matrix is symmetric and normalize values to 0-1
|
1366
|
+
corr_matrix = correlation_matrix.copy()
|
1367
|
+
corr_matrix = (corr_matrix - corr_matrix.min().min()) / (corr_matrix.max().max() - corr_matrix.min().min())
|
1368
|
+
|
1369
|
+
# Initialize the effect sizes with dtype float
|
1370
|
+
effect_sizes = pd.Series(0.0, index=corr_matrix.index)
|
1371
|
+
|
1372
|
+
# Set the effect sizes for the specified gRNAs
|
1373
|
+
for grna, size in grna_dict.items():
|
1374
|
+
effect_sizes[grna] = size
|
1375
|
+
|
1376
|
+
# Propagate the effect sizes
|
1377
|
+
for grna in corr_matrix.index:
|
1378
|
+
if grna not in grna_dict:
|
1379
|
+
# Weighted sum of correlations with the fixed gRNAs
|
1380
|
+
effect_sizes[grna] = np.dot(corr_matrix.loc[grna], effect_sizes) / np.sum(corr_matrix.loc[grna])
|
1381
|
+
|
1382
|
+
if save:
|
1383
|
+
# Save the effect sizes
|
1384
|
+
effect_sizes.to_csv(os.path.join(save_folder, 'effect_sizes.csv'))
|
1385
|
+
|
1386
|
+
# Visualization
|
1387
|
+
plt.figure(figsize=(10, 6))
|
1388
|
+
sns.barplot(x=effect_sizes.index, y=effect_sizes.values, palette="viridis", hue=None, legend=False)
|
1389
|
+
|
1390
|
+
#for i, val in enumerate(effect_sizes.values):
|
1391
|
+
# plt.text(i, val + 0.02, f"{val:.2f}", ha='center', va='bottom', fontsize=9)
|
1392
|
+
plt.title("Effect Sizes of gRNAs")
|
1393
|
+
plt.xlabel("gRNAs")
|
1394
|
+
plt.ylabel("Effect Size")
|
1395
|
+
plt.xticks(rotation=45)
|
1396
|
+
plt.tight_layout()
|
1397
|
+
|
1398
|
+
if save:
|
1399
|
+
effect_sizes_fig_path = os.path.join(save_folder, 'effect_sizes_barplot.pdf')
|
1400
|
+
plt.savefig(effect_sizes_fig_path, dpi=300)
|
1401
|
+
|
1402
|
+
plt.show()
|
1403
|
+
|
1404
|
+
return effect_sizes
|
1405
|
+
|
1406
|
+
# Ensure the save folder exists
|
1407
|
+
save_folder = os.path.join(os.path.dirname(csv_file), 'post_regression_analysis_results')
|
1408
|
+
os.makedirs(save_folder, exist_ok=True)
|
1409
|
+
|
1410
|
+
# Load the data
|
1411
|
+
df = pd.read_csv(csv_file)
|
1412
|
+
|
1413
|
+
# Perform analysis
|
1414
|
+
correlation_matrix = _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save)
|
1415
|
+
effect_sizes = _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save)
|
spacr/toxo.py
CHANGED
@@ -24,117 +24,128 @@ from sklearn.metrics import mean_absolute_error
|
|
24
24
|
|
25
25
|
from matplotlib.gridspec import GridSpec
|
26
26
|
|
27
|
+
def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
|
27
28
|
|
28
|
-
|
29
|
-
point_size=50, figsize=20, threshold=0,
|
30
|
-
save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
|
29
|
+
# Dictionary mapping compartment to color
|
31
30
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
31
|
+
colors = {'micronemes':'black',
|
32
|
+
'rhoptries 1':'darkviolet',
|
33
|
+
'rhoptries 2':'darkviolet',
|
34
|
+
'nucleus - chromatin':'blue',
|
35
|
+
'nucleus - non-chromatin':'blue',
|
36
|
+
'dense granules':'teal',
|
37
|
+
'ER 1':'pink',
|
38
|
+
'ER 2':'pink',
|
39
|
+
'unknown':'black',
|
40
|
+
'tubulin cytoskeleton':'slategray',
|
41
|
+
'IMC':'slategray',
|
42
|
+
'PM - peripheral 1':'slategray',
|
43
|
+
'PM - peripheral 2':'slategray',
|
44
|
+
'cytosol':'turquoise',
|
45
|
+
'mitochondrion - soluble':'red',
|
46
|
+
'mitochondrion - membranes':'red',
|
47
|
+
'apicoplast':'slategray',
|
48
|
+
'Golgi':'green',
|
49
|
+
'PM - integral':'slategray',
|
50
|
+
'apical 1':'orange',
|
51
|
+
'apical 2':'orange',
|
52
|
+
'19S proteasome':'slategray',
|
53
|
+
'20S proteasome':'slategray',
|
54
|
+
'60S ribosome':'slategray',
|
55
|
+
'40S ribosome':'slategray',
|
56
|
+
}
|
57
|
+
|
58
|
+
# Increase font size for better readability
|
59
|
+
fontsize = 18
|
60
|
+
plt.rcParams.update({'font.size': fontsize})
|
52
61
|
|
53
|
-
|
54
|
-
|
55
|
-
# Load data
|
62
|
+
# --- Load data ---
|
56
63
|
if isinstance(data_path, pd.DataFrame):
|
57
64
|
data = data_path
|
58
65
|
else:
|
59
66
|
data = pd.read_csv(data_path)
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
plt.rcParams.update({'font.size': fontsize})
|
67
|
+
|
68
|
+
# Extract ‘variable’ and ‘gene_nr’ from your feature notation
|
64
69
|
data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
|
65
70
|
data['variable'].fillna(data['feature'], inplace=True)
|
66
71
|
data['gene_nr'] = data['variable'].str.split('_').str[0]
|
67
72
|
data = data[data['variable'] != 'Intercept']
|
68
73
|
|
69
|
-
# Load metadata
|
74
|
+
# --- Load metadata ---
|
70
75
|
if isinstance(metadata_path, pd.DataFrame):
|
71
76
|
metadata = metadata_path
|
72
77
|
else:
|
73
78
|
metadata = pd.read_csv(metadata_path)
|
79
|
+
|
74
80
|
metadata['gene_nr'] = metadata['gene_nr'].astype(str)
|
75
81
|
data['gene_nr'] = data['gene_nr'].astype(str)
|
76
82
|
|
77
|
-
|
83
|
+
# Merge data and metadata
|
84
|
+
merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]],
|
85
|
+
on='gene_nr', how='left')
|
78
86
|
merged_data[metadata_column].fillna('unknown', inplace=True)
|
79
87
|
|
80
|
-
#
|
81
|
-
|
82
|
-
marker_dict = {val: marker for val, marker in zip(
|
83
|
-
merged_data[metadata_column].unique(), markers)}
|
84
|
-
|
85
|
-
# Create the figure with custom spacing
|
86
|
-
fig = plt.figure(figsize=(figsize,figsize))
|
88
|
+
# --- Create figure with "upper" and "lower" subplots sharing the x-axis ---
|
89
|
+
fig = plt.figure(figsize=(figsize, figsize))
|
87
90
|
gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
|
88
|
-
|
89
91
|
ax_upper = fig.add_subplot(gs[0])
|
90
92
|
ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
|
91
93
|
|
92
94
|
# Hide x-axis labels on the upper plot
|
93
95
|
ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
|
94
96
|
|
97
|
+
# List to collect the variables (hits) that meet threshold criteria
|
95
98
|
hit_list = []
|
96
99
|
|
97
|
-
# Scatter plot on both axes
|
100
|
+
# --- Scatter plot on both axes ---
|
98
101
|
for _, row in merged_data.iterrows():
|
99
102
|
y_val = -np.log10(row['p_value'])
|
103
|
+
|
104
|
+
# Decide which axis to draw on based on the p-value
|
100
105
|
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
101
106
|
|
107
|
+
# Here is the main change: color by the colors dict
|
102
108
|
ax.scatter(
|
103
|
-
row['coefficient'],
|
104
|
-
|
105
|
-
|
106
|
-
|
109
|
+
row['coefficient'],
|
110
|
+
y_val,
|
111
|
+
color=colors.get(row[metadata_column], 'gray'), # <-- Use your color dict
|
112
|
+
marker='o', # You can fix a single marker if desired
|
113
|
+
s=point_size,
|
114
|
+
edgecolor='black',
|
115
|
+
alpha=0.6
|
107
116
|
)
|
108
117
|
|
109
|
-
|
118
|
+
# Check significance thresholds
|
119
|
+
if (row['p_value'] <= 0.05) and (abs(row['coefficient']) >= abs(threshold)):
|
110
120
|
hit_list.append(row['variable'])
|
111
121
|
|
112
|
-
#
|
122
|
+
# --- Adjust axis limits ---
|
113
123
|
ax_upper.set_ylim(y_lims[1])
|
114
124
|
ax_lower.set_ylim(y_lims[0])
|
115
125
|
ax_lower.set_xlim(x_lim)
|
116
126
|
|
127
|
+
# Hide top spines
|
117
128
|
ax_lower.spines['top'].set_visible(False)
|
118
129
|
ax_upper.spines['top'].set_visible(False)
|
119
130
|
ax_upper.spines['bottom'].set_visible(False)
|
120
131
|
|
121
|
-
# Set x-axis and y-axis
|
122
|
-
ax_lower.set_xlabel('Coefficient')
|
123
|
-
ax_lower.set_ylabel('-log10(p-value)')
|
124
|
-
ax_upper.set_ylabel('-log10(p-value)')
|
125
|
-
|
132
|
+
# Set x-axis and y-axis labels
|
133
|
+
ax_lower.set_xlabel('Coefficient')
|
134
|
+
ax_lower.set_ylabel('-log10(p-value)')
|
135
|
+
ax_upper.set_ylabel('-log10(p-value)')
|
136
|
+
|
126
137
|
for ax in [ax_upper, ax_lower]:
|
127
138
|
ax.spines['right'].set_visible(False)
|
128
139
|
|
129
|
-
# Add threshold lines to both axes
|
140
|
+
# --- Add threshold lines to both axes ---
|
130
141
|
for ax in [ax_upper, ax_lower]:
|
131
142
|
ax.axvline(x=-abs(threshold), linestyle='--', color='black')
|
132
143
|
ax.axvline(x=abs(threshold), linestyle='--', color='black')
|
133
144
|
|
134
145
|
ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
|
135
146
|
|
136
|
-
# Annotate significant points
|
137
|
-
texts_upper, texts_lower = [], []
|
147
|
+
# --- Annotate significant points ---
|
148
|
+
texts_upper, texts_lower = [], []
|
138
149
|
|
139
150
|
for _, row in merged_data.iterrows():
|
140
151
|
y_val = -np.log10(row['p_value'])
|
@@ -142,38 +153,50 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
|
|
142
153
|
continue
|
143
154
|
|
144
155
|
ax = ax_upper if y_val > y_lims[1][0] else ax_lower
|
145
|
-
text = ax.text(
|
146
|
-
|
156
|
+
text = ax.text(
|
157
|
+
row['coefficient'],
|
158
|
+
y_val,
|
159
|
+
row['variable'],
|
160
|
+
fontsize=fontsize,
|
161
|
+
ha='center',
|
162
|
+
va='bottom'
|
163
|
+
)
|
147
164
|
|
148
165
|
if ax == ax_upper:
|
149
166
|
texts_upper.append(text)
|
150
167
|
else:
|
151
168
|
texts_lower.append(text)
|
152
169
|
|
153
|
-
#
|
170
|
+
# Attempt to keep text labels from overlapping
|
154
171
|
adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
|
155
172
|
adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
|
156
173
|
|
157
|
-
# Add a
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
174
|
+
# --- Add a legend keyed by color (optional) ---
|
175
|
+
# If you'd like a legend that shows what each compartment color represents:
|
176
|
+
legend_handles = []
|
177
|
+
for comp, comp_color in colors.items():
|
178
|
+
# Create a “dummy” scatter for legend
|
179
|
+
legend_handles.append(
|
180
|
+
plt.Line2D([0], [0], marker='o', color=comp_color,
|
181
|
+
label=comp, linewidth=0, markersize=8)
|
182
|
+
)
|
183
|
+
# You can adjust the location and styling of the legend to taste:
|
184
|
+
ax_lower.legend(
|
185
|
+
handles=legend_handles,
|
186
|
+
bbox_to_anchor=(1.05, 1),
|
187
|
+
loc='upper left',
|
188
|
+
borderaxespad=0.25,
|
189
|
+
labelspacing=2,
|
190
|
+
handletextpad=0.25,
|
191
|
+
markerscale=1.5,
|
192
|
+
prop={'size': fontsize}
|
193
|
+
)
|
194
|
+
|
195
|
+
# --- Save and show ---
|
173
196
|
if save_path:
|
174
197
|
plt.savefig(save_path, format='pdf', bbox_inches='tight')
|
175
198
|
plt.show()
|
176
|
-
|
199
|
+
|
177
200
|
return hit_list
|
178
201
|
|
179
202
|
def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
|