spacr 0.3.47__py3-none-any.whl → 0.3.52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/toxo.py CHANGED
@@ -7,26 +7,63 @@ import pandas as pd
7
7
  from scipy.stats import fisher_exact
8
8
  from IPython.display import display
9
9
  from matplotlib.legend import Legend
10
+ from matplotlib.transforms import Bbox
11
+ from brokenaxes import brokenaxes
10
12
 
11
- def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, split_axis_lims = [10, None, None, 10], save_path=None, x_lim=[-0.5, 0.5]):
12
- """
13
- Create a volcano plot with the ability to control the shape of points based on a categorical column,
14
- color points based on a condition, annotate specific points based on p-value and coefficient thresholds,
15
- and control the size of points.
16
- """
17
- volcano_path = save_path
18
- padd = 30
19
- fontsize = 18
20
- # Load the data
13
+ import os
14
+ import pandas as pd
15
+ import seaborn as sns
16
+ import matplotlib.pyplot as plt
17
+ from scipy.spatial.distance import cosine
18
+ from scipy.stats import pearsonr
19
+ import pandas as pd
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ from sklearn.metrics import mean_absolute_error
23
+
24
+
25
+ from matplotlib.gridspec import GridSpec
26
+
27
+
28
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',
29
+ point_size=50, figsize=20, threshold=0,
30
+ save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
31
+
32
+ markers = [
33
+ 'o', # Circle
34
+ 'X', # X-shaped marker
35
+ '^', # Upward triangle
36
+ 's', # Square
37
+ 'v', # Downward triangle
38
+ 'P', # Plus-filled pentagon
39
+ '*', # Star
40
+ '+', # Plus
41
+ 'x', # Cross
42
+ '.', # Point
43
+ ',', # Pixel
44
+ 'd', # Diamond
45
+ 'D', # Thin diamond
46
+ 'h', # Hexagon 1
47
+ 'H', # Hexagon 2
48
+ 'p', # Pentagon
49
+ '|', # Vertical line
50
+ '_', # Horizontal line
51
+ ]
52
+
53
+ plt.rcParams.update({'font.size': 14})
54
+
55
+ # Load data
21
56
  if isinstance(data_path, pd.DataFrame):
22
57
  data = data_path
23
58
  else:
24
59
  data = pd.read_csv(data_path)
25
-
60
+
61
+ fontsize = 18
62
+
63
+ plt.rcParams.update({'font.size': fontsize})
26
64
  data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
27
65
  data['variable'].fillna(data['feature'], inplace=True)
28
- split_columns = data['variable'].str.split('_', expand=True)
29
- data['gene_nr'] = split_columns[0]
66
+ data['gene_nr'] = data['variable'].str.split('_').str[0]
30
67
  data = data[data['variable'] != 'Intercept']
31
68
 
32
69
  # Load metadata
@@ -34,173 +71,110 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
34
71
  metadata = metadata_path
35
72
  else:
36
73
  metadata = pd.read_csv(metadata_path)
37
-
38
74
  metadata['gene_nr'] = metadata['gene_nr'].astype(str)
39
75
  data['gene_nr'] = data['gene_nr'].astype(str)
40
76
 
41
- # Merge data and metadata on 'gene_nr'
42
- merged_data = pd.merge(data, metadata[['gene_nr', 'tagm_location']], on='gene_nr', how='left')
43
-
44
- merged_data.loc[merged_data['gene_nr'].str.startswith('4'), metadata_column] = 'GT1_gene'
45
- merged_data.loc[merged_data['gene_nr'] == 'Intercept', metadata_column] = 'Intercept'
46
- merged_data.loc[merged_data['condition'] == 'control', metadata_column] = 'control'
77
+ merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left')
47
78
  merged_data[metadata_column].fillna('unknown', inplace=True)
48
79
 
49
- # Categorize condition for coloring
50
- merged_data['condition'] = pd.Categorical(
51
- merged_data['condition'],
52
- categories=['other','pc', 'nc', 'control'],
53
- ordered=True)
54
-
55
- # Create subplots with a broken y-axis
56
- figsize_2 = figsize / 2
57
- fig, (ax1, ax2) = plt.subplots(
58
- 2, 1, figsize=(figsize, figsize),
59
- sharex=True, gridspec_kw={'height_ratios': [1, 3]}
60
- )
61
-
62
- # Define color palette
63
- palette = {
64
- 'pc': 'red',
65
- 'nc': 'green',
66
- 'control': 'white',
67
- 'other': 'gray'}
68
-
69
- # Scatter plot on both axes with legend completely disabled
70
- sns.scatterplot(
71
- data=merged_data,
72
- x='coefficient',
73
- y='-log10(p_value)',
74
- hue='condition',
75
- style=metadata_column if metadata_column else None,
76
- s=point_size,
77
- edgecolor='black',
78
- palette=palette,
79
- legend=False, # Disable automatic legend
80
- alpha=0.6,
81
- ax=ax2 # Lower plot
82
- )
83
-
84
- sns.scatterplot(
85
- data=merged_data[merged_data['-log10(p_value)'] > 10],
86
- x='coefficient',
87
- y='-log10(p_value)',
88
- hue='condition',
89
- style=metadata_column if metadata_column else None,
90
- s=point_size,
91
- edgecolor='black',
92
- palette=palette,
93
- legend=False, # No legend on the upper plot
94
- alpha=0.6,
95
- ax=ax1 # Upper plot
96
- )
97
-
98
- # Ensure no previous legends on ax1 or ax2
99
- if ax1.get_legend() is not None:
100
- ax1.get_legend().remove()
80
+ # Define palette and markers
81
+ palette = {'pc': 'red', 'nc': 'green', 'control': 'white', 'other': 'gray'}
82
+ marker_dict = {val: marker for val, marker in zip(
83
+ merged_data[metadata_column].unique(), markers)}
101
84
 
102
- if ax2.get_legend() is not None:
103
- ax2.get_legend().remove()
85
+ # Create the figure with custom spacing
86
+ fig = plt.figure(figsize=(figsize,figsize))
87
+ gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
104
88
 
105
- # Manually gather handles and labels from ax2 after plotting
106
- handles, labels = ax2.get_legend_handles_labels()
89
+ ax_upper = fig.add_subplot(gs[0])
90
+ ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
107
91
 
108
- # Debug: Print the captured handles and labels for verification
109
- print(f"Handles: {handles}")
110
- print(f"Labels: {labels}")
92
+ # Hide x-axis labels on the upper plot
93
+ ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
111
94
 
112
- # Identify shape-based legend entries (skip color-based entries)
113
- n_color_entries = len(set(merged_data['condition']))
114
- shape_handles = handles[n_color_entries:]
115
- shape_labels = labels[n_color_entries:]
95
+ hit_list = []
116
96
 
117
- # Create and add the legend with shape-based entries
118
- legend = Legend(
119
- ax2, shape_handles, shape_labels,
120
- bbox_to_anchor=(1.05, 1), loc='upper left',
121
- handletextpad=2.0, labelspacing=1.5, borderaxespad=1.0,
122
- markerscale=2.0, prop={'size': 14}
123
- )
124
- ax2.add_artist(legend)
125
-
126
- if isinstance(split_axis_lims, list):
127
- if len(split_axis_lims) == 4:
128
- ylim_min_ax1 = split_axis_lims[0]
129
- if split_axis_lims[1] is None:
130
- ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
131
- else:
132
- ylim_max_ax1 = split_axis_lims[1]
133
- ylim_min_ax2 = split_axis_lims[2]
134
- ylim_max_ax2 = split_axis_lims[3]
135
- else:
136
- ylim_min_ax1 = None
137
- ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
138
- ylim_min_ax2 = 0
139
- ylim_max_ax2 = None
140
-
141
- # Set axis limits and hide unnecessary parts
142
- ax1.set_ylim(ylim_min_ax1, ylim_max_ax1)
143
- ax2.set_ylim(0, ylim_max_ax2)
97
+ # Scatter plot on both axes
98
+ for _, row in merged_data.iterrows():
99
+ y_val = -np.log10(row['p_value'])
100
+ ax = ax_upper if y_val > y_lims[1][0] else ax_lower
144
101
 
145
- if x_lim != None:
146
- ax1.set_xlim(x_lim)
147
- ax2.set_xlim(x_lim)
102
+ ax.scatter(
103
+ row['coefficient'], y_val,
104
+ color=palette.get(row['condition'], 'gray'),
105
+ marker=marker_dict.get(row[metadata_column], 'o'),
106
+ s=point_size, edgecolor='black', alpha=0.6
107
+ )
148
108
 
149
- ax1.spines['bottom'].set_visible(False)
150
- ax2.spines['top'].set_visible(False)
151
- ax1.tick_params(labelbottom=False)
152
-
153
- ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
154
-
155
- # Add vertical threshold lines to both plots
156
- if threshold > 0:
157
- for ax in (ax1, ax2):
158
- ax.axvline(x=-abs(threshold), linestyle='--', color='black')
159
- ax.axvline(x=abs(threshold), linestyle='--', color='black')
109
+ if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
110
+ hit_list.append(row['variable'])
160
111
 
161
- # Add a horizontal line at p-value threshold (0.05)
162
- ax2.axhline(y=-np.log10(0.05), color='black', linestyle='--')
112
+ # Set axis limits
113
+ ax_upper.set_ylim(y_lims[1])
114
+ ax_lower.set_ylim(y_lims[0])
115
+ ax_lower.set_xlim(x_lim)
163
116
 
164
- # Annotate significant points on both axes
165
- texts_ax1 = []
166
- texts_ax2 = []
117
+ ax_lower.spines['top'].set_visible(False)
118
+ ax_upper.spines['top'].set_visible(False)
119
+ ax_upper.spines['bottom'].set_visible(False)
167
120
 
168
- for i, row in merged_data.iterrows():
169
- if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
170
- ax = ax1 if row['-log10(p_value)'] >= ax1.get_ylim()[0] else ax2
171
- # Create the annotation on the selected axis
172
- text = ax.text(
173
- row['coefficient'],
174
- -np.log10(row['p_value']),
175
- row['variable'],
176
- fontsize=fontsize,
177
- ha='center',
178
- va='bottom',
179
- )
121
+ # Set x-axis and y-axis titles
122
+ ax_lower.set_xlabel('Coefficient') # X-axis title on the lower graph
123
+ ax_lower.set_ylabel('-log10(p-value)') # Y-axis title on the lower graph
124
+ ax_upper.set_ylabel('-log10(p-value)') # Y-axis title on the upper graph
125
+
126
+ for ax in [ax_upper, ax_lower]:
127
+ ax.spines['right'].set_visible(False)
180
128
 
181
- # Store the text annotation in the correct list
182
- if ax == ax1:
183
- texts_ax1.append(text)
184
- else:
185
- texts_ax2.append(text)
129
+ # Add threshold lines to both axes
130
+ for ax in [ax_upper, ax_lower]:
131
+ ax.axvline(x=-abs(threshold), linestyle='--', color='black')
132
+ ax.axvline(x=abs(threshold), linestyle='--', color='black')
186
133
 
187
- # Adjust text positions to avoid overlap for both axes
188
- adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1, expand_points=(padd, padd), fontsize=fontsize)
189
- adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2, expand_points=(padd, padd), fontsize=fontsize)
134
+ ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
190
135
 
191
- # Move the legend outside the lower plot
192
- ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
136
+ # Annotate significant points
137
+ texts_upper, texts_lower = [], [] # Collect text annotations separately
193
138
 
194
- # Adjust the spacing between subplots and move the title
195
- plt.subplots_adjust(hspace=0.05)
196
- fig.suptitle('Custom Volcano Plot of Coefficients', y=1.02, fontsize=16) # Title above the top plot
139
+ for _, row in merged_data.iterrows():
140
+ y_val = -np.log10(row['p_value'])
141
+ if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold):
142
+ continue
197
143
 
198
- # Save the plot as PDF
199
- plt.savefig(volcano_path, format='pdf', bbox_inches='tight')
200
- print(f'Saved Volcano plot: {volcano_path}')
144
+ ax = ax_upper if y_val > y_lims[1][0] else ax_lower
145
+ text = ax.text(row['coefficient'], y_val, row['variable'],
146
+ fontsize=fontsize, ha='center', va='bottom')
201
147
 
202
- # Show the plot
148
+ if ax == ax_upper:
149
+ texts_upper.append(text)
150
+ else:
151
+ texts_lower.append(text)
152
+
153
+ # Adjust text positions to avoid overlap
154
+ adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
155
+ adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
156
+
157
+ # Add a single legend on the lower axis
158
+ handles = [plt.Line2D([0], [0], marker=m, color='w', markerfacecolor='gray', markersize=10)
159
+ for m in marker_dict.values()]
160
+ labels = marker_dict.keys()
161
+ ax_lower.legend(handles,
162
+ labels,
163
+ bbox_to_anchor=(1.05, 1),
164
+ loc='upper left',
165
+ borderaxespad=0.25,
166
+ labelspacing=2,
167
+ handletextpad=0.25,
168
+ markerscale=2,
169
+ prop={'size': fontsize})
170
+
171
+
172
+ # Save and show the plot
173
+ if save_path:
174
+ plt.savefig(save_path, format='pdf', bbox_inches='tight')
203
175
  plt.show()
176
+
177
+ return hit_list
204
178
 
205
179
  def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
206
180
  """
@@ -341,4 +315,318 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
341
315
 
342
316
  # Show the combined plot
343
317
  plt.tight_layout()
344
- plt.show()
318
+ plt.show()
319
+
320
+ def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
321
+ """
322
+ Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
323
+
324
+ Args:
325
+ data (pd.DataFrame): The input DataFrame containing gene data.
326
+ gene_list (list): A list of gene names to highlight on the plot.
327
+ """
328
+ # Ensure x_column is properly processed
329
+ def extract_gene_id(gene):
330
+ if isinstance(gene, str) and '_' in gene:
331
+ return gene.split('_')[1]
332
+ return str(gene)
333
+
334
+ data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce')
335
+ data = data.dropna(subset=[data_column])
336
+ data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce')
337
+ data = data.dropna(subset=[error_column])
338
+
339
+ data['x'] = data[x_column].apply(extract_gene_id)
340
+
341
+ # Sort by the data_column and assign ranks
342
+ data = data.sort_values(by=data_column).reset_index(drop=True)
343
+ data['rank'] = range(1, len(data) + 1)
344
+
345
+ # Prepare the x, y, and error values for plotting
346
+ x = data['rank']
347
+ y = data[data_column]
348
+ yerr = data[error_column]
349
+
350
+ # Create the plot
351
+ plt.figure(figsize=(10, 10))
352
+
353
+ # Plot the mean phenotype with standard error shading
354
+ plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2)
355
+ plt.fill_between(
356
+ x, y - yerr, y + yerr,
357
+ color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error'
358
+ )
359
+
360
+ # Prepare for adjustText
361
+ texts = [] # Store text objects for adjustment
362
+
363
+ # Highlight the genes in the gene_list
364
+ for gene in gene_list:
365
+ gene_id = extract_gene_id(gene)
366
+ gene_data = data[data['x'] == gene_id]
367
+ if not gene_data.empty:
368
+ # Scatter the highlighted points in purple and add labels for adjustment
369
+ plt.scatter(
370
+ gene_data['rank'],
371
+ gene_data[data_column],
372
+ color=(155/255, 55/255, 155/255),
373
+ s=200,
374
+ alpha=0.6,
375
+ label=f'Highlighted Gene: {gene}',
376
+ zorder=3 # Ensure the points are on top
377
+ )
378
+ # Add the text label next to the highlighted gene
379
+ texts.append(
380
+ plt.text(
381
+ gene_data['rank'].values[0],
382
+ gene_data[data_column].values[0],
383
+ gene,
384
+ fontsize=18,
385
+ ha='right'
386
+ )
387
+ )
388
+
389
+ # Adjust text to avoid overlap with lines drawn from points to text
390
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray'))
391
+
392
+ # Label the plot
393
+ plt.xlabel('Rank')
394
+ plt.ylabel('Mean Phenotype')
395
+ #plt.xticks(rotation=90) # Rotate x-axis labels for readability
396
+ plt.legend().remove() # Remove the legend if not needed
397
+ plt.tight_layout()
398
+
399
+ # Save the plot if a path is provided
400
+ if save_path:
401
+ plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
402
+ print(f"Figure saved to {save_path}")
403
+
404
+ plt.show()
405
+
406
+ def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None):
407
+ """
408
+ Generate a teal-to-white heatmap with the specified columns and genes.
409
+
410
+ Args:
411
+ data (pd.DataFrame): The input DataFrame containing gene data.
412
+ gene_list (list): A list of genes to include in the heatmap.
413
+ columns (list): A list of column names to visualize as heatmaps.
414
+ normalize (bool): If True, normalize the values for each gene between 0 and 1.
415
+ save_path (str): Optional. If provided, the plot will be saved to this path.
416
+ """
417
+ # Ensure x_column is properly processed
418
+ def extract_gene_id(gene):
419
+ if isinstance(gene, str) and '_' in gene:
420
+ return gene.split('_')[1]
421
+ return str(gene)
422
+
423
+ data['x'] = data[x_column].apply(extract_gene_id)
424
+
425
+ # Filter the data to only include the specified genes
426
+ filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns]
427
+
428
+ # Normalize each gene's values between 0 and 1 if normalize=True
429
+ if normalize:
430
+ filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
431
+
432
+ # Define the figure size dynamically based on the number of genes and columns
433
+ width = len(columns) * 4
434
+ height = len(gene_list) * 1
435
+
436
+ # Create the heatmap
437
+ plt.figure(figsize=(width, height))
438
+ cmap = sns.color_palette("viridis", as_cmap=True)
439
+
440
+ # Plot the heatmap with genes on the y-axis and columns on the x-axis
441
+ sns.heatmap(
442
+ filtered_data,
443
+ cmap=cmap,
444
+ cbar=True,
445
+ annot=False,
446
+ linewidths=0.5,
447
+ square=True
448
+ )
449
+
450
+ # Set the labels
451
+ plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability
452
+ plt.yticks(rotation=0) # Keep y-axis labels horizontal
453
+ plt.xlabel('')
454
+ plt.ylabel('')
455
+
456
+ # Adjust layout to ensure the plot fits well
457
+ plt.tight_layout()
458
+
459
+ # Save the plot if a path is provided
460
+ if save_path:
461
+ plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
462
+ print(f"Figure saved to {save_path}")
463
+
464
+ plt.show()
465
+
466
+ def generate_score_heatmap(settings):
467
+
468
+ def group_cv_score(csv, plate=1, column='c3', data_column='pred'):
469
+
470
+ df = pd.read_csv(csv)
471
+ if 'col' in df.columns:
472
+ df = df[df['col']==column]
473
+ elif 'column' in df.columns:
474
+ df['col'] = df['column']
475
+ df = df[df['col']==column]
476
+ if not plate is None:
477
+ df['plate'] = f"plate{plate}"
478
+ grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
479
+ grouped_df['prc'] = grouped_df['plate'].astype(str) + '_' + grouped_df['row'].astype(str) + '_' + grouped_df['col'].astype(str)
480
+ return grouped_df
481
+
482
+ def calculate_fraction_mixed_condition(csv, plate=1, column='c3', control_sgrnas = ['TGGT1_220950_1', 'TGGT1_233460_4']):
483
+ df = pd.read_csv(csv)
484
+ df = df[df['column_name']==column]
485
+ if plate not in df.columns:
486
+ df['plate'] = f"plate{plate}"
487
+ df = df[df['grna_name'].str.match(f'^{control_sgrnas[0]}$|^{control_sgrnas[1]}$')]
488
+ grouped_df = df.groupby(['plate', 'row_name', 'column_name'])['count'].sum().reset_index()
489
+ grouped_df = grouped_df.rename(columns={'count': 'total_count'})
490
+ merged_df = pd.merge(df, grouped_df, on=['plate', 'row_name', 'column_name'])
491
+ merged_df['fraction'] = merged_df['count'] / merged_df['total_count']
492
+ merged_df['prc'] = merged_df['plate'].astype(str) + '_' + merged_df['row_name'].astype(str) + '_' + merged_df['column_name'].astype(str)
493
+ return merged_df
494
+
495
+ def plot_multi_channel_heatmap(df, column='c3'):
496
+ """
497
+ Plot a heatmap with multiple channels as columns.
498
+
499
+ Parameters:
500
+ - df: DataFrame with scores for different channels.
501
+ - column: Column to filter by (default is 'c3').
502
+ """
503
+ # Extract row number and convert to integer for sorting
504
+ df['row_num'] = df['row'].str.extract(r'(\d+)').astype(int)
505
+
506
+ # Filter and sort by plate, row, and column
507
+ df = df[df['col'] == column]
508
+ df = df.sort_values(by=['plate', 'row_num', 'col'])
509
+
510
+ # Drop temporary 'row_num' column after sorting
511
+ df = df.drop('row_num', axis=1)
512
+
513
+ # Create a new column combining plate, row, and column for the index
514
+ df['plate_row_col'] = df['plate'] + '-' + df['row'] + '-' + df['col']
515
+
516
+ # Set 'plate_row_col' as the index
517
+ df.set_index('plate_row_col', inplace=True)
518
+
519
+ # Extract only numeric data for the heatmap
520
+ heatmap_data = df.select_dtypes(include=[float, int])
521
+
522
+ # Plot heatmap with square boxes, no annotations, and 'viridis' colormap
523
+ plt.figure(figsize=(12, 8))
524
+ sns.heatmap(
525
+ heatmap_data,
526
+ cmap="viridis",
527
+ cbar=True,
528
+ square=True,
529
+ annot=False
530
+ )
531
+
532
+ plt.title("Heatmap of Prediction Scores for All Channels")
533
+ plt.xlabel("Channels")
534
+ plt.ylabel("Plate-Row-Column")
535
+ plt.tight_layout()
536
+
537
+ # Save the figure object and return it
538
+ fig = plt.gcf()
539
+ plt.show()
540
+
541
+ return fig
542
+
543
+
544
+ def combine_classification_scores(folders, csv_name, data_column, plate=1, column='c3'):
545
+ # Ensure `folders` is a list
546
+ if isinstance(folders, str):
547
+ folders = [folders]
548
+
549
+ ls = [] # Initialize ls to store found CSV file paths
550
+
551
+ # Iterate over the provided folders
552
+ for folder in folders:
553
+ sub_folders = os.listdir(folder) # Get sub-folder list
554
+ for sub_folder in sub_folders: # Iterate through sub-folders
555
+ path = os.path.join(folder, sub_folder) # Join the full path
556
+
557
+ if os.path.isdir(path): # Check if it’s a directory
558
+ csv = os.path.join(path, csv_name) # Join path to the CSV file
559
+ if os.path.exists(csv): # If CSV exists, add to list
560
+ ls.append(csv)
561
+ else:
562
+ print(f'No such file: {csv}')
563
+
564
+ # Initialize combined DataFrame
565
+ combined_df = None
566
+ print(f'Found {len(ls)} CSV files')
567
+
568
+ # Loop through all collected CSV files and process them
569
+ for csv_file in ls:
570
+ df = pd.read_csv(csv_file) # Read CSV into DataFrame
571
+ df = df[df['col']==column]
572
+ if not plate is None:
573
+ df['plate'] = f"plate{plate}"
574
+ # Group the data by 'plate', 'row', and 'col'
575
+ grouped_df = df.groupby(['plate', 'row', 'col'])[data_column].mean().reset_index()
576
+ # Use the CSV filename to create a new column name
577
+ folder_name = os.path.dirname(csv_file).replace(".csv", "")
578
+ new_column_name = os.path.basename(f"{folder_name}_{data_column}")
579
+ print(new_column_name)
580
+ grouped_df = grouped_df.rename(columns={data_column: new_column_name})
581
+
582
+ # Merge into the combined DataFrame
583
+ if combined_df is None:
584
+ combined_df = grouped_df
585
+ else:
586
+ combined_df = pd.merge(combined_df, grouped_df, on=['plate', 'row', 'col'], how='outer')
587
+ combined_df['prc'] = combined_df['plate'].astype(str) + '_' + combined_df['row'].astype(str) + '_' + combined_df['col'].astype(str)
588
+ return combined_df
589
+
590
+ def calculate_mae(df):
591
+ """
592
+ Calculate the MAE between each channel's predictions and the fraction column for all rows.
593
+ """
594
+ # Extract numeric columns excluding 'fraction' and 'prc'
595
+ channels = df.drop(columns=['fraction', 'prc']).select_dtypes(include=[float, int])
596
+
597
+ mae_data = []
598
+
599
+ # Compute MAE for each channel with 'fraction' for all rows
600
+ for column in channels.columns:
601
+ for index, row in df.iterrows():
602
+ mae = mean_absolute_error([row['fraction']], [row[column]])
603
+ mae_data.append({'Channel': column, 'MAE': mae, 'Row': row['prc']})
604
+
605
+ # Convert the list of dictionaries to a DataFrame
606
+ mae_df = pd.DataFrame(mae_data)
607
+ return mae_df
608
+
609
+ result_df = combine_classification_scores(settings['folders'], settings['csv_name'], settings['data_column'], settings['plate'], settings['column'], )
610
+ df = calculate_fraction_mixed_condition(settings['csv'], settings['plate'], settings['column'], settings['control_sgrnas'])
611
+ df = df[df['grna_name']==settings['fraction_grna']]
612
+ fraction_df = df[['fraction', 'prc']]
613
+ merged_df = pd.merge(fraction_df, result_df, on=['prc'])
614
+ cv_df = group_cv_score(settings['cv_csv'], settings['plate'], settings['column'], settings['data_column_cv'])
615
+ cv_df = cv_df[[settings['data_column_cv'], 'prc']]
616
+ merged_df = pd.merge(merged_df, cv_df, on=['prc'])
617
+
618
+ fig = plot_multi_channel_heatmap(merged_df, settings['column'])
619
+ if 'row_number' in merged_df.columns:
620
+ merged_df = merged_df.drop('row_num', axis=1)
621
+ mae_df = calculate_mae(merged_df)
622
+ if 'row_number' in mae_df.columns:
623
+ mae_df = mae_df.drop('row_num', axis=1)
624
+
625
+ if not settings['dst'] is None:
626
+ mae_dst = os.path.join(settings['dst'], f"mae_scores_comparison_plate_{settings['plate']}.csv")
627
+ merged_dst = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}_data.csv")
628
+ heatmap_save = os.path.join(settings['dst'], f"scores_comparison_plate_{settings['plate']}.pdf")
629
+ mae_df.to_csv(mae_dst, index=False)
630
+ merged_df.to_csv(merged_dst, index=False)
631
+ fig.savefig(heatmap_save, format='pdf', dpi=600, bbox_inches='tight')
632
+ return merged_df