spacr 0.3.46__py3-none-any.whl → 0.3.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/sequencing.py CHANGED
@@ -2,6 +2,11 @@ import os, gzip, re, time, gzip
2
2
  import pandas as pd
3
3
  from multiprocessing import Pool, cpu_count, Queue, Process
4
4
  from Bio.Seq import Seq
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import numpy as np
8
+ from .plot import plot_plates
9
+ from IPython.display import display
5
10
 
6
11
  # Function to map sequences to names (same as your original)
7
12
  def map_sequences_to_names(csv_file, sequences, rc):
@@ -480,4 +485,120 @@ def barecodes_reverse_complement(csv_file):
480
485
  # Save the DataFrame with the reverse complement sequences
481
486
  df.to_csv(new_filename, index=False)
482
487
 
483
- print(f"Reverse complement file saved as {new_filename}")
488
+ print(f"Reverse complement file saved as {new_filename}")
489
+
490
+ def graph_sequencing_stats(settings):
491
+
492
+ from .utils import correct_metadata_column_names
493
+
494
+ def _plot_density(df, dependent_variable, dst=None):
495
+ """Plot a density plot of the dependent variable."""
496
+ plt.figure(figsize=(10, 10))
497
+ sns.kdeplot(df[dependent_variable], fill=True, alpha=0.6)
498
+ plt.title(f'Density Plot of {dependent_variable}')
499
+ plt.xlabel(dependent_variable)
500
+ plt.ylabel('Density')
501
+ if dst is not None:
502
+ filename = os.path.join(dst, 'dependent_variable_density.pdf')
503
+ plt.savefig(filename, format='pdf')
504
+ print(f'Saved density plot to {filename}')
505
+ plt.show()
506
+
507
+ def find_and_visualize_fraction_threshold(df, target_unique_count=5, log_x=False, log_y=False, dst=None):
508
+ """
509
+ Find the fraction threshold where the recalculated unique count matches the target value,
510
+ and visualize the relationship between fraction thresholds and unique counts.
511
+ """
512
+
513
+ def _line_plot(df, x='fraction_threshold', y='unique_count', log_x=False, log_y=False):
514
+ if x not in df.columns or y not in df.columns:
515
+ raise ValueError(f"Columns '{x}' and/or '{y}' not found in the DataFrame.")
516
+ fig, ax = plt.subplots(figsize=(10, 10))
517
+ ax.plot(df[x], df[y], linestyle='-', color=(0 / 255, 155 / 255, 155 / 255), label=f"{y}")
518
+ ax.set_xlabel(x)
519
+ ax.set_ylabel(y)
520
+ ax.set_title(f'{y} vs {x}')
521
+ ax.legend()
522
+ if log_x:
523
+ ax.set_xscale('log')
524
+ if log_y:
525
+ ax.set_yscale('log')
526
+ fig.tight_layout()
527
+ return fig, ax
528
+
529
+ fraction_thresholds = np.linspace(0.001, 0.99, 1000)
530
+ results = []
531
+
532
+ # Iterate through the fraction thresholds
533
+ for threshold in fraction_thresholds:
534
+ filtered_df = df[df['fraction'] >= threshold]
535
+ unique_count = filtered_df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
536
+ results.append((threshold, unique_count))
537
+
538
+ results_df = pd.DataFrame(results, columns=['fraction_threshold', 'unique_count'])
539
+ closest_index = (results_df['unique_count'] - target_unique_count).abs().argmin()
540
+ closest_threshold = results_df.iloc[closest_index]
541
+
542
+ print(f"Closest Fraction Threshold: {closest_threshold['fraction_threshold']}")
543
+ print(f"Unique Count at Threshold: {closest_threshold['unique_count']}")
544
+
545
+ fig, ax = _line_plot(df=results_df, x='fraction_threshold', y='unique_count', log_x=log_x, log_y=log_y)
546
+
547
+ plt.axvline(x=closest_threshold['fraction_threshold'], color='black', linestyle='--',
548
+ label=f'Closest Threshold ({closest_threshold["fraction_threshold"]:.4f})')
549
+ plt.axhline(y=target_unique_count, color='black', linestyle='--',
550
+ label=f'Target Unique Count ({target_unique_count})')
551
+
552
+ plt.xlim(0,0.1)
553
+ plt.ylim(0,20)
554
+
555
+ if dst is not None:
556
+ fig_path = os.path.join(dst, 'results')
557
+ os.makedirs(fig_path, exist_ok=True)
558
+ fig_file_path = os.path.join(fig_path, 'fraction_threshold.pdf')
559
+ fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
560
+ print(f"Saved {fig_file_path}")
561
+ plt.show()
562
+
563
+ return closest_threshold['fraction_threshold']
564
+
565
+ if isinstance(settings['count_data'], str):
566
+ settings['count_data'] = [settings['count_data']]
567
+
568
+ dfs = []
569
+ for i, count_data in enumerate(settings['count_data']):
570
+ df = pd.read_csv(count_data)
571
+ df['plate'] = f'plate{i+1}'
572
+ df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
573
+ df['total_count'] = df.groupby(['prc'])['count'].transform('sum')
574
+ df['fraction'] = df['count'] / df['total_count']
575
+ dfs.append(df)
576
+
577
+ df = pd.concat(dfs, axis=0)
578
+
579
+ df = correct_metadata_column_names(df)
580
+
581
+ for c in settings['control_wells']:
582
+ df = df[df[settings['filter_column']] != c]
583
+
584
+ dst = os.path.dirname(settings['count_data'][0])
585
+
586
+ closest_threshold = find_and_visualize_fraction_threshold(df, settings['target_unique_count'], log_x=settings['log_x'], log_y=settings['log_y'], dst=dst)
587
+
588
+ # Apply the closest threshold to the DataFrame
589
+ df = df[df['fraction'] >= closest_threshold]
590
+
591
+ # Group by 'plate', 'row', 'column' and compute unique counts of 'grna'
592
+ unique_counts = df.groupby(['plate', 'row', 'column'])['grna'].nunique().reset_index(name='unique_counts')
593
+ unique_count_mean = df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
594
+ unique_count_std = df.groupby(['plate', 'row', 'column'])['grna'].nunique().std()
595
+
596
+ # Merge the unique counts back into the original DataFrame
597
+ df = pd.merge(df, unique_counts, on=['plate', 'row', 'column'], how='left')
598
+
599
+ print(f"unique_count mean: {unique_count_mean} std: {unique_count_std}")
600
+ display(df)
601
+ #_plot_density(df, dependent_variable='unique_counts')
602
+ plot_plates(df=df, variable='unique_counts', grouping='mean', min_max='allq', cmap='viridis',min_count=0, verbose=True, dst=dst)
603
+
604
+ return closest_threshold
spacr/settings.py CHANGED
@@ -549,7 +549,8 @@ def get_perform_regression_default_settings(settings):
549
549
  settings.setdefault('filter_column','column')
550
550
  settings.setdefault('plate','plate1')
551
551
  settings.setdefault('class_1_threshold',None)
552
- settings.setdefault('metadata_files',['/home/carruthers/Documents/TGME49_Summary.csv','/home/carruthers/Documents/TGGT1_Summary.csv'])
552
+ settings.setdefault('metadata_files',['/home/carruthers/Documents/TGGT1_Summary.csv','/home/carruthers/Documents/TGME49_Summary.csv'])
553
+ settings.setdefault('volcano','gene')
553
554
  settings.setdefault('toxo', True)
554
555
 
555
556
  if settings['regression_type'] == 'quantile':
spacr/toxo.py CHANGED
@@ -6,25 +6,53 @@ from adjustText import adjust_text
6
6
  import pandas as pd
7
7
  from scipy.stats import fisher_exact
8
8
  from IPython.display import display
9
+ from matplotlib.legend import Legend
10
+ from matplotlib.transforms import Bbox
11
+ from brokenaxes import brokenaxes
9
12
 
10
- def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, split_axis_lims = [10, None, None, 10], save_path=None):
11
- """
12
- Create a volcano plot with the ability to control the shape of points based on a categorical column,
13
- color points based on a condition, annotate specific points based on p-value and coefficient thresholds,
14
- and control the size of points.
15
- """
16
- volcano_path = save_path
17
13
 
18
- # Load the data
14
+ from matplotlib.gridspec import GridSpec
15
+
16
+
17
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location',
18
+ point_size=50, figsize=20, threshold=0,
19
+ save_path=None, x_lim=[-0.5, 0.5], y_lims=[[0, 6], [9, 15]]):
20
+
21
+ markers = [
22
+ 'o', # Circle
23
+ 'X', # X-shaped marker
24
+ '^', # Upward triangle
25
+ 's', # Square
26
+ 'v', # Downward triangle
27
+ 'P', # Plus-filled pentagon
28
+ '*', # Star
29
+ '+', # Plus
30
+ 'x', # Cross
31
+ '.', # Point
32
+ ',', # Pixel
33
+ 'd', # Diamond
34
+ 'D', # Thin diamond
35
+ 'h', # Hexagon 1
36
+ 'H', # Hexagon 2
37
+ 'p', # Pentagon
38
+ '|', # Vertical line
39
+ '_', # Horizontal line
40
+ ]
41
+
42
+ plt.rcParams.update({'font.size': 14})
43
+
44
+ # Load data
19
45
  if isinstance(data_path, pd.DataFrame):
20
46
  data = data_path
21
47
  else:
22
48
  data = pd.read_csv(data_path)
23
-
49
+
50
+ fontsize = 18
51
+
52
+ plt.rcParams.update({'font.size': fontsize})
24
53
  data['variable'] = data['feature'].str.extract(r'\[(.*?)\]')
25
54
  data['variable'].fillna(data['feature'], inplace=True)
26
- split_columns = data['variable'].str.split('_', expand=True)
27
- data['gene_nr'] = split_columns[0]
55
+ data['gene_nr'] = data['variable'].str.split('_').str[0]
28
56
  data = data[data['variable'] != 'Intercept']
29
57
 
30
58
  # Load metadata
@@ -32,165 +60,110 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
32
60
  metadata = metadata_path
33
61
  else:
34
62
  metadata = pd.read_csv(metadata_path)
35
-
36
63
  metadata['gene_nr'] = metadata['gene_nr'].astype(str)
37
64
  data['gene_nr'] = data['gene_nr'].astype(str)
38
65
 
39
- # Merge data and metadata on 'gene_nr'
40
- merged_data = pd.merge(data, metadata[['gene_nr', 'tagm_location']], on='gene_nr', how='left')
66
+ merged_data = pd.merge(data, metadata[['gene_nr', metadata_column]], on='gene_nr', how='left')
67
+ merged_data[metadata_column].fillna('unknown', inplace=True)
41
68
 
42
- merged_data.loc[merged_data['gene_nr'].str.startswith('4'), metadata_column] = 'GT1_gene'
43
- merged_data.loc[merged_data['gene_nr'] == 'Intercept', metadata_column] = 'Intercept'
44
- merged_data.loc[merged_data['condition'] == 'control', metadata_column] = 'control'
69
+ # Define palette and markers
70
+ palette = {'pc': 'red', 'nc': 'green', 'control': 'white', 'other': 'gray'}
71
+ marker_dict = {val: marker for val, marker in zip(
72
+ merged_data[metadata_column].unique(), markers)}
45
73
 
46
- # Categorize condition for coloring
47
- merged_data['condition'] = pd.Categorical(
48
- merged_data['condition'],
49
- categories=['other','pc', 'nc', 'control'],
50
- ordered=True)
51
-
74
+ # Create the figure with custom spacing
75
+ fig = plt.figure(figsize=(figsize,figsize))
76
+ gs = GridSpec(2, 1, height_ratios=[1, 3], hspace=0.05)
52
77
 
53
- display(merged_data)
78
+ ax_upper = fig.add_subplot(gs[0])
79
+ ax_lower = fig.add_subplot(gs[1], sharex=ax_upper)
54
80
 
55
- # Create subplots with a broken y-axis
56
- figsize_2 = figsize / 2
57
- fig, (ax1, ax2) = plt.subplots(
58
- 2, 1, figsize=(figsize, figsize),
59
- sharex=True, gridspec_kw={'height_ratios': [1, 3]}
60
- )
81
+ # Hide x-axis labels on the upper plot
82
+ ax_upper.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
61
83
 
62
- # Define color palette
63
- palette = {
64
- 'pc': 'red',
65
- 'nc': 'green',
66
- 'control': 'white',
67
- 'other': 'gray'}
84
+ hit_list = []
68
85
 
69
86
  # Scatter plot on both axes
70
- sns.scatterplot(
71
- data=merged_data,
72
- x='coefficient',
73
- y='-log10(p_value)',
74
- hue='condition', # Keep colors but prevent them from showing in the final legend
75
- style=metadata_column if metadata_column else None, # Shape-based legend
76
- s=point_size,
77
- edgecolor='black',
78
- palette=palette,
79
- legend='brief', # Capture the full legend initially
80
- alpha=0.8,
81
- ax=ax2 # Lower plot
82
- )
87
+ for _, row in merged_data.iterrows():
88
+ y_val = -np.log10(row['p_value'])
89
+ ax = ax_upper if y_val > y_lims[1][0] else ax_lower
83
90
 
84
- sns.scatterplot(
85
- data=merged_data[merged_data['-log10(p_value)'] > 10],
86
- x='coefficient',
87
- y='-log10(p_value)',
88
- hue='condition',
89
- style=metadata_column if metadata_column else None,
90
- s=point_size,
91
- palette=palette,
92
- edgecolor='black',
93
- legend=False, # Suppress legend for upper plot
94
- alpha=0.8,
95
- ax=ax1 # Upper plot
96
- )
97
-
98
- if isinstance(split_axis_lims, list):
99
- if len(split_axis_lims) == 4:
100
- ylim_min_ax1 = split_axis_lims[0]
101
- if split_axis_lims[1] is None:
102
- ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
103
- else:
104
- ylim_max_ax1 = split_axis_lims[1]
105
- ylim_min_ax2 = split_axis_lims[2]
106
- ylim_max_ax2 = split_axis_lims[3]
107
- else:
108
- ylim_min_ax1 = None
109
- ylim_max_ax1 = merged_data['-log10(p_value)'].max() + 5
110
- ylim_min_ax2 = 0
111
- ylim_max_ax2 = None
112
-
113
- # Set axis limits and hide unnecessary parts
114
- ax1.set_ylim(ylim_min_ax1, ylim_max_ax1)
115
- ax2.set_ylim(0, ylim_max_ax2)
116
- ax1.spines['bottom'].set_visible(False)
117
- ax2.spines['top'].set_visible(False)
118
- ax1.tick_params(labelbottom=False)
119
-
120
- if ax1.get_legend() is not None:
121
- ax1.legend_.remove()
122
- ax1.get_legend().remove() # Extract handles and labels from the legend
123
- handles, labels = ax2.get_legend_handles_labels()
124
-
125
- # Identify shape-based legend entries (skip color-based entries)
126
- shape_handles = handles[len(set(merged_data['condition'])):]
127
- shape_labels = labels[len(set(merged_data['condition'])):]
128
-
129
- # Set the legend with only shape-based entries
130
- ax2.legend(
131
- shape_handles,
132
- shape_labels,
133
- bbox_to_anchor=(1.05, 1),
134
- loc='upper left',
135
- borderaxespad=0.
136
- )
137
-
138
- ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
139
-
140
- # Add vertical threshold lines to both plots
141
- if threshold > 0:
142
- for ax in (ax1, ax2):
143
- ax.axvline(x=-abs(threshold), linestyle='--', color='black')
144
- ax.axvline(x=abs(threshold), linestyle='--', color='black')
91
+ ax.scatter(
92
+ row['coefficient'], y_val,
93
+ color=palette.get(row['condition'], 'gray'),
94
+ marker=marker_dict.get(row[metadata_column], 'o'),
95
+ s=point_size, edgecolor='black', alpha=0.6
96
+ )
145
97
 
146
- # Add a horizontal line at p-value threshold (0.05)
147
- ax2.axhline(y=-np.log10(0.05), color='black', linestyle='--')
148
-
149
- # Annotate significant points on both axes
150
- texts_ax1 = []
151
- texts_ax2 = []
152
-
153
- for i, row in merged_data.iterrows():
154
98
  if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
155
- # Select the appropriate axis for the annotation
156
- #ax = ax1 if row['-log10(p_value)'] > 10 else ax2
99
+ hit_list.append(row['variable'])
157
100
 
158
- ax = ax1 if row['-log10(p_value)'] >= ax1.get_ylim()[0] else ax2
101
+ # Set axis limits
102
+ ax_upper.set_ylim(y_lims[1])
103
+ ax_lower.set_ylim(y_lims[0])
104
+ ax_lower.set_xlim(x_lim)
159
105
 
106
+ ax_lower.spines['top'].set_visible(False)
107
+ ax_upper.spines['top'].set_visible(False)
108
+ ax_upper.spines['bottom'].set_visible(False)
160
109
 
161
- # Create the annotation on the selected axis
162
- text = ax.text(
163
- row['coefficient'],
164
- -np.log10(row['p_value']),
165
- row['variable'],
166
- fontsize=8,
167
- ha='center',
168
- va='bottom',
169
- )
110
+ # Set x-axis and y-axis titles
111
+ ax_lower.set_xlabel('Coefficient') # X-axis title on the lower graph
112
+ ax_lower.set_ylabel('-log10(p-value)') # Y-axis title on the lower graph
113
+ ax_upper.set_ylabel('-log10(p-value)') # Y-axis title on the upper graph
114
+
115
+ for ax in [ax_upper, ax_lower]:
116
+ ax.spines['right'].set_visible(False)
170
117
 
171
- # Store the text annotation in the correct list
172
- if ax == ax1:
173
- texts_ax1.append(text)
174
- else:
175
- texts_ax2.append(text)
118
+ # Add threshold lines to both axes
119
+ for ax in [ax_upper, ax_lower]:
120
+ ax.axvline(x=-abs(threshold), linestyle='--', color='black')
121
+ ax.axvline(x=abs(threshold), linestyle='--', color='black')
176
122
 
177
- # Adjust text positions to avoid overlap for both axes
178
- adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1)
179
- adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2)
123
+ ax_lower.axhline(y=-np.log10(0.05), linestyle='--', color='black')
180
124
 
181
- # Move the legend outside the lower plot
182
- ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
125
+ # Annotate significant points
126
+ texts_upper, texts_lower = [], [] # Collect text annotations separately
183
127
 
184
- # Adjust the spacing between subplots and move the title
185
- plt.subplots_adjust(hspace=0.05)
186
- fig.suptitle('Custom Volcano Plot of Coefficients', y=1.02, fontsize=16) # Title above the top plot
128
+ for _, row in merged_data.iterrows():
129
+ y_val = -np.log10(row['p_value'])
130
+ if row['p_value'] > 0.05 or abs(row['coefficient']) < abs(threshold):
131
+ continue
187
132
 
188
- # Save the plot as PDF
189
- plt.savefig(volcano_path, format='pdf', bbox_inches='tight')
190
- print(f'Saved Volcano plot: {volcano_path}')
133
+ ax = ax_upper if y_val > y_lims[1][0] else ax_lower
134
+ text = ax.text(row['coefficient'], y_val, row['variable'],
135
+ fontsize=fontsize, ha='center', va='bottom')
191
136
 
192
- # Show the plot
137
+ if ax == ax_upper:
138
+ texts_upper.append(text)
139
+ else:
140
+ texts_lower.append(text)
141
+
142
+ # Adjust text positions to avoid overlap
143
+ adjust_text(texts_upper, ax=ax_upper, arrowprops=dict(arrowstyle='-', color='black'))
144
+ adjust_text(texts_lower, ax=ax_lower, arrowprops=dict(arrowstyle='-', color='black'))
145
+
146
+ # Add a single legend on the lower axis
147
+ handles = [plt.Line2D([0], [0], marker=m, color='w', markerfacecolor='gray', markersize=10)
148
+ for m in marker_dict.values()]
149
+ labels = marker_dict.keys()
150
+ ax_lower.legend(handles,
151
+ labels,
152
+ bbox_to_anchor=(1.05, 1),
153
+ loc='upper left',
154
+ borderaxespad=0.25,
155
+ labelspacing=2,
156
+ handletextpad=0.25,
157
+ markerscale=2,
158
+ prop={'size': fontsize})
159
+
160
+
161
+ # Save and show the plot
162
+ if save_path:
163
+ plt.savefig(save_path, format='pdf', bbox_inches='tight')
193
164
  plt.show()
165
+
166
+ return hit_list
194
167
 
195
168
  def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=['Computed GO Processes', 'Curated GO Components', 'Curated GO Functions', 'Curated GO Processes']):
196
169
  """
@@ -331,4 +304,150 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
331
304
 
332
305
  # Show the combined plot
333
306
  plt.tight_layout()
307
+ plt.show()
308
+
309
+ def plot_gene_phenotypes(data, gene_list, x_column='Gene ID', data_column='T.gondii GT1 CRISPR Phenotype - Mean Phenotype',error_column='T.gondii GT1 CRISPR Phenotype - Standard Error', save_path=None):
310
+ """
311
+ Plot a line graph for the mean phenotype with standard error shading and highlighted genes.
312
+
313
+ Args:
314
+ data (pd.DataFrame): The input DataFrame containing gene data.
315
+ gene_list (list): A list of gene names to highlight on the plot.
316
+ """
317
+ # Ensure x_column is properly processed
318
+ def extract_gene_id(gene):
319
+ if isinstance(gene, str) and '_' in gene:
320
+ return gene.split('_')[1]
321
+ return str(gene)
322
+
323
+ data.loc[:, data_column] = pd.to_numeric(data[data_column], errors='coerce')
324
+ data = data.dropna(subset=[data_column])
325
+ data.loc[:, error_column] = pd.to_numeric(data[error_column], errors='coerce')
326
+ data = data.dropna(subset=[error_column])
327
+
328
+ data['x'] = data[x_column].apply(extract_gene_id)
329
+
330
+ # Sort by the data_column and assign ranks
331
+ data = data.sort_values(by=data_column).reset_index(drop=True)
332
+ data['rank'] = range(1, len(data) + 1)
333
+
334
+ # Prepare the x, y, and error values for plotting
335
+ x = data['rank']
336
+ y = data[data_column]
337
+ yerr = data[error_column]
338
+
339
+ # Create the plot
340
+ plt.figure(figsize=(10, 10))
341
+
342
+ # Plot the mean phenotype with standard error shading
343
+ plt.plot(x, y, label='Mean Phenotype', color=(0/255, 155/255, 155/255), linewidth=2)
344
+ plt.fill_between(
345
+ x, y - yerr, y + yerr,
346
+ color=(0/255, 155/255, 155/255), alpha=0.1, label='Standard Error'
347
+ )
348
+
349
+ # Prepare for adjustText
350
+ texts = [] # Store text objects for adjustment
351
+
352
+ # Highlight the genes in the gene_list
353
+ for gene in gene_list:
354
+ gene_id = extract_gene_id(gene)
355
+ gene_data = data[data['x'] == gene_id]
356
+ if not gene_data.empty:
357
+ # Scatter the highlighted points in purple and add labels for adjustment
358
+ plt.scatter(
359
+ gene_data['rank'],
360
+ gene_data[data_column],
361
+ color=(155/255, 55/255, 155/255),
362
+ s=200,
363
+ alpha=0.6,
364
+ label=f'Highlighted Gene: {gene}',
365
+ zorder=3 # Ensure the points are on top
366
+ )
367
+ # Add the text label next to the highlighted gene
368
+ texts.append(
369
+ plt.text(
370
+ gene_data['rank'].values[0],
371
+ gene_data[data_column].values[0],
372
+ gene,
373
+ fontsize=18,
374
+ ha='right'
375
+ )
376
+ )
377
+
378
+ # Adjust text to avoid overlap with lines drawn from points to text
379
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='gray'))
380
+
381
+ # Label the plot
382
+ plt.xlabel('Rank')
383
+ plt.ylabel('Mean Phenotype')
384
+ #plt.xticks(rotation=90) # Rotate x-axis labels for readability
385
+ plt.legend().remove() # Remove the legend if not needed
386
+ plt.tight_layout()
387
+
388
+ # Save the plot if a path is provided
389
+ if save_path:
390
+ plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
391
+ print(f"Figure saved to {save_path}")
392
+
393
+ plt.show()
394
+
395
+ def plot_gene_heatmaps(data, gene_list, columns, x_column='Gene ID', normalize=False, save_path=None):
396
+ """
397
+ Generate a teal-to-white heatmap with the specified columns and genes.
398
+
399
+ Args:
400
+ data (pd.DataFrame): The input DataFrame containing gene data.
401
+ gene_list (list): A list of genes to include in the heatmap.
402
+ columns (list): A list of column names to visualize as heatmaps.
403
+ normalize (bool): If True, normalize the values for each gene between 0 and 1.
404
+ save_path (str): Optional. If provided, the plot will be saved to this path.
405
+ """
406
+ # Ensure x_column is properly processed
407
+ def extract_gene_id(gene):
408
+ if isinstance(gene, str) and '_' in gene:
409
+ return gene.split('_')[1]
410
+ return str(gene)
411
+
412
+ data['x'] = data[x_column].apply(extract_gene_id)
413
+
414
+ # Filter the data to only include the specified genes
415
+ filtered_data = data[data['x'].isin(gene_list)].set_index('x')[columns]
416
+
417
+ # Normalize each gene's values between 0 and 1 if normalize=True
418
+ if normalize:
419
+ filtered_data = filtered_data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=1)
420
+
421
+ # Define the figure size dynamically based on the number of genes and columns
422
+ width = len(columns) * 4
423
+ height = len(gene_list) * 1
424
+
425
+ # Create the heatmap
426
+ plt.figure(figsize=(width, height))
427
+ cmap = sns.color_palette("viridis", as_cmap=True)
428
+
429
+ # Plot the heatmap with genes on the y-axis and columns on the x-axis
430
+ sns.heatmap(
431
+ filtered_data,
432
+ cmap=cmap,
433
+ cbar=True,
434
+ annot=False,
435
+ linewidths=0.5,
436
+ square=True
437
+ )
438
+
439
+ # Set the labels
440
+ plt.xticks(rotation=90, ha='center') # Rotate x-axis labels for better readability
441
+ plt.yticks(rotation=0) # Keep y-axis labels horizontal
442
+ plt.xlabel('')
443
+ plt.ylabel('')
444
+
445
+ # Adjust layout to ensure the plot fits well
446
+ plt.tight_layout()
447
+
448
+ # Save the plot if a path is provided
449
+ if save_path:
450
+ plt.savefig(save_path, format='pdf', dpi=600, bbox_inches='tight')
451
+ print(f"Figure saved to {save_path}")
452
+
334
453
  plt.show()
spacr/utils.py CHANGED
@@ -4067,7 +4067,7 @@ def generate_path_list_from_db(db_path, file_metadata):
4067
4067
 
4068
4068
  return all_paths
4069
4069
 
4070
- def correct_paths(df, base_path):
4070
+ def correct_paths(df, base_path, folder='data'):
4071
4071
 
4072
4072
  if isinstance(df, pd.DataFrame):
4073
4073
 
@@ -4083,9 +4083,9 @@ def correct_paths(df, base_path):
4083
4083
  adjusted_image_paths = []
4084
4084
  for path in image_paths:
4085
4085
  if base_path not in path:
4086
- parts = path.split('/data/')
4086
+ parts = path.split(f'/{folder}/')
4087
4087
  if len(parts) > 1:
4088
- new_path = os.path.join(base_path, 'data', parts[1])
4088
+ new_path = os.path.join(base_path, f'{folder}', parts[1])
4089
4089
  adjusted_image_paths.append(new_path)
4090
4090
  else:
4091
4091
  adjusted_image_paths.append(path)
@@ -5209,4 +5209,27 @@ def fill_holes_in_mask(mask):
5209
5209
  # Assign the original label back to the filled object
5210
5210
  filled_mask[filled_object] = i
5211
5211
 
5212
- return filled_mask
5212
+ return filled_mask
5213
+
5214
+ def correct_metadata_column_names(df):
5215
+ if 'plate_name' in df.columns:
5216
+ df = df.rename(columns={'plate_name': 'plate'})
5217
+ if 'column_name' in df.columns:
5218
+ df = df.rename(columns={'column_name': 'column'})
5219
+ if 'col' in df.columns:
5220
+ df = df.rename(columns={'col': 'column'})
5221
+ if 'row_name' in df.columns:
5222
+ df = df.rename(columns={'row_name': 'row'})
5223
+ if 'grna_name' in df.columns:
5224
+ df = df.rename(columns={'grna_name': 'grna'})
5225
+ if 'plate_row' in df.columns:
5226
+ df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
5227
+ return df
5228
+
5229
+ def control_filelist(folder, mode='column', values=['01','02']):
5230
+ files = os.listdir(folder)
5231
+ if mode is 'column':
5232
+ filtered_files = [file for file in files if file.split('_')[1][1:] in values]
5233
+ if mode is 'row':
5234
+ filtered_files = [file for file in files if file.split('_')[1][:1] in values]
5235
+ return filtered_files
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.3.46
3
+ Version: 0.3.50
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -66,6 +66,7 @@ Requires-Dist: gdown
66
66
  Requires-Dist: IPython<9.0,>=8.18.1
67
67
  Requires-Dist: ipykernel
68
68
  Requires-Dist: ipywidgets<9.0,>=8.1.2
69
+ Requires-Dist: brokenaxes<1.0,>=0.6.2
69
70
  Requires-Dist: huggingface-hub<0.25,>=0.24.0
70
71
  Provides-Extra: dev
71
72
  Requires-Dist: pytest<3.11,>=3.9; extra == "dev"