spacr 0.3.81__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/submodules.py CHANGED
@@ -1,3 +1,6 @@
1
+
2
+
3
+
1
4
  import seaborn as sns
2
5
  import os, random, sqlite3, re, shap
3
6
  import pandas as pd
@@ -10,7 +13,10 @@ from IPython.display import display
10
13
  from sklearn.ensemble import RandomForestClassifier
11
14
  from sklearn.inspection import permutation_importance
12
15
  from math import pi
13
- from scipy.stats import chi2_contingency
16
+ from scipy.stats import chi2_contingency, pearsonr
17
+ from scipy.spatial.distance import cosine
18
+
19
+ from sklearn.metrics import mean_absolute_error
14
20
 
15
21
  import matplotlib.pyplot as plt
16
22
  from natsort import natsorted
@@ -1132,3 +1138,278 @@ def analyze_class_proportion(settings):
1132
1138
  print("Statistical analysis results saved.")
1133
1139
 
1134
1140
  return output
1141
+
1142
+ def generate_score_heatmap(settings):
1143
+
1144
+ def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
1145
+
1146
+ df = pd.read_csv(csv)
1147
+ if 'col' in df.columns:
1148
+ df = df[df['col']==column]
1149
+ elif 'column' in df.columns:
1150
+ df['col'] = df['column']
1151
+ df = df[df['col']==column]
1152
+ if not plate is None:
1153
+ df['plate'] = f"plate{plate}"
1154
+ grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
1155
+ grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
1156
+ return grouped_df
1157
+
1158
+ def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
1159
+ df = pd.read_csv(csv)
1160
+ df = df[df['column_name']==column]
1161
+ if plate not in df.columns:
1162
+ df['plate'] = f"plate{plate}"
1163
+ df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
1164
+ grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
1165
+ grouped_df = grouped_df.rename(columns={'count': 'total_count'})
1166
+ merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
1167
+ merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
1168
+ merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
1169
+ return merged_df
1170
+
1171
+ def plot_multi_channel_heatmap(df, column='c3', cmap='coolwarm'):
1172
+ """
1173
+ Plot a heatmap with multiple channels as columns.
1174
+
1175
+ Parameters:
1176
+ - df: DataFrame with scores for different channels.
1177
+ - column: Column to filter by (default is 'c3').
1178
+ """
1179
+ # Extract row number and convert to integer for sorting
1180
+ df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
1181
+
1182
+ # Filter and sort by plate, row, and column
1183
+ df = df[df['col'] == column]
1184
+ df = df.sort_values(by=['plate', 'row_num', 'col'])
1185
+
1186
+ # Drop temporary 'row_num' column after sorting
1187
+ df = df.drop('row_num', axis=1)
1188
+
1189
+ # Create a new column combining plate, row, and column for the index
1190
+ df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
1191
+
1192
+ # Set 'plate_row_col' as the index
1193
+ df.set_index('plate_row_col', inplace=True)
1194
+
1195
+ # Extract only numeric data for the heatmap
1196
+ heatmap_data = df.select_dtypes(include=[float, int])
1197
+
1198
+ # Plot heatmap with square boxes, no annotations, and 'viridis' colormap
1199
+ plt.figure(figsize=(12, 8))
1200
+ sns.heatmap(
1201
+ heatmap_data,
1202
+ cmap=cmap,
1203
+ cbar=True,
1204
+ square=True,
1205
+ annot=False
1206
+ )
1207
+
1208
+ plt.title("Heatmap of Prediction Scores for All Channels")
1209
+ plt.xlabel("Channels")
1210
+ plt.ylabel("Plate-Row-Column")
1211
+ plt.tight_layout()
1212
+
1213
+ # Save the figure object and return it
1214
+ fig = plt.gcf()
1215
+ plt.show()
1216
+
1217
+ return fig
1218
+
1219
+
1220
+ def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
1221
+ # Ensure `folders` is a list
1222
+ if isinstance(folders, str):
1223
+ folders = [folders]
1224
+
1225
+ ls = [] # Initialize ls to store found CSV file paths
1226
+
1227
+ # Iterate over the provided folders
1228
+ for folder in folders:
1229
+ sub_folders = os.listdir(folder) # Get sub-folder list
1230
+ for sub_folder in sub_folders: # Iterate through sub-folders
1231
+ path = os.path.join(folder, sub_folder) # Join the full path
1232
+
1233
+ if os.path.isdir(path): # Check if it’s a directory
1234
+ csv = os.path.join(path, csv_name) # Join path to the CSV file
1235
+ if os.path.exists(csv): # If CSV exists, add to list
1236
+ ls.append(csv)
1237
+ else:
1238
+ print(f'No such file: {csv}')
1239
+
1240
+ # Initialize combined DataFrame
1241
+ combined_df = None
1242
+ print(f'Found {len(ls)} CSV files')
1243
+
1244
+ # Loop through all collected CSV files and process them
1245
+ for csv_file in ls:
1246
+ df = pd.read_csv(csv_file) # Read CSV into DataFrame
1247
+ df = df[df['col']==column]
1248
+ if not plate is None:
1249
+ df['plate'] = f"plate{plate}"
1250
+ # Group the data by 'plate', 'row', and 'col'
1251
+ grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
1252
+ # Use the CSV filename to create a new column name
1253
+ folder_name = os.path.dirname(csv_file).replace(".csv", "")
1254
+ new_column_name = os.path.basename(f"{folder_name}_{data_column}")
1255
+ print(new_column_name)
1256
+ grouped_df = grouped_df.rename(columns={data_column: new_column_name})
1257
+
1258
+ # Merge into the combined DataFrame
1259
+ if combined_df is None:
1260
+ combined_df = grouped_df
1261
+ else:
1262
+ combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
1263
+ combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
1264
+ return combined_df
1265
+
1266
+ def calculate_mae(df):
1267
+ """
1268
+ Calculate the MAE between each channel's predictions and the fraction column for all rows.
1269
+ """
1270
+ # Extract numeric columns excluding 'fraction' and 'prc'
1271
+ channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
1272
+
1273
+ mae_data = []
1274
+
1275
+ # Compute MAE for each channel with 'fraction' for all rows
1276
+ for column in channels.columns:
1277
+ for index, row in df.iterrows():
1278
+ mae = mean_absolute_error([row['fraction']], [row[column]])
1279
+ mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
1280
+
1281
+ # Convert the list of dictionaries to a DataFrame
1282
+ mae_df = pd.DataFrame(mae_data)
1283
+ return mae_df
1284
+
1285
+ result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
1286
+ df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
1287
+ df = df[df['grna_name']==settings['fraction_grna']]
1288
+ fraction_df = df[['fraction', 'prc']]
1289
+ merged_df = pd.merge(fraction_df, result_df, on=['prc'])
1290
+ cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
1291
+ cv_df = cv_df[[settings['data_column_cv'], 'prc']]
1292
+ merged_df = pd.merge(merged_df, cv_df, on=['prc'])
1293
+
1294
+ fig = plot_multi_channel_heatmap(merged_df, settings['column'], settings['cmap'])
1295
+ if 'row_number' in merged_df.columns:
1296
+ merged_df = merged_df.drop('row_num', axis=1)
1297
+ mae_df = calculate_mae(merged_df)
1298
+ if 'row_number' in mae_df.columns:
1299
+ mae_df = mae_df.drop('row_num', axis=1)
1300
+
1301
+ if not settings['dst'] is None:
1302
+ mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
1303
+ merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
1304
+ heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
1305
+ mae_df.to_csv(mae_dst, index=False)
1306
+ merged_df.to_csv(merged_dst, index=False)
1307
+ fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
1308
+ return merged_df
1309
+
1310
+ def post_regression_analysis(csv_file, grna_dict, grna_list, save=False):
1311
+
1312
+ def _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save=False):
1313
+ """
1314
+ Analyze and visualize the correlation matrix of gRNAs based on their fractions and overlap.
1315
+
1316
+ Parameters:
1317
+ df (pd.DataFrame): DataFrame with columns ['grna', 'fraction', 'prc'].
1318
+ grna_list (list): List of gRNAs to include in the correlation analysis.
1319
+ save_folder (str): Path to the folder where figures and data will be saved.
1320
+
1321
+ Returns:
1322
+ pd.DataFrame: Correlation matrix of the gRNAs.
1323
+ """
1324
+ # Filter the DataFrame to include only rows with gRNAs in the list
1325
+ filtered_df = df[df['grna'].isin(grna_list)]
1326
+
1327
+ # Pivot the data to create a prc-by-gRNA matrix, using fractions as values
1328
+ pivot_df = filtered_df.pivot_table(index='prc', columns='grna', values='fraction', aggfunc='sum').fillna(0)
1329
+
1330
+ # Compute the correlation matrix
1331
+ correlation_matrix = pivot_df.corr()
1332
+
1333
+ if save:
1334
+ # Save the correlation matrix
1335
+ correlation_matrix.to_csv(os.path.join(save_folder, 'correlation_matrix.csv'))
1336
+
1337
+ # Visualize the correlation matrix as a heatmap
1338
+ plt.figure(figsize=(10, 8))
1339
+ sns.heatmap(correlation_matrix, annot=False, cmap='coolwarm', cbar=True)
1340
+ plt.title('gRNA Correlation Matrix')
1341
+ plt.xlabel('gRNAs')
1342
+ plt.ylabel('gRNAs')
1343
+ plt.tight_layout()
1344
+
1345
+ if save:
1346
+ correlation_fig_path = os.path.join(save_folder, 'correlation_matrix_heatmap.pdf')
1347
+ plt.savefig(correlation_fig_path, dpi=300)
1348
+
1349
+ plt.show()
1350
+
1351
+ return correlation_matrix
1352
+
1353
+ def _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save=False):
1354
+ """
1355
+ Compute and visualize the effect sizes of gRNAs given fixed effect sizes for a subset of gRNAs.
1356
+
1357
+ Parameters:
1358
+ correlation_matrix (pd.DataFrame): Correlation matrix of gRNAs.
1359
+ grna_dict (dict): Dictionary of gRNAs with fixed effect sizes {grna_name: effect_size}.
1360
+ save_folder (str): Path to the folder where figures and data will be saved.
1361
+
1362
+ Returns:
1363
+ pd.Series: Effect sizes of all gRNAs.
1364
+ """
1365
+ # Ensure the matrix is symmetric and normalize values to 0-1
1366
+ corr_matrix = correlation_matrix.copy()
1367
+ corr_matrix = (corr_matrix - corr_matrix.min().min()) / (corr_matrix.max().max() - corr_matrix.min().min())
1368
+
1369
+ # Initialize the effect sizes with dtype float
1370
+ effect_sizes = pd.Series(0.0, index=corr_matrix.index)
1371
+
1372
+ # Set the effect sizes for the specified gRNAs
1373
+ for grna, size in grna_dict.items():
1374
+ effect_sizes[grna] = size
1375
+
1376
+ # Propagate the effect sizes
1377
+ for grna in corr_matrix.index:
1378
+ if grna not in grna_dict:
1379
+ # Weighted sum of correlations with the fixed gRNAs
1380
+ effect_sizes[grna] = np.dot(corr_matrix.loc[grna], effect_sizes) / np.sum(corr_matrix.loc[grna])
1381
+
1382
+ if save:
1383
+ # Save the effect sizes
1384
+ effect_sizes.to_csv(os.path.join(save_folder, 'effect_sizes.csv'))
1385
+
1386
+ # Visualization
1387
+ plt.figure(figsize=(10, 6))
1388
+ sns.barplot(x=effect_sizes.index, y=effect_sizes.values, palette="viridis", hue=None, legend=False)
1389
+
1390
+ #for i, val in enumerate(effect_sizes.values):
1391
+ # plt.text(i, val + 0.02, f"{val:.2f}", ha='center', va='bottom', fontsize=9)
1392
+ plt.title("Effect Sizes of gRNAs")
1393
+ plt.xlabel("gRNAs")
1394
+ plt.ylabel("Effect Size")
1395
+ plt.xticks(rotation=45)
1396
+ plt.tight_layout()
1397
+
1398
+ if save:
1399
+ effect_sizes_fig_path = os.path.join(save_folder, 'effect_sizes_barplot.pdf')
1400
+ plt.savefig(effect_sizes_fig_path, dpi=300)
1401
+
1402
+ plt.show()
1403
+
1404
+ return effect_sizes
1405
+
1406
+ # Ensure the save folder exists
1407
+ save_folder = os.path.join(os.path.dirname(csv_file), 'post_regression_analysis_results')
1408
+ os.makedirs(save_folder, exist_ok=True)
1409
+
1410
+ # Load the data
1411
+ df = pd.read_csv(csv_file)
1412
+
1413
+ # Perform analysis
1414
+ correlation_matrix = _analyze_and_visualize_grna_correlation(df, grna_list, save_folder, save)
1415
+ effect_sizes = _compute_effect_sizes(correlation_matrix, grna_dict, save_folder, save)
spacr/toxo.py CHANGED
@@ -24,117 +24,128 @@ from sklearn.metrics import mean_absolute_error
24
24
 
25
25
  from matplotlib.gridspec import GridSpec
26
26
 
27
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',point_size=50, figsize=20, threshold=0,save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 20]]):
27
28
 
28
- def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',
29
- point_size=50, figsize=20, threshold=0,
30
- save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
29
+ # Dictionary mapping compartment to color
31
30
 
32
- markers = [
33
- 'o', # Circle
34
- 'X', # X-shaped marker
35
- '^', # Upward triangle
36
- 's', # Square
37
- 'v', # Downward triangle
38
- 'P', # Plus-filled pentagon
39
- '*', # Star
40
- '+', # Plus
41
- 'x', # Cross
42
- '.', # Point
43
- ',', # Pixel
44
- 'd', # Diamond
45
- 'D', # Thin diamond
46
- 'h', # Hexagon 1
47
- 'H', # Hexagon 2
48
- 'p', # Pentagon
49
- '|', # Vertical line
50
- '_', # Horizontal line
51
- ]
31
+ colors = {'micronemes':'black',
32
+ 'rhoptries 1':'darkviolet',
33
+ 'rhoptries 2':'darkviolet',
34
+ 'nucleus - chromatin':'blue',
35
+ 'nucleus - non-chromatin':'blue',
36
+ 'dense granules':'teal',
37
+ 'ER 1':'pink',
38
+ 'ER 2':'pink',
39
+ 'unknown':'black',
40
+ 'tubulin cytoskeleton':'slategray',
41
+ 'IMC':'slategray',
42
+ 'PM - peripheral 1':'slategray',
43
+ 'PM - peripheral 2':'slategray',
44
+ 'cytosol':'turquoise',
45
+ 'mitochondrion - soluble':'red',
46
+ 'mitochondrion - membranes':'red',
47
+ 'apicoplast':'slategray',
48
+ 'Golgi':'green',
49
+ 'PM - integral':'slategray',
50
+ 'apical 1':'orange',
51
+ 'apical 2':'orange',
52
+ '19S proteasome':'slategray',
53
+ '20S proteasome':'slategray',
54
+ '60S ribosome':'slategray',
55
+ '40S ribosome':'slategray',
56
+ }
57
+
58
+ # Increase font size for better readability
59
+ fontsize = 18
60
+ plt.rcParams.update({'font.size': fontsize})
52
61
 
53
- plt.rcParams.update({'font.size': 14})
54
-
55
- # Load data
62
+ # --- Load data ---
56
63
  if isinstance(data_path, pd.DataFrame):
57
64
  data = data_path
58
65
  else:
59
66
  data = pd.read_csv(data_path)
60
-
61
- fontsize = 18
62
-
63
- plt.rcParams.update({'font.size': fontsize})
67
+
68
+ # Extract ‘variable’ and ‘gene_nr’ from your feature notation
64
69
  data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
65
70
  data['variable'].fillna(data['feature'], inplace=True)
66
71
  data['gene_nr'] = data['variable'].str.split('_').str[0]
67
72
  data = data[data['variable'] != 'Intercept']
68
73
 
69
- # Load metadata
74
+ # --- Load metadata ---
70
75
  if isinstance(metadata_path, pd.DataFrame):
71
76
  metadata = metadata_path
72
77
  else:
73
78
  metadata = pd.read_csv(metadata_path)
79
+
74
80
  metadata['gene_nr'] = metadata['gene_nr'].astype(str)
75
81
  data['gene_nr'] = data['gene_nr'].astype(str)
76
82
 
77
- merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left')
83
+ # Merge data and metadata
84
+ merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]],
85
+ on='gene_nr', how='left')
78
86
  merged_data[metadata_column].fillna('unknown', inplace=True)
79
87
 
80
- # Define palette and markers
81
- palette = {'pc': 'red', 'nc': 'green', 'control': 'white', 'other': 'gray'}
82
- marker_dict = {val: marker for val, marker in zip(
83
- merged_data[metadata_column].unique(), markers)}
84
-
85
- # Create the figure with custom spacing
86
- fig = plt.figure(figsize=(figsize,figsize))
88
+ # --- Create figure with "upper" and "lower" subplots sharing the x-axis ---
89
+ fig = plt.figure(figsize=(figsize, figsize))
87
90
  gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
88
-
89
91
  ax_upper = fig.add_subplot(gs[0])
90
92
  ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
91
93
 
92
94
  # Hide x-axis labels on the upper plot
93
95
  ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
94
96
 
97
+ # List to collect the variables (hits) that meet threshold criteria
95
98
  hit_list = []
96
99
 
97
- # Scatter plot on both axes
100
+ # --- Scatter plot on both axes ---
98
101
  for _, row in merged_data.iterrows():
99
102
  y_val = -np.log10(row['p_value'])
103
+
104
+ # Decide which axis to draw on based on the p-value
100
105
  ax = ax_upper if y_val > y_lims[1][0] else ax_lower
101
106
 
107
+ # Here is the main change: color by the colors dict
102
108
  ax.scatter(
103
- row['coefficient'], y_val,
104
- color=palette.get(row['condition'], 'gray'),
105
- marker=marker_dict.get(row[metadata_column], 'o'),
106
- s=point_size, edgecolor='black', alpha=0.6
109
+ row['coefficient'],
110
+ y_val,
111
+ color=colors.get(row[metadata_column], 'gray'), # <-- Use your color dict
112
+ marker='o', # You can fix a single marker if desired
113
+ s=point_size,
114
+ edgecolor='black',
115
+ alpha=0.6
107
116
  )
108
117
 
109
- if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
118
+ # Check significance thresholds
119
+ if (row['p_value'] <= 0.05) and (abs(row['coefficient']) >= abs(threshold)):
110
120
  hit_list.append(row['variable'])
111
121
 
112
- # Set axis limits
122
+ # --- Adjust axis limits ---
113
123
  ax_upper.set_ylim(y_lims[1])
114
124
  ax_lower.set_ylim(y_lims[0])
115
125
  ax_lower.set_xlim(x_lim)
116
126
 
127
+ # Hide top spines
117
128
  ax_lower.spines['top'].set_visible(False)
118
129
  ax_upper.spines['top'].set_visible(False)
119
130
  ax_upper.spines['bottom'].set_visible(False)
120
131
 
121
- # Set x-axis and y-axis titles
122
- ax_lower.set_xlabel('Coefficient') # X-axis title on the lower graph
123
- ax_lower.set_ylabel('-log10(p-value)') # Y-axis title on the lower graph
124
- ax_upper.set_ylabel('-log10(p-value)') # Y-axis title on the upper graph
125
-
132
+ # Set x-axis and y-axis labels
133
+ ax_lower.set_xlabel('Coefficient')
134
+ ax_lower.set_ylabel('-log10(p-value)')
135
+ ax_upper.set_ylabel('-log10(p-value)')
136
+
126
137
  for ax in [ax_upper, ax_lower]:
127
138
  ax.spines['right'].set_visible(False)
128
139
 
129
- # Add threshold lines to both axes
140
+ # --- Add threshold lines to both axes ---
130
141
  for ax in [ax_upper, ax_lower]:
131
142
  ax.axvline(x=-abs(threshold), linestyle='--', color='black')
132
143
  ax.axvline(x=abs(threshold), linestyle='--', color='black')
133
144
 
134
145
  ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
135
146
 
136
- # Annotate significant points
137
- texts_upper, texts_lower = [], [] # Collect text annotations separately
147
+ # --- Annotate significant points ---
148
+ texts_upper, texts_lower = [], []
138
149
 
139
150
  for _, row in merged_data.iterrows():
140
151
  y_val = -np.log10(row['p_value'])
@@ -142,38 +153,50 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
142
153
  continue
143
154
 
144
155
  ax = ax_upper if y_val > y_lims[1][0] else ax_lower
145
- text = ax.text(row['coefficient'], y_val, row['variable'],
146
- fontsize=fontsize, ha='center', va='bottom')
156
+ text = ax.text(
157
+ row['coefficient'],
158
+ y_val,
159
+ row['variable'],
160
+ fontsize=fontsize,
161
+ ha='center',
162
+ va='bottom'
163
+ )
147
164
 
148
165
  if ax == ax_upper:
149
166
  texts_upper.append(text)
150
167
  else:
151
168
  texts_lower.append(text)
152
169
 
153
- # Adjust text positions to avoid overlap
170
+ # Attempt to keep text labels from overlapping
154
171
  adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
155
172
  adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
156
173
 
157
- # Add a single legend on the lower axis
158
- handles = [plt.Line2D([0], [0], marker=m, color='w', markerfacecolor='gray', markersize=10)
159
- for m in marker_dict.values()]
160
- labels = marker_dict.keys()
161
- ax_lower.legend(handles,
162
- labels,
163
- bbox_to_anchor=(1.05, 1),
164
- loc='upper left',
165
- borderaxespad=0.25,
166
- labelspacing=2,
167
- handletextpad=0.25,
168
- markerscale=2,
169
- prop={'size': fontsize})
170
-
171
-
172
- # Save and show the plot
174
+ # --- Add a legend keyed by color (optional) ---
175
+ # If you'd like a legend that shows what each compartment color represents:
176
+ legend_handles = []
177
+ for comp, comp_color in colors.items():
178
+ # Create a “dummy” scatter for legend
179
+ legend_handles.append(
180
+ plt.Line2D([0], [0], marker='o', color=comp_color,
181
+ label=comp, linewidth=0, markersize=8)
182
+ )
183
+ # You can adjust the location and styling of the legend to taste:
184
+ ax_lower.legend(
185
+ handles=legend_handles,
186
+ bbox_to_anchor=(1.05, 1),
187
+ loc='upper left',
188
+ borderaxespad=0.25,
189
+ labelspacing=2,
190
+ handletextpad=0.25,
191
+ markerscale=1.5,
192
+ prop={'size': fontsize}
193
+ )
194
+
195
+ # --- Save and show ---
173
196
  if save_path:
174
197
  plt.savefig(save_path, format='pdf', bbox_inches='tight')
175
198
  plt.show()
176
-
199
+
177
200
  return hit_list
178
201
 
179
202
  def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):