spacr 0.3.46__py3-none-any.whl → 0.3.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/ml.py CHANGED
@@ -343,12 +343,79 @@ def regression(df, csv_path, dependent_variable='predictions', regression_type=N
343
343
 
344
344
  return model, coef_df
345
345
 
346
- def perform_regression(settings):
346
+ def graph_cell_count_threshold(settings):
347
+
348
+ from .utils import correct_metadata_column_names
349
+
350
+ def _line_plot(df, x, y, log_x=False, log_y=False, title=""):
351
+ fig, ax = plt.subplots(figsize=(10, 6))
352
+ ax.plot(df[x], df[y], linestyle='-', color=(0, 0.6, 0.6), label=f"{y}")
353
+ ax.set_xlabel(x)
354
+ ax.set_ylabel(y)
355
+ ax.set_title(title)
356
+ ax.legend()
357
+ if log_x:
358
+ ax.set_xscale('log')
359
+ if log_y:
360
+ ax.set_yscale('log')
361
+ plt.show()
362
+
363
+ if isinstance(settings['score_data'], str):
364
+ settings['score_data'] = [settings['score_data']]
365
+
366
+ dfs = []
367
+ for i, score_data in enumerate(settings['score_data']):
368
+ df = pd.read_csv(score_data)
369
+ df = correct_metadata_column_names(df)
370
+ df['plate'] = f'plate{i+1}'
371
+ df['prc'] = df['plate'] + '_' + df['row'].astype(str) + '_' + df['column'].astype(str)
372
+ dfs.append(df)
373
+
374
+ df = pd.concat(dfs, axis=0)
375
+
376
+ # Compute the number of cells (or scores) per well
377
+ cell_counts = df.groupby('prc').size().reset_index(name='cell_count')
378
+
379
+ # Merge the cell counts back into the original DataFrame
380
+ df = df.merge(cell_counts, on='prc')
381
+
382
+ # Generate a range of thresholds
383
+ thresholds = np.arange(1, df['cell_count'].max() + 1)
384
+ results = []
385
+
386
+ # Iterate over thresholds and compute score mean and variance
387
+ for threshold in thresholds:
388
+ filtered_df = df[df['cell_count'] >= threshold]
389
+ score_mean = filtered_df.groupby('prc')[settings['score_column']].mean().mean()
390
+ score_variance = filtered_df.groupby('prc')[settings['score_column']].mean().var()
391
+ results.append((threshold, score_mean, score_variance))
392
+
393
+ results_df = pd.DataFrame(results, columns=['cell_count_threshold', 'score_mean', 'score_variance'])
394
+
395
+ if results_df.empty:
396
+ raise ValueError("No valid results were found. Check your data and thresholds.")
347
397
 
398
+ closest_threshold = results_df['score_variance'].diff().abs().argmin()
399
+ optimal_threshold = results_df.iloc[closest_threshold]
400
+
401
+ print(f"Optimal Threshold: {optimal_threshold['cell_count_threshold']}")
402
+ print(f"Score Mean at Optimal Threshold: {optimal_threshold['score_mean']}")
403
+ print(f"Score Variance at Optimal Threshold: {optimal_threshold['score_variance']}")
404
+
405
+ _line_plot(results_df, x='cell_count_threshold', y='score_mean',
406
+ title='Mean Well Score vs. Cell Count Threshold')
407
+ _line_plot(results_df, x='cell_count_threshold', y='score_variance',
408
+ title='Score Variance vs. Cell Count Threshold')
409
+
410
+ return optimal_threshold['cell_count_threshold']
411
+
412
+ def perform_regression(settings):
413
+
348
414
  from .plot import plot_plates
349
415
  from .utils import merge_regression_res_with_metadata, save_settings
350
416
  from .settings import get_perform_regression_default_settings
351
417
  from .toxo import go_term_enrichment_by_column, custom_volcano_plot
418
+ from .sequencing import graph_sequencing_stats
352
419
 
353
420
  def _perform_regression_read_data(settings):
354
421
 
@@ -468,9 +535,15 @@ def perform_regression(settings):
468
535
  score_data_df = clean_controls(score_data_df, settings['filter_value'], settings['filter_column'])
469
536
  print(f"Dependent variable after clean_controls: {len(score_data_df)}")
470
537
 
538
+ if settings['min_cell_count'] is None:
539
+ settings['min_cell_count'] = graph_cell_count_threshold(settings)
540
+
471
541
  dependent_df, dependent_variable = process_scores(score_data_df, settings['dependent_variable'], settings['plate'], settings['min_cell_count'], settings['agg_type'], settings['transform'])
472
542
  print(f"Dependent variable after process_scores: {len(dependent_df)}")
473
543
 
544
+ if settings['fraction_threshold'] is None:
545
+ settings['fraction_threshold'] = graph_sequencing_stats(settings)
546
+
474
547
  independent_df = process_reads(count_data_df, settings['fraction_threshold'], settings['plate'], filter_column=filter_column, filter_value=filter_value)
475
548
  independent_df, n_grna, n_gene = _count_variable_instances(independent_df, column_1='grna', column_2='gene')
476
549
 
@@ -499,8 +572,12 @@ def perform_regression(settings):
499
572
  grna_coef_df = grna_coef_df.dropna(subset=['n_grna'])
500
573
 
501
574
  if settings['controls'] is not None:
575
+
502
576
  control_coef_df = grna_coef_df[grna_coef_df['grna'].isin(settings['controls'])]
503
577
  mean_coef = control_coef_df['coefficient'].mean()
578
+ significant_c = control_coef_df[control_coef_df['p_value']<= 0.05]
579
+ mean_coef_c = significant_c['coefficient'].mean()
580
+ print(mean_coef, mean_coef_c)
504
581
 
505
582
  if settings['threshold_method'] in ['var','variance']:
506
583
  coef_mes = control_coef_df['coefficient'].var()
@@ -508,6 +585,7 @@ def perform_regression(settings):
508
585
  coef_mes = control_coef_df['coefficient'].std()
509
586
  else:
510
587
  raise ValueError(f"Unsupported threshold method {settings['threshold_method']}. Supported methods: ['var','variance','std','standard_deveation']")
588
+
511
589
  reg_threshold = mean_coef + (settings['threshold_multiplier'] * coef_mes)
512
590
 
513
591
  coef_df.to_csv(results_path, index=False)
@@ -531,6 +609,12 @@ def perform_regression(settings):
531
609
 
532
610
  significant.to_csv(hits_path, index=False)
533
611
 
612
+ significant_grna_filtered = significant[significant['n_grna'] > settings['min_n']]
613
+ significant_gene_filtered = significant[significant['n_gene'] > settings['min_n']]
614
+ significant_filtered = pd.concat([significant_grna_filtered, significant_gene_filtered])
615
+ filtered_hit_path = os.path.join(os.path.dirname(hits_path), 'results_significant_filtered.csv')
616
+ significant_filtered.to_csv(filtered_hit_path, index=False)
617
+
534
618
  if isinstance(settings['metadata_files'], str):
535
619
  settings['metadata_files'] = [settings['metadata_files']]
536
620
 
@@ -549,9 +633,15 @@ def perform_regression(settings):
549
633
  base_dir = os.path.dirname(os.path.abspath(__file__))
550
634
  metadata_path = os.path.join(base_dir, 'resources', 'data', 'lopit.csv')
551
635
 
552
- custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=200, figsize=20, threshold=reg_threshold, split_axis_lims=settings['split_axis_lims'], save_path=volcano_path)
553
- #custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=reg_threshold)
554
- #custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=reg_threshold)
636
+ if settings['volcano'] == 'all':
637
+ print('all')
638
+ custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, split_axis_lims=settings['split_axis_lims'], save_path=volcano_path, x_lim=settings['x_lim'])
639
+ elif settings['volcano'] == 'gene':
640
+ print('gene')
641
+ custom_volcano_plot(data_path_gene, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, split_axis_lims=settings['split_axis_lims'], save_path=volcano_path, x_lim=settings['x_lim'])
642
+ elif settings['volcano'] == 'grna':
643
+ print('grna')
644
+ custom_volcano_plot(data_path_grna, metadata_path, metadata_column='tagm_location', point_size=600, figsize=20, threshold=reg_threshold, split_axis_lims=settings['split_axis_lims'], save_path=volcano_path, x_lim=settings['x_lim'])
555
645
 
556
646
  #if len(significant) > 2:
557
647
  # metadata_path = os.path.join(base_dir, 'resources', 'data', 'toxoplasma_metadata.csv')
spacr/plot.py CHANGED
@@ -2733,7 +2733,7 @@ class spacrGraph:
2733
2733
  hue = None
2734
2734
 
2735
2735
  # Create the jitter plot
2736
- sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax,alpha=0.6)
2736
+ sns.stripplot(data=self.df_melted,x=x_axis_column,y='Value',hue=self.hue, palette=self.sns_palette, dodge=self.jitter_bar_dodge, jitter=self.bar_width, ax=ax, alpha=0.6, size=16)
2737
2737
 
2738
2738
  # Adjust legend and labels
2739
2739
  ax.set_xlabel(self.grouping_column)
@@ -2754,6 +2754,12 @@ class spacrGraph:
2754
2754
  # Ensure epoch is used on the x-axis and accuracy on the y-axis
2755
2755
  x_axis_column = self.data_column[0]
2756
2756
  y_axis_column = self.data_column[1]
2757
+
2758
+ if self.log_y:
2759
+ self.df[y_axis_column] = np.log10(self.df[y_axis_column])
2760
+
2761
+ if self.log_x:
2762
+ self.df[x_axis_column] = np.log10(self.df[x_axis_column])
2757
2763
 
2758
2764
  # Set hue to the grouping column to get one line per group
2759
2765
  hue = self.grouping_column
@@ -2771,11 +2777,6 @@ class spacrGraph:
2771
2777
  ax.set_xlabel(f"{x_axis_column}")
2772
2778
  ax.set_ylabel(f"{y_axis_column}")
2773
2779
 
2774
- if self.log_y:
2775
- ax.set_yscale('log')
2776
- if self.log_x:
2777
- ax.set_xscale('log')
2778
-
2779
2780
  def _create_line_with_std_area(self, ax):
2780
2781
  """Helper method to create a line graph with shaded area representing standard deviation."""
2781
2782
 
@@ -2784,15 +2785,22 @@ class spacrGraph:
2784
2785
  y_axis_column_mean = f"mean_{y_axis_column}"
2785
2786
  y_axis_column_std = f"std_{y_axis_column_mean}"
2786
2787
 
2788
+ if self.log_y:
2789
+ self.df[y_axis_column] = np.log10(self.df[y_axis_column])
2790
+
2791
+ if self.log_x:
2792
+ self.df[x_axis_column] = np.log10(self.df[x_axis_column])
2793
+
2787
2794
  # Pivot the DataFrame to get mean and std for each epoch across plates
2788
2795
  summary_df = self.df.pivot_table(index=x_axis_column,values=y_axis_column,aggfunc=['mean', 'std']).reset_index()
2789
2796
 
2790
2797
  # Flatten MultiIndex columns (result of pivoting)
2791
2798
  summary_df.columns = [x_axis_column, y_axis_column_mean, y_axis_column_std]
2792
-
2799
+
2793
2800
  # Plot the mean accuracy as a line
2794
2801
  sns.lineplot(data=summary_df,x=x_axis_column,y=y_axis_column_mean,ax=ax,marker='o',linewidth=1,markersize=0,color='blue',label=y_axis_column_mean)
2795
2802
 
2803
+
2796
2804
  # Fill the area representing the standard deviation
2797
2805
  ax.fill_between(summary_df[x_axis_column],summary_df[y_axis_column_mean] - summary_df[y_axis_column_std],summary_df[y_axis_column_mean] + summary_df[y_axis_column_std],color='blue', alpha=0.1 )
2798
2806
 
@@ -2800,11 +2808,6 @@ class spacrGraph:
2800
2808
  ax.set_xlabel(f"{x_axis_column}")
2801
2809
  ax.set_ylabel(f"{y_axis_column}")
2802
2810
 
2803
- if self.log_y:
2804
- ax.set_yscale('log')
2805
- if self.log_x:
2806
- ax.set_xscale('log')
2807
-
2808
2811
  def _create_box_plot(self, ax):
2809
2812
  """Helper method to create a box plot with consistent spacing."""
2810
2813
  # Combine grouping column and data column if needed
@@ -2969,23 +2972,29 @@ def plot_data_from_db(settings):
2969
2972
  df (pd.DataFrame): The extracted table as a DataFrame.
2970
2973
  """
2971
2974
 
2975
+
2976
+
2972
2977
  if isinstance(settings['src'], str):
2973
2978
  srcs = [settings['src']]
2974
2979
  elif isinstance(settings['src'], list):
2975
2980
  srcs = settings['src']
2976
- if isinstance(settings['database'], str):
2977
- settings['database'] = [settings['database'] for _ in range(len(srcs))]
2978
2981
  else:
2979
2982
  raise ValueError("src must be a string or a list of strings.")
2980
2983
 
2984
+ if isinstance(settings['database'], str):
2985
+ settings['database'] = [settings['database'] for _ in range(len(srcs))]
2986
+
2987
+ settings['dst'] = os.path.join(srcs[0], 'results')
2988
+
2981
2989
  save_settings(settings, name=f"{settings['graph_name']}_plot_settings_db", show=True)
2982
2990
 
2983
2991
  dfs = []
2984
2992
  for i, src in enumerate(srcs):
2985
2993
 
2986
2994
  db_loc = os.path.join(src, 'measurements', settings['database'][i])
2987
-
2995
+ print(f"Database: {db_loc}")
2988
2996
  if settings['table_names'] in ['saliency_image_correlations']:
2997
+ print(f"Database table: {settings['table_names']}")
2989
2998
  [df1] = _read_db(db_loc, tables=[settings['table_names']])
2990
2999
  else:
2991
3000
  df1, _ = _read_and_merge_data(locs=[db_loc],
@@ -3006,8 +3015,9 @@ def plot_data_from_db(settings):
3006
3015
 
3007
3016
  df = pd.concat(dfs, axis=0)
3008
3017
  df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
3009
- df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3010
- df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3018
+ #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3019
+ #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3020
+ df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
3011
3021
 
3012
3022
  if settings['cell_plate_metadata'] != None:
3013
3023
  df = df.dropna(subset='host_cell')
@@ -3021,7 +3031,7 @@ def plot_data_from_db(settings):
3021
3031
  df = df.dropna(subset=settings['data_column'])
3022
3032
  df = df.dropna(subset=settings['grouping_column'])
3023
3033
 
3024
- #df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
3034
+
3025
3035
  src = srcs[0]
3026
3036
  dst = os.path.join(src, 'results', settings['graph_name'])
3027
3037
  os.makedirs(dst, exist_ok=True)
spacr/sequencing.py CHANGED
@@ -2,6 +2,11 @@ import os, gzip, re, time, gzip
2
2
  import pandas as pd
3
3
  from multiprocessing import Pool, cpu_count, Queue, Process
4
4
  from Bio.Seq import Seq
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import numpy as np
8
+ from .plot import plot_plates
9
+ from IPython.display import display
5
10
 
6
11
  # Function to map sequences to names (same as your original)
7
12
  def map_sequences_to_names(csv_file, sequences, rc):
@@ -480,4 +485,117 @@ def barecodes_reverse_complement(csv_file):
480
485
  # Save the DataFrame with the reverse complement sequences
481
486
  df.to_csv(new_filename, index=False)
482
487
 
483
- print(f"Reverse complement file saved as {new_filename}")
488
+ print(f"Reverse complement file saved as {new_filename}")
489
+
490
+ def graph_sequencing_stats(settings):
491
+
492
+ from .utils import correct_metadata_column_names
493
+
494
+ def _plot_density(df, dependent_variable, dst=None):
495
+ """Plot a density plot of the dependent variable."""
496
+ plt.figure(figsize=(10, 6))
497
+ sns.kdeplot(df[dependent_variable], fill=True, alpha=0.6)
498
+ plt.title(f'Density Plot of {dependent_variable}')
499
+ plt.xlabel(dependent_variable)
500
+ plt.ylabel('Density')
501
+ if dst is not None:
502
+ filename = os.path.join(dst, 'dependent_variable_density.pdf')
503
+ plt.savefig(filename, format='pdf')
504
+ print(f'Saved density plot to {filename}')
505
+ plt.show()
506
+
507
+ def find_and_visualize_fraction_threshold(df, target_unique_count=5, log_x=False, log_y=False, dst=None):
508
+ """
509
+ Find the fraction threshold where the recalculated unique count matches the target value,
510
+ and visualize the relationship between fraction thresholds and unique counts.
511
+ """
512
+
513
+ def _line_plot(df, x='fraction_threshold', y='unique_count', log_x=False, log_y=False):
514
+ if x not in df.columns or y not in df.columns:
515
+ raise ValueError(f"Columns '{x}' and/or '{y}' not found in the DataFrame.")
516
+ fig, ax = plt.subplots(figsize=(10, 10))
517
+ ax.plot(df[x], df[y], linestyle='-', color=(0 / 255, 155 / 255, 155 / 255), label=f"{y}")
518
+ ax.set_xlabel(x)
519
+ ax.set_ylabel(y)
520
+ ax.set_title(f'{y} vs {x}')
521
+ ax.legend()
522
+ if log_x:
523
+ ax.set_xscale('log')
524
+ if log_y:
525
+ ax.set_yscale('log')
526
+ fig.tight_layout()
527
+ return fig, ax
528
+
529
+ fraction_thresholds = np.linspace(0.001, 0.99, 1000)
530
+ results = []
531
+
532
+ # Iterate through the fraction thresholds
533
+ for threshold in fraction_thresholds:
534
+ filtered_df = df[df['fraction'] >= threshold]
535
+ unique_count = filtered_df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
536
+ results.append((threshold, unique_count))
537
+
538
+ results_df = pd.DataFrame(results, columns=['fraction_threshold', 'unique_count'])
539
+ closest_index = (results_df['unique_count'] - target_unique_count).abs().argmin()
540
+ closest_threshold = results_df.iloc[closest_index]
541
+
542
+ print(f"Closest Fraction Threshold: {closest_threshold['fraction_threshold']}")
543
+ print(f"Unique Count at Threshold: {closest_threshold['unique_count']}")
544
+
545
+ fig, ax = _line_plot(df=results_df, x='fraction_threshold', y='unique_count', log_x=log_x, log_y=log_y)
546
+
547
+ plt.axvline(x=closest_threshold['fraction_threshold'], color='black', linestyle='--',
548
+ label=f'Closest Threshold ({closest_threshold["fraction_threshold"]:.4f})')
549
+ plt.axhline(y=target_unique_count, color='black', linestyle='--',
550
+ label=f'Target Unique Count ({target_unique_count})')
551
+
552
+ if dst is not None:
553
+ fig_path = os.path.join(dst, 'results')
554
+ os.makedirs(fig_path, exist_ok=True)
555
+ fig_file_path = os.path.join(fig_path, 'fraction_threshold.pdf')
556
+ fig.savefig(fig_file_path, format='pdf', dpi=600, bbox_inches='tight')
557
+ print(f"Saved {fig_file_path}")
558
+ plt.show()
559
+
560
+ return closest_threshold['fraction_threshold']
561
+
562
+ if isinstance(settings['count_data'], str):
563
+ settings['count_data'] = [settings['count_data']]
564
+
565
+ dfs = []
566
+ for i, count_data in enumerate(settings['count_data']):
567
+ df = pd.read_csv(count_data)
568
+ df['plate'] = f'plate{i+1}'
569
+ df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
570
+ df['total_count'] = df.groupby(['prc'])['count'].transform('sum')
571
+ df['fraction'] = df['count'] / df['total_count']
572
+ dfs.append(df)
573
+
574
+ df = pd.concat(dfs, axis=0)
575
+
576
+ df = correct_metadata_column_names(df)
577
+
578
+ for c in settings['control_wells']:
579
+ df = df[df[settings['filter_column']] != c]
580
+
581
+ dst = os.path.dirname(settings['count_data'][0])
582
+
583
+ closest_threshold = find_and_visualize_fraction_threshold(df, settings['target_unique_count'], log_x=settings['log_x'], log_y=settings['log_y'], dst=dst)
584
+
585
+ # Apply the closest threshold to the DataFrame
586
+ df = df[df['fraction'] >= closest_threshold]
587
+
588
+ # Group by 'plate', 'row', 'column' and compute unique counts of 'grna'
589
+ unique_counts = df.groupby(['plate', 'row', 'column'])['grna'].nunique().reset_index(name='unique_counts')
590
+ unique_count_mean = df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
591
+ unique_count_std = df.groupby(['plate', 'row', 'column'])['grna'].nunique().std()
592
+
593
+ # Merge the unique counts back into the original DataFrame
594
+ df = pd.merge(df, unique_counts, on=['plate', 'row', 'column'], how='left')
595
+
596
+ print(f"unique_count mean: {unique_count_mean} std: {unique_count_std}")
597
+
598
+ #_plot_density(df, dependent_variable='unique_counts')
599
+ plot_plates(df=df, variable='unique_counts', grouping='mean', min_max='allq', cmap='viridis',min_count=0, verbose=True, dst=dst)
600
+
601
+ return closest_threshold
spacr/settings.py CHANGED
@@ -550,6 +550,7 @@ def get_perform_regression_default_settings(settings):
550
550
  settings.setdefault('plate','plate1')
551
551
  settings.setdefault('class_1_threshold',None)
552
552
  settings.setdefault('metadata_files',['/home/carruthers/Documents/TGME49_Summary.csv','/home/carruthers/Documents/TGGT1_Summary.csv'])
553
+ settings.setdefault('volcano','gene')
553
554
  settings.setdefault('toxo', True)
554
555
 
555
556
  if settings['regression_type'] == 'quantile':
spacr/toxo.py CHANGED
@@ -6,15 +6,17 @@ from adjustText import adjust_text
6
6
  import pandas as pd
7
7
  from scipy.stats import fisher_exact
8
8
  from IPython.display import display
9
+ from matplotlib.legend import Legend
9
10
 
10
- def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, split_axis_lims = [10, None, None, 10], save_path=None):
11
+ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location', point_size=50, figsize=20, threshold=0, split_axis_lims = [10, None, None, 10], save_path=None, x_lim=[-0.5, 0.5]):
11
12
  """
12
13
  Create a volcano plot with the ability to control the shape of points based on a categorical column,
13
14
  color points based on a condition, annotate specific points based on p-value and coefficient thresholds,
14
15
  and control the size of points.
15
16
  """
16
17
  volcano_path = save_path
17
-
18
+ padd = 30
19
+ fontsize = 18
18
20
  # Load the data
19
21
  if isinstance(data_path, pd.DataFrame):
20
22
  data = data_path
@@ -42,15 +44,13 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
42
44
  merged_data.loc[merged_data['gene_nr'].str.startswith('4'), metadata_column] = 'GT1_gene'
43
45
  merged_data.loc[merged_data['gene_nr'] == 'Intercept', metadata_column] = 'Intercept'
44
46
  merged_data.loc[merged_data['condition'] == 'control', metadata_column] = 'control'
47
+ merged_data[metadata_column].fillna('unknown', inplace=True)
45
48
 
46
49
  # Categorize condition for coloring
47
50
  merged_data['condition'] = pd.Categorical(
48
51
  merged_data['condition'],
49
52
  categories=['other','pc', 'nc', 'control'],
50
53
  ordered=True)
51
-
52
-
53
- display(merged_data)
54
54
 
55
55
  # Create subplots with a broken y-axis
56
56
  figsize_2 = figsize / 2
@@ -65,19 +65,19 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
65
65
  'nc': 'green',
66
66
  'control': 'white',
67
67
  'other': 'gray'}
68
-
69
- # Scatter plot on both axes
68
+
69
+ # Scatter plot on both axes with legend completely disabled
70
70
  sns.scatterplot(
71
71
  data=merged_data,
72
72
  x='coefficient',
73
73
  y='-log10(p_value)',
74
- hue='condition', # Keep colors but prevent them from showing in the final legend
75
- style=metadata_column if metadata_column else None, # Shape-based legend
74
+ hue='condition',
75
+ style=metadata_column if metadata_column else None,
76
76
  s=point_size,
77
- edgecolor='black',
77
+ edgecolor='black',
78
78
  palette=palette,
79
- legend='brief', # Capture the full legend initially
80
- alpha=0.8,
79
+ legend=False, # Disable automatic legend
80
+ alpha=0.6,
81
81
  ax=ax2 # Lower plot
82
82
  )
83
83
 
@@ -88,13 +88,41 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
88
88
  hue='condition',
89
89
  style=metadata_column if metadata_column else None,
90
90
  s=point_size,
91
+ edgecolor='black',
91
92
  palette=palette,
92
- edgecolor='black',
93
- legend=False, # Suppress legend for upper plot
94
- alpha=0.8,
93
+ legend=False, # No legend on the upper plot
94
+ alpha=0.6,
95
95
  ax=ax1 # Upper plot
96
96
  )
97
97
 
98
+ # Ensure no previous legends on ax1 or ax2
99
+ if ax1.get_legend() is not None:
100
+ ax1.get_legend().remove()
101
+
102
+ if ax2.get_legend() is not None:
103
+ ax2.get_legend().remove()
104
+
105
+ # Manually gather handles and labels from ax2 after plotting
106
+ handles, labels = ax2.get_legend_handles_labels()
107
+
108
+ # Debug: Print the captured handles and labels for verification
109
+ print(f"Handles: {handles}")
110
+ print(f"Labels: {labels}")
111
+
112
+ # Identify shape-based legend entries (skip color-based entries)
113
+ n_color_entries = len(set(merged_data['condition']))
114
+ shape_handles = handles[n_color_entries:]
115
+ shape_labels = labels[n_color_entries:]
116
+
117
+ # Create and add the legend with shape-based entries
118
+ legend = Legend(
119
+ ax2, shape_handles, shape_labels,
120
+ bbox_to_anchor=(1.05, 1), loc='upper left',
121
+ handletextpad=2.0, labelspacing=1.5, borderaxespad=1.0,
122
+ markerscale=2.0, prop={'size': 14}
123
+ )
124
+ ax2.add_artist(legend)
125
+
98
126
  if isinstance(split_axis_lims, list):
99
127
  if len(split_axis_lims) == 4:
100
128
  ylim_min_ax1 = split_axis_lims[0]
@@ -113,28 +141,15 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
113
141
  # Set axis limits and hide unnecessary parts
114
142
  ax1.set_ylim(ylim_min_ax1, ylim_max_ax1)
115
143
  ax2.set_ylim(0, ylim_max_ax2)
144
+
145
+ if x_lim != None:
146
+ ax1.set_xlim(x_lim)
147
+ ax2.set_xlim(x_lim)
148
+
116
149
  ax1.spines['bottom'].set_visible(False)
117
150
  ax2.spines['top'].set_visible(False)
118
151
  ax1.tick_params(labelbottom=False)
119
152
 
120
- if ax1.get_legend() is not None:
121
- ax1.legend_.remove()
122
- ax1.get_legend().remove() # Extract handles and labels from the legend
123
- handles, labels = ax2.get_legend_handles_labels()
124
-
125
- # Identify shape-based legend entries (skip color-based entries)
126
- shape_handles = handles[len(set(merged_data['condition'])):]
127
- shape_labels = labels[len(set(merged_data['condition'])):]
128
-
129
- # Set the legend with only shape-based entries
130
- ax2.legend(
131
- shape_handles,
132
- shape_labels,
133
- bbox_to_anchor=(1.05, 1),
134
- loc='upper left',
135
- borderaxespad=0.
136
- )
137
-
138
153
  ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
139
154
 
140
155
  # Add vertical threshold lines to both plots
@@ -152,18 +167,13 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
152
167
 
153
168
  for i, row in merged_data.iterrows():
154
169
  if row['p_value'] <= 0.05 and abs(row['coefficient']) >= abs(threshold):
155
- # Select the appropriate axis for the annotation
156
- #ax = ax1 if row['-log10(p_value)'] > 10 else ax2
157
-
158
170
  ax = ax1 if row['-log10(p_value)'] >= ax1.get_ylim()[0] else ax2
159
-
160
-
161
171
  # Create the annotation on the selected axis
162
172
  text = ax.text(
163
173
  row['coefficient'],
164
174
  -np.log10(row['p_value']),
165
175
  row['variable'],
166
- fontsize=8,
176
+ fontsize=fontsize,
167
177
  ha='center',
168
178
  va='bottom',
169
179
  )
@@ -175,8 +185,8 @@ def custom_volcano_plot(data_path, metadata_path, metadata_column='tagm_location
175
185
  texts_ax2.append(text)
176
186
 
177
187
  # Adjust text positions to avoid overlap for both axes
178
- adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1)
179
- adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2)
188
+ adjust_text(texts_ax1, arrowprops=dict(arrowstyle='-', color='black'), ax=ax1, expand_points=(padd, padd), fontsize=fontsize)
189
+ adjust_text(texts_ax2, arrowprops=dict(arrowstyle='-', color='black'), ax=ax2, expand_points=(padd, padd), fontsize=fontsize)
180
190
 
181
191
  # Move the legend outside the lower plot
182
192
  ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
spacr/utils.py CHANGED
@@ -5209,4 +5209,19 @@ def fill_holes_in_mask(mask):
5209
5209
  # Assign the original label back to the filled object
5210
5210
  filled_mask[filled_object] = i
5211
5211
 
5212
- return filled_mask
5212
+ return filled_mask
5213
+
5214
+ def correct_metadata_column_names(df):
5215
+ if 'plate_name' in df.columns:
5216
+ df = df.rename(columns={'plate_name': 'plate'})
5217
+ if 'column_name' in df.columns:
5218
+ df = df.rename(columns={'column_name': 'column'})
5219
+ if 'col' in df.columns:
5220
+ df = df.rename(columns={'col': 'column'})
5221
+ if 'row_name' in df.columns:
5222
+ df = df.rename(columns={'row_name': 'row'})
5223
+ if 'grna_name' in df.columns:
5224
+ df = df.rename(columns={'grna_name': 'grna'})
5225
+ if 'plate_row' in df.columns:
5226
+ df[['plate', 'row']] = df['plate_row'].str.split('_', expand=True)
5227
+ return df
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.3.46
3
+ Version: 0.3.47
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -18,16 +18,16 @@ spacr/io.py,sha256=1rIdJ_8dyn7W4D2zXjaOqlgyo_Y5Z7X86aRp4hNYWCU,144194
18
18
  spacr/logger.py,sha256=lJhTqt-_wfAunCPl93xE65Wr9Y1oIHJWaZMjunHUeIw,1538
19
19
  spacr/measure.py,sha256=KdboGXoi85BO5-_6er7932FgjFI7G7tuaQDnWSiEuew,54817
20
20
  spacr/mediar.py,sha256=FwLvbLQW5LQzPgvJZG8Lw7GniA2vbZx6Jv6vIKu7I5c,14743
21
- spacr/ml.py,sha256=bPcKVk1camnOhv8jQglj6EYyipAxxmiB1QJ2Fdo3dEM,50654
21
+ spacr/ml.py,sha256=Mkxl4n3OvNsVix8bbPQ09-HTgP3YTzQOESso4_exKLs,54634
22
22
  spacr/openai.py,sha256=5vBZ3Jl2llYcW3oaTEXgdyCB2aJujMUIO5K038z7w_A,1246
23
- spacr/plot.py,sha256=r4kbrMA8iQ317f0lvIDj4wJDIDwDXXYHEgGtFJrO3-k,145387
24
- spacr/sequencing.py,sha256=t18mgpK6rhWuB1LtFOsPxqgpFXxuUmrD06ecsaVQ0Gw,19655
25
- spacr/settings.py,sha256=3ygnAY6uLtkzFQdK8TMBbWV6zXEX-G_wV19YLyjCBeM,77668
23
+ spacr/plot.py,sha256=M04Cbv1n_FHxO0Qg3VNu_IQXRGt7lb-sgWzh9jjL4rI,145733
24
+ spacr/sequencing.py,sha256=aP2QUfb9wpeJRZFLWFwxq1o4EtVpSvyVohOcm3Wvrq0,24965
25
+ spacr/settings.py,sha256=XEpo9sQXmQ3-sdRNmcsss6q0j7ZAvoAFO-_D8ecgYQc,77710
26
26
  spacr/sim.py,sha256=1xKhXimNU3ukzIw-3l9cF3Znc_brW8h20yv8fSTzvss,71173
27
27
  spacr/submodules.py,sha256=3C5M4UbI9Ral1MX4PTpucaAaqhL3RADuCOCqaHhMyUg,28048
28
28
  spacr/timelapse.py,sha256=FSYpUtAVy6xc3lwprRYgyDTT9ysUhfRQ4zrP9_h2mvg,39465
29
- spacr/toxo.py,sha256=X62hKFcSzFhIxFYlhL2AZb0qNpvtjLs3y1HldReAQEY,12880
30
- spacr/utils.py,sha256=K36BxYr4GN956V4S7IkNty2sP4Y265WS7yMzAw8Tqeg,220451
29
+ spacr/toxo.py,sha256=RjAqI2sCcYYr-eiLPGnyJUn96zR_ATENuyiM2CZT408,13358
30
+ spacr/utils.py,sha256=zkgUP_w_w9HJe4000KhVmpwO2gELoeIvdYNrXlRAzG8,221050
31
31
  spacr/version.py,sha256=axH5tnGwtgSnJHb5IDhiu4Zjk5GhLyAEDRe-rnaoFOA,409
32
32
  spacr/resources/MEDIAR/.gitignore,sha256=Ff1q9Nme14JUd-4Q3jZ65aeQ5X4uttptssVDgBVHYo8,152
33
33
  spacr/resources/MEDIAR/LICENSE,sha256=yEj_TRDLUfDpHDNM0StALXIt6mLqSgaV2hcCwa6_TcY,1065
@@ -150,9 +150,9 @@ spacr/resources/icons/umap.png,sha256=dOLF3DeLYy9k0nkUybiZMe1wzHQwLJFRmgccppw-8b
150
150
  spacr/resources/images/plate1_E01_T0001F001L01A01Z01C02.tif,sha256=Tl0ZUfZ_AYAbu0up_nO0tPRtF1BxXhWQ3T3pURBCCRo,7958528
151
151
  spacr/resources/images/plate1_E01_T0001F001L01A02Z01C01.tif,sha256=m8N-V71rA1TT4dFlENNg8s0Q0YEXXs8slIn7yObmZJQ,7958528
152
152
  spacr/resources/images/plate1_E01_T0001F001L01A03Z01C03.tif,sha256=Pbhk7xn-KUP6RSIhJsxQcrHFImBm3GEpLkzx7WOc-5M,7958528
153
- spacr-0.3.46.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
154
- spacr-0.3.46.dist-info/METADATA,sha256=rDVd_7S8qknwKjW3gzWpaC4FvKLLArfmA3xqGlby088,5949
155
- spacr-0.3.46.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
156
- spacr-0.3.46.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
157
- spacr-0.3.46.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
158
- spacr-0.3.46.dist-info/RECORD,,
153
+ spacr-0.3.47.dist-info/LICENSE,sha256=SR-2MeGc6SCM1UORJYyarSWY_A-JaOMFDj7ReSs9tRM,1083
154
+ spacr-0.3.47.dist-info/METADATA,sha256=NEQNKKM40sYjqLkD1M-R4eOVTPSAQJ4zDD_5juTElbk,5949
155
+ spacr-0.3.47.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
156
+ spacr-0.3.47.dist-info/entry_points.txt,sha256=BMC0ql9aNNpv8lUZ8sgDLQMsqaVnX5L535gEhKUP5ho,296
157
+ spacr-0.3.47.dist-info/top_level.txt,sha256=GJPU8FgwRXGzKeut6JopsSRY2R8T3i9lDgya42tLInY,6
158
+ spacr-0.3.47.dist-info/RECORD,,
File without changes