spacr 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +140 -2493
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +624 -44
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +280 -15
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +271 -171
  27. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/submodules.py ADDED
@@ -0,0 +1,348 @@
1
+ import seaborn as sns
2
+ import os, random, sqlite3
3
+ import pandas as pd
4
+ import numpy as np
5
+ import cellpose
6
+ from skimage.measure import regionprops, label
7
+ from cellpose import models as cp_models
8
+ from cellpose import train as train_cp
9
+ from IPython.display import display
10
+
11
+ def analyze_recruitment(settings={}):
12
+ """
13
+ Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
14
+
15
+ Parameters:
16
+ settings (dict): settings.
17
+
18
+ Returns:
19
+ None
20
+ """
21
+
22
+ from .io import _read_and_merge_data, _results_to_csv
23
+ from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
24
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
25
+ from .settings import get_analyze_recruitment_default_settings
26
+
27
+ settings = get_analyze_recruitment_default_settings(settings=settings)
28
+ save_settings(settings, name='recruitment')
29
+
30
+ print(f"Cell(s): {settings['cell_types']}, in {settings['cell_plate_metadata']}")
31
+ print(f"Pathogen(s): {settings['pathogen_types']}, in {settings['pathogen_plate_metadata']}")
32
+ print(f"Treatment(s): {settings['treatments']}, in {settings['treatment_plate_metadata']}")
33
+
34
+ mask_chans=[settings['nucleus_chann_dim'], settings['pathogen_chann_dim'], settings['cell_chann_dim']]
35
+
36
+ sns.color_palette("mako", as_cmap=True)
37
+ print(f"channel:{settings['channel_of_interest']} = {settings['target']}")
38
+
39
+ df, _ = _read_and_merge_data(db_loc=[settings['src']+'/measurements/measurements.db'],
40
+ tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
41
+ verbose=True,
42
+ nuclei_limit=settings['nuclei_limit'],
43
+ pathogen_limit=settings['pathogen_limit'],
44
+ uninfected=settings['uninfected'])
45
+
46
+ df = annotate_conditions(df,
47
+ cells=settings['cell_types'],
48
+ cell_loc=settings['cell_plate_metadata'],
49
+ pathogens=settings['pathogen_types'],
50
+ pathogen_loc=settings['pathogen_plate_metadata'],
51
+ treatments=settings['treatments'],
52
+ treatment_loc=settings['treatment_plate_metadata'])
53
+
54
+ df = df.dropna(subset=['condition'])
55
+ print(f'After dropping non-annotated wells: {len(df)} rows')
56
+
57
+ files = df['file_name'].tolist()
58
+ print(f'found: {len(files)} files')
59
+
60
+ files = [item + '.npy' for item in files]
61
+ random.shuffle(files)
62
+
63
+ _max = 10**100
64
+ if settings['cell_size_range'] is None:
65
+ settings['cell_size_range'] = [0,_max]
66
+ if settings['nucleus_size_range'] is None:
67
+ settings['nucleus_size_range'] = [0,_max]
68
+ if settings['pathogen_size_range'] is None:
69
+ settings['pathogen_size_range'] = [0,_max]
70
+
71
+ if settings['plot']:
72
+ merged_path = os.path.join(settings['src'],'merged')
73
+ if os.path.exists(merged_path):
74
+ try:
75
+ for idx, file in enumerate(os.listdir(merged_path)):
76
+ file_path = os.path.join(merged_path,file)
77
+ if idx <= settings['plot_nr']:
78
+ plot_image_mask_overlay(file_path,
79
+ settings['channel_dims'],
80
+ settings['cell_chann_dim'],
81
+ settings['nucleus_chann_dim'],
82
+ settings['pathogen_chann_dim'],
83
+ figuresize=10,
84
+ normalize=True,
85
+ thickness=3,
86
+ save_pdf=True)
87
+ except Exception as e:
88
+ print(f'Failed to plot images with outlines, Error: {e}')
89
+
90
+ if not settings['cell_chann_dim'] is None:
91
+ df = _object_filter(df, 'cell', settings['cell_size_range'], settings['cell_intensity_range'], mask_chans, 0)
92
+ if not settings['target_intensity_min'] is None:
93
+ df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95'] > settings['target_intensity_min"]]
94
+ print(f"After channel {settings['channel_of_interest']} filtration", len(df))
95
+ if not settings['nucleus_chann_dim'] is None:
96
+ df = _object_filter(df, 'nucleus', settings['nucleus_size_range'], settings['nucleus_intensity_range'], mask_chans, 1)
97
+ if not settings['pathogen_chann_dim'] is None:
98
+ df = _object_filter(df, 'pathogen', settings['pathogen_size_range'], settings['pathogen_intensity_range'], mask_chans, 2)
99
+
100
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
101
+ for chan in settings['channel_dims']:
102
+ df = _calculate_recruitment(df, channel=chan)
103
+ print(f'calculated recruitment for: {len(df)} rows')
104
+
105
+ df_well = _group_by_well(df)
106
+ print(f'found: {len(df_well)} wells')
107
+
108
+ df_well = df_well[df_well['cells_per_well'] >= settings['cells_per_well']]
109
+ prc_list = df_well['prc'].unique().tolist()
110
+ df = df[df['prc'].isin(prc_list)]
111
+ print(f"After cells per well filter: {len(df)} cells in {len(df_well)} wells left wth threshold {settings['cells_per_well']}")
112
+
113
+ if settings['plot_control']:
114
+ _plot_controls(df, mask_chans, settings['channel_of_interest'], figuresize=5)
115
+
116
+ print(f'PV level: {len(df)} rows')
117
+ _plot_recruitment(df, 'by PV', settings['channel_of_interest'], settings['target'], settings['figuresize'])
118
+ print(f'well level: {len(df_well)} rows')
119
+ _plot_recruitment(df_well, 'by well', settings['channel_of_interest'], settings['target'], settings['figuresize'])
120
+ cells,wells = _results_to_csv(settings['src'], df, df_well)
121
+
122
+ return [cells,wells]
123
+
124
+ def analyze_plaques(folder):
125
+ summary_data = []
126
+ details_data = []
127
+ stats_data = []
128
+
129
+ for filename in os.listdir(folder):
130
+ filepath = os.path.join(folder, filename)
131
+ if os.path.isfile(filepath):
132
+ # Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
133
+ #image = np.load(filepath)
134
+ image = cellpose.io.imread(filepath)
135
+ labeled_image = label(image)
136
+ regions = regionprops(labeled_image)
137
+
138
+ object_count = len(regions)
139
+ sizes = [region.area for region in regions]
140
+ average_size = np.mean(sizes) if sizes else 0
141
+ std_dev_size = np.std(sizes) if sizes else 0
142
+
143
+ summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
144
+ stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
145
+ for size in sizes:
146
+ details_data.append({'file': filename, 'plaque_size': size})
147
+
148
+ # Convert lists to pandas DataFrames
149
+ summary_df = pd.DataFrame(summary_data)
150
+ details_df = pd.DataFrame(details_data)
151
+ stats_df = pd.DataFrame(stats_data)
152
+
153
+ # Save DataFrames to a SQLite database
154
+ db_name = os.path.join(folder, 'plaques_analysis.db')
155
+ conn = sqlite3.connect(db_name)
156
+
157
+ summary_df.to_sql('summary', conn, if_exists='replace', index=False)
158
+ details_df.to_sql('details', conn, if_exists='replace', index=False)
159
+ stats_df.to_sql('stats', conn, if_exists='replace', index=False)
160
+
161
+ conn.close()
162
+
163
+ print(f"Analysis completed and saved to database '{db_name}'.")
164
+
165
+ def train_cellpose(settings):
166
+
167
+ from .io import _load_normalized_images_and_labels, _load_images_and_labels
168
+ from .settings import get_train_cellpose_default_settings
169
+ from .utils import save_settings
170
+
171
+ settings = get_train_cellpose_default_settings(settings)
172
+
173
+ img_src = settings['img_src']
174
+ mask_src = os.path.join(img_src, 'masks')
175
+ test_img_src = settings['test_img_src']
176
+ test_mask_src = settings['test_mask_src']
177
+
178
+ if settings['resize']:
179
+ target_height = settings['width_height'][1]
180
+ target_width = settings['width_height'][0]
181
+
182
+ if settings['test']:
183
+ test_img_src = os.path.join(os.path.dirname(settings['img_src']), 'test')
184
+ test_mask_src = os.path.join(settings['test_img_src'], 'mask')
185
+
186
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
187
+ print(settings)
188
+
189
+ if settings['from_scratch']:
190
+ model_name=f"scratch_{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
191
+ else:
192
+ if settings['resize']:
193
+ model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}_X{target_width}_Y{target_height}.CP_model"
194
+ else:
195
+ model_name=f"{settings['model_name']}_{settings['model_type']}_e{settings['n_epochs']}.CP_model"
196
+
197
+ model_save_path = os.path.join(settings['mask_src'], 'models', 'cellpose_model')
198
+ print(model_save_path)
199
+ os.makedirs(model_save_path, exist_ok=True)
200
+
201
+ save_settings(settings, name=model_name)
202
+
203
+ if settings['from_scratch']:
204
+ model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'], diam_mean=settings['diameter'], pretrained_model=None)
205
+ else:
206
+ model = cp_models.CellposeModel(gpu=True, model_type=settings['model_type'])
207
+
208
+ if settings['normalize']:
209
+
210
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
211
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
212
+ images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files,
213
+ label_files,
214
+ settings['channels'],
215
+ settings['percentiles'],
216
+ settings['circular'],
217
+ settings['invert'],
218
+ settings['verbose'],
219
+ settings['remove_background'],
220
+ settings['background'],
221
+ settings['Signal_to_noise'],
222
+ settings['target_height'],
223
+ settings['target_width'])
224
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
225
+
226
+ if settings['test']:
227
+ test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
228
+ test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
229
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files,
230
+ test_label_files,
231
+ settings['channels'],
232
+ settings['percentiles'],
233
+ settings['circular'],
234
+ settings['invert'],
235
+ settings['verbose'],
236
+ settings['remove_background'],
237
+ settings['background'],
238
+ settings['Signal_to_noise'],
239
+ settings['target_height'],
240
+ settings['target_width'])
241
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
242
+
243
+ else:
244
+ images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, settings['circular'], settings['invert'])
245
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
246
+
247
+ if settings['test']:
248
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(test_img_src,
249
+ test_mask_src,
250
+ settings['circular'],
251
+ settings['invert'])
252
+
253
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
254
+
255
+ #if resize:
256
+ # images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
257
+
258
+ if settings['model_type'] == 'cyto':
259
+ cp_channels = [0,1]
260
+ if settings['model_type'] == 'cyto2':
261
+ cp_channels = [0,2]
262
+ if settings['model_type'] == 'nucleus':
263
+ cp_channels = [0,0]
264
+ if settings['grayscale']:
265
+ cp_channels = [0,0]
266
+ images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
267
+
268
+ masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
269
+
270
+ print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
271
+ save_every = int(settings['n_epochs']/10)
272
+ if save_every < 10:
273
+ save_every = settings['n_epochs']
274
+
275
+ train_cp.train_seg(model.net,
276
+ train_data=images,
277
+ train_labels=masks,
278
+ train_files=image_names,
279
+ train_labels_files=mask_names,
280
+ train_probs=None,
281
+ test_data=test_images,
282
+ test_labels=test_masks,
283
+ test_files=test_image_names,
284
+ test_labels_files=test_mask_names,
285
+ test_probs=None,
286
+ load_files=True,
287
+ batch_size=settings['batch_size'],
288
+ learning_rate=settings['learning_rate'],
289
+ n_epochs=settings['n_epochs'],
290
+ weight_decay=settings['weight_decay'],
291
+ momentum=0.9,
292
+ SGD=False,
293
+ channels=cp_channels,
294
+ channel_axis=None,
295
+ #rgb=False,
296
+ normalize=False,
297
+ compute_flows=False,
298
+ save_path=model_save_path,
299
+ save_every=save_every,
300
+ nimg_per_epoch=None,
301
+ nimg_test_per_epoch=None,
302
+ rescale=settings['rescale'],
303
+ #scale_range=None,
304
+ #bsize=224,
305
+ min_train_masks=1,
306
+ model_name=settings['model_name'])
307
+
308
+ return print(f"Model saved at: {model_save_path}/{model_name}")
309
+
310
+ def count_phenotypes(settings):
311
+ from .io import _read_db
312
+
313
+ if not settings['src'].endswith('/measurements/measurements.db'):
314
+ settings['src'] = os.path.join(settings['src'], 'measurements/measurements.db')
315
+
316
+ df = _read_db(loc=settings['src'], tables=['png_list'])
317
+
318
+ unique_values_count = df[settings['annotation_column']].nunique(dropna=True)
319
+ print(f"Unique values in {settings['annotation_column']} (excluding NaN): {unique_values_count}")
320
+
321
+ # Count unique values in 'value' column, grouped by 'plate', 'row', 'column'
322
+ grouped_unique_count = df.groupby(['plate', 'row', 'column'])[settings['annotation_column']].nunique(dropna=True).reset_index(name='unique_count')
323
+ display(grouped_unique_count)
324
+
325
+ save_path = os.path.join(settings['src'], 'phenotype_counts.csv')
326
+
327
+ # Group by plate, row, and column, then count the occurrences of each unique value
328
+ grouped_counts = df.groupby(['plate', 'row', 'column', 'value']).size().reset_index(name='count')
329
+
330
+ # Pivot the DataFrame so that unique values are columns and their counts are in the rows
331
+ pivot_df = grouped_counts.pivot_table(index=['plate', 'row', 'column'], columns='value', values='count', fill_value=0)
332
+
333
+ # Flatten the multi-level columns
334
+ pivot_df.columns = [f"value_{int(col)}" for col in pivot_df.columns]
335
+
336
+ # Reset the index so that plate, row, and column form a combined index
337
+ pivot_df.index = pivot_df.index.map(lambda x: f"{x[0]}_{x[1]}_{x[2]}")
338
+
339
+ # Saving the DataFrame to a SQLite .db file
340
+ output_dir = os.path.join('src', 'results') # Replace 'src' with the actual base directory
341
+ os.makedirs(output_dir, exist_ok=True)
342
+
343
+ output_dir = os.path.dirname(settings['src'])
344
+ output_path = os.path.join(output_dir, 'phenotype_counts.csv')
345
+
346
+ pivot_df.to_csv(output_path)
347
+
348
+ return
spacr/timelapse.py CHANGED
@@ -13,8 +13,6 @@ from scipy.optimize import curve_fit
13
13
  from scipy.integrate import trapz
14
14
  import matplotlib.pyplot as plt
15
15
 
16
- from .logger import log_function_call
17
-
18
16
  def _npz_to_movie(arrays, filenames, save_path, fps=10):
19
17
  """
20
18
  Convert a list of numpy arrays to a movie file.
spacr/toxo.py ADDED
@@ -0,0 +1,233 @@
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import numpy as np
4
+ from adjustText import adjust_text
5
+ import pandas as pd
6
+ from scipy.stats import fisher_exact
7
+
8
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', string_list=[], point_size=50, figsize=20):
9
+ """
10
+ Create a volcano plot with the ability to control the shape of points based on a categorical column,
11
+ color points based on a string list, annotate specific points based on p-value and coefficient thresholds,
12
+ and control the size of points.
13
+
14
+ Parameters:
15
+ - data_path: Path to the data CSV file.
16
+ - metadata_path: Path to the metadata CSV file.
17
+ - metadata_column: Column name in the metadata to control point shapes.
18
+ - string_list: List of strings to color points differently if present in 'coefficient' names.
19
+ - point_size: Fixed value to control the size of points.
20
+ - figsize: Width of the plot (height is half the width).
21
+ """
22
+
23
+ filename = 'volcano_plot.pdf'
24
+
25
+ # Load the data
26
+
27
+ if isinstance(data_path, pd.DataFrame):
28
+ data = data_path
29
+ else:
30
+ data = pd.read_csv(data_path)
31
+ data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
32
+ data['variable'].fillna(data['feature'], inplace=True)
33
+ split_columns = data['variable'].str.split('_', expand=True)
34
+ data['gene_nr'] = split_columns[0]
35
+
36
+ # Load metadata
37
+ if isinstance(metadata_path, pd.DataFrame):
38
+ metadata = metadata_path
39
+ else:
40
+ metadata = pd.read_csv(metadata_path)
41
+
42
+ metadata['gene_nr'] = metadata['gene_nr'].astype(str)
43
+ data['gene_nr'] = data['gene_nr'].astype(str)
44
+
45
+ # Merge data and metadata on 'gene_nr'
46
+ merged_data = pd.merge(data, metadata[['gene_nr', 'tagm_location']], on='gene_nr', how='left')
47
+
48
+ # Controls handling
49
+ controls = ['000000', '000001', '000002', '000003', '000004', '000005', '000006', '000007', '000008', '000009', '000010', '000011']
50
+ merged_data.loc[merged_data['gene_nr'].isin(controls), metadata_column] = 'control'
51
+ merged_data.loc[merged_data['gene_nr'].str.startswith('4'), metadata_column] = 'GT1_gene'
52
+ merged_data.loc[merged_data['gene_nr'] == 'Intercept', metadata_column] = 'Intercept'
53
+
54
+ # Create a 'highlight_color' column based on the string_list
55
+ merged_data['highlight_color'] = merged_data['gene_nr'].apply(lambda x: 'red' if any(s in x for s in string_list) else 'blue')
56
+
57
+ # Create the volcano plot
58
+ figsize_2 = figsize / 2
59
+ plt.figure(figsize=(figsize_2, figsize))
60
+
61
+ # Create the scatter plot with fixed point size
62
+ sns.scatterplot(
63
+ data=merged_data,
64
+ x='coefficient',
65
+ y='-log10(p_value)',
66
+ hue='highlight_color',
67
+ style=metadata_column if metadata_column else None, # Control point shape with metadata_column
68
+ s=point_size, # Fixed size for all points
69
+ palette={'red': 'red', 'blue': 'blue'}
70
+ )
71
+
72
+ # Set the plot title and labels
73
+ plt.title('Custom Volcano Plot of Coefficients')
74
+ plt.xlabel('Coefficient')
75
+ plt.ylabel('-log10(p-value)')
76
+
77
+ # Horizontal line at p-value threshold (0.05)
78
+ plt.axhline(y=-np.log10(0.05), color='red', linestyle='--')
79
+
80
+ # Annotate points where p_value <= 0.05 and coefficient >= 0.25
81
+ texts = []
82
+ for i, row in merged_data.iterrows():
83
+ if row['p_value'] <= 0.05 and row['coefficient'] >= 0.25:
84
+ texts.append(plt.text(row['coefficient'], -np.log10(row['p_value']), row['gene_nr'], fontsize=9))
85
+
86
+ # Adjust text positions to avoid overlap
87
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
88
+
89
+ # Move the legend outside the plot
90
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
91
+
92
+ # Save the plot
93
+ plt.savefig(filename, format='pdf', bbox_inches='tight') # bbox_inches ensures the legend doesn't get cut off
94
+ print(f'Saved Volcano plot: {filename}')
95
+
96
+ # Show the plot
97
+ plt.show()
98
+
99
+ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
100
+ """
101
+ Perform GO term enrichment analysis for each GO term column and generate plots.
102
+
103
+ Parameters:
104
+ - significant_df: DataFrame containing the significant genes from the screen.
105
+ - metadata_path: Path to the metadata file containing GO terms.
106
+ - go_term_columns: List of columns in the metadata corresponding to GO terms.
107
+
108
+ For each GO term column, this function will:
109
+ - Split the GO terms by semicolons.
110
+ - Count the occurrences of GO terms in the hits and in the background.
111
+ - Perform Fisher's exact test for enrichment.
112
+ - Plot the enrichment score vs -log10(p-value).
113
+ """
114
+
115
+ significant_df['variable'].fillna(significant_df['feature'], inplace=True)
116
+ split_columns = significant_df['variable'].str.split('_', expand=True)
117
+ significant_df['gene_nr'] = split_columns[0]
118
+ gene_list = significant_df['gene_nr'].to_list()
119
+
120
+ # Load metadata
121
+ metadata = pd.read_csv(metadata_path)
122
+ split_columns = metadata['Gene ID'].str.split('_', expand=True)
123
+ metadata['gene_nr'] = split_columns[1]
124
+
125
+ # Create a subset of metadata with only the rows that contain genes in gene_list (hits)
126
+ hits_metadata = metadata[metadata['gene_nr'].isin(gene_list)]
127
+
128
+ # Create a list to hold results from all columns
129
+ combined_results = []
130
+
131
+ for go_term_column in go_term_columns:
132
+ # Initialize lists to store results
133
+ go_terms = []
134
+ enrichment_scores = []
135
+ p_values = []
136
+
137
+ # Split the GO terms in the entire metadata and hits
138
+ metadata[go_term_column] = metadata[go_term_column].fillna('')
139
+ hits_metadata[go_term_column] = hits_metadata[go_term_column].fillna('')
140
+
141
+ all_go_terms = metadata[go_term_column].str.split(';').explode()
142
+ hit_go_terms = hits_metadata[go_term_column].str.split(';').explode()
143
+
144
+ # Count occurrences of each GO term in hits and total metadata
145
+ all_go_term_counts = all_go_terms.value_counts()
146
+ hit_go_term_counts = hit_go_terms.value_counts()
147
+
148
+ # Perform enrichment analysis for each GO term
149
+ for go_term in all_go_term_counts.index:
150
+ total_with_go_term = all_go_term_counts.get(go_term, 0)
151
+ hits_with_go_term = hit_go_term_counts.get(go_term, 0)
152
+
153
+ # Calculate the total number of genes and hits
154
+ total_genes = len(metadata)
155
+ total_hits = len(hits_metadata)
156
+
157
+ # Perform Fisher's exact test
158
+ contingency_table = [[hits_with_go_term, total_hits - hits_with_go_term],
159
+ [total_with_go_term - hits_with_go_term, total_genes - total_hits - (total_with_go_term - hits_with_go_term)]]
160
+
161
+ _, p_value = fisher_exact(contingency_table)
162
+
163
+ # Calculate enrichment score (hits with GO term / total hits with GO term)
164
+ if total_with_go_term > 0 and total_hits > 0:
165
+ enrichment_score = (hits_with_go_term / total_hits) / (total_with_go_term / total_genes)
166
+ else:
167
+ enrichment_score = 0.0
168
+
169
+ # Store the results only if enrichment score is non-zero
170
+ if enrichment_score > 0.0:
171
+ go_terms.append(go_term)
172
+ enrichment_scores.append(enrichment_score)
173
+ p_values.append(p_value)
174
+
175
+ # Create a results DataFrame for this GO term column
176
+ results_df = pd.DataFrame({
177
+ 'GO Term': go_terms,
178
+ 'Enrichment Score': enrichment_scores,
179
+ 'P-value': p_values,
180
+ 'GO Column': go_term_column # Track the GO term column for final combined plot
181
+ })
182
+
183
+ # Sort by enrichment score
184
+ results_df = results_df.sort_values(by='Enrichment Score', ascending=False)
185
+
186
+ # Append this DataFrame to the combined list
187
+ combined_results.append(results_df)
188
+
189
+ # Plot the enrichment results for each individual column
190
+ plt.figure(figsize=(10, 6))
191
+
192
+ # Create a scatter plot of Enrichment Score vs -log10(p-value)
193
+ sns.scatterplot(data=results_df, x='Enrichment Score', y=-np.log10(results_df['P-value']), hue='GO Term', size='Enrichment Score', sizes=(50, 200))
194
+
195
+ # Set plot labels and title
196
+ plt.title(f'GO Term Enrichment Analysis for {go_term_column}')
197
+ plt.xlabel('Enrichment Score')
198
+ plt.ylabel('-log10(P-value)')
199
+
200
+ # Move the legend to the right of the plot
201
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
202
+
203
+ # Show the plot
204
+ plt.tight_layout() # Ensure everything fits in the figure area
205
+ plt.show()
206
+
207
+ # Optionally return or save the results for each column
208
+ print(f'Results for {go_term_column}')
209
+
210
+ # Combine results from all columns into a single DataFrame
211
+ combined_df = pd.concat(combined_results)
212
+
213
+ # Plot the combined results with text labels
214
+ plt.figure(figsize=(12, 8))
215
+ sns.scatterplot(data=combined_df, x='Enrichment Score', y=-np.log10(combined_df['P-value']),
216
+ style='GO Column', size='Enrichment Score', sizes=(50, 200))
217
+
218
+ # Set plot labels and title for the combined graph
219
+ plt.title('Combined GO Term Enrichment Analysis')
220
+ plt.xlabel('Enrichment Score')
221
+ plt.ylabel('-log10(P-value)')
222
+
223
+ # Annotate the points with labels and connecting lines
224
+ texts = []
225
+ for i, row in combined_df.iterrows():
226
+ texts.append(plt.text(row['Enrichment Score'], -np.log10(row['P-value']), row['GO Term'], fontsize=9))
227
+
228
+ # Adjust text to avoid overlap
229
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
230
+
231
+ # Show the combined plot
232
+ plt.tight_layout()
233
+ plt.show()