spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +316 -48
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +134 -47
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +419 -180
  27. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math, torch
1
+ import os, random, cv2, glob, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -14,11 +14,15 @@ from skimage.segmentation import find_boundaries
14
14
  from skimage import measure
15
15
  from skimage.measure import find_contours, label, regionprops
16
16
 
17
+ from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
+ from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
+ from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
20
+ import itertools
21
+ import pingouin as pg
22
+
17
23
  from ipywidgets import IntSlider, interact
18
24
  from IPython.display import Image as ipyimage
19
25
 
20
- from .logger import log_function_call
21
-
22
26
  def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
23
27
  """Plot image and mask overlays."""
24
28
 
@@ -409,7 +413,7 @@ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, thr
409
413
  plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile, overlay, save=False)
410
414
  return
411
415
 
412
- def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
416
+ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, nuclei_limit, pathogen_limit):
413
417
  """
414
418
  Filters objects in a plot based on various criteria.
415
419
 
@@ -420,8 +424,8 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
420
424
  pathogen_mask_dim (int): The dimension index of the pathogen mask.
421
425
  mask_dims (list): A list of dimension indices for additional masks.
422
426
  filter_min_max (list): A list of minimum and maximum area values for each mask.
423
- include_multinucleated (bool): Whether to include multinucleated cells.
424
- include_multiinfected (bool): Whether to include multiinfected cells.
427
+ nuclei_limit (bool): Whether to include multinucleated cells.
428
+ pathogen_limit (bool): Whether to include multiinfected cells.
425
429
 
426
430
  Returns:
427
431
  numpy.ndarray: The filtered stack of masks.
@@ -451,9 +455,9 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
451
455
  total_count_after = len(props_after['label'])
452
456
 
453
457
  if mask_dim == cell_mask_dim:
454
- if include_multinucleated is False and nucleus_mask_dim is not None:
458
+ if nuclei_limit is False and nucleus_mask_dim is not None:
455
459
  stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
456
- if include_multiinfected is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
460
+ if pathogen_limit is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
457
461
  stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
458
462
  cell_area_before = avg_size_before
459
463
  cell_count_before = total_count_before
@@ -662,18 +666,18 @@ def plot_merged(src, settings):
662
666
  display(settings)
663
667
 
664
668
  if settings['pathogen_mask_dim'] is None:
665
- settings['include_multiinfected'] = True
669
+ settings['pathogen_limit'] = True
666
670
 
667
671
  for file in os.listdir(src):
668
672
  path = os.path.join(src, file)
669
673
  stack = np.load(path)
670
674
  print(f'Loaded: {path}')
671
- if not settings['include_noninfected']:
675
+ if not settings['uninfected']:
672
676
  if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
673
677
  stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
674
678
 
675
- if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
676
- stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
679
+ if settings['pathogen_limit'] is not True or settings['nuclei_limit'] is not True or settings['filter_min_max'] is not None:
680
+ stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['nuclei_limit'], settings['pathogen_limit'])
677
681
 
678
682
  overlayed_image, image, outlines = _normalize_and_outline(image=stack,
679
683
  remove_background=settings['remove_background'],
@@ -999,8 +1003,107 @@ def _display_gif(path):
999
1003
  """
1000
1004
  with open(path, 'rb') as file:
1001
1005
  display(ipyimage(file.read()))
1006
+
1007
+ def _plot_recruitment_v2(df, df_type, channel_of_interest, columns=[], figuresize=10):
1008
+ """
1009
+ Plot recruitment data for different conditions and pathogens.
1010
+
1011
+ Args:
1012
+ df (DataFrame): The input DataFrame containing the recruitment data.
1013
+ df_type (str): The type of DataFrame (e.g., 'train', 'test').
1014
+ channel_of_interest (str): The channel of interest for plotting.
1015
+ target (str): The target variable for plotting.
1016
+ columns (list, optional): Additional columns to plot. Defaults to an empty list.
1017
+ figuresize (int, optional): The size of the figure. Defaults to 50.
1018
+
1019
+ Returns:
1020
+ None
1021
+ """
1022
+
1023
+ from .plot import spacrGraph
1024
+
1025
+ color_list = [(55/255, 155/255, 155/255),
1026
+ (155/255, 55/255, 155/255),
1027
+ (55/255, 155/255, 255/255),
1028
+ (255/255, 55/255, 155/255)]
1029
+
1030
+ sns.set_palette(sns.color_palette(color_list))
1031
+ font = figuresize/2
1032
+ width=figuresize
1033
+ height=figuresize/4
1034
+
1035
+ # Create the subplots
1036
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
1037
+
1038
+ # Plot for 'cell_channel' on axes[0]
1039
+ plotter_cell = spacrGraph(df,grouping_column='condition', data_column=f'cell_channel_{channel_of_interest}_mean_intensity')
1040
+ plotter_cell.create_plot(ax=axes[0])
1041
+ axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
1042
+ axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1043
+
1044
+ # Plot for 'nucleus_channel' on axes[1]
1045
+ plotter_nucleus = spacrGraph(df,grouping_column='condition', data_column=f'nucleus_channel_{channel_of_interest}_mean_intensity')
1046
+ plotter_nucleus.create_plot(ax=axes[1])
1047
+ axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
1048
+ axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1049
+
1050
+ # Plot for 'cytoplasm_channel' on axes[2]
1051
+ plotter_cytoplasm = spacrGraph(df, grouping_column='condition', data_column=f'cytoplasm_channel_{channel_of_interest}_mean_intensity')
1052
+ plotter_cytoplasm.create_plot(ax=axes[2])
1053
+ axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
1054
+ axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1055
+
1056
+ # Plot for 'pathogen_channel' on axes[3]
1057
+ plotter_pathogen = spacrGraph(df, grouping_column='condition', data_column=f'pathogen_channel_{channel_of_interest}_mean_intensity')
1058
+ plotter_pathogen.create_plot(ax=axes[3])
1059
+ axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
1060
+ axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
1061
+
1062
+ #axes[0].legend_.remove()
1063
+ #axes[1].legend_.remove()
1064
+ #axes[2].legend_.remove()
1065
+ #axes[3].legend_.remove()
1066
+
1067
+ handles, labels = axes[3].get_legend_handles_labels()
1068
+ axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
1069
+ for i in [0,1,2,3]:
1070
+ axes[i].tick_params(axis='both', which='major', labelsize=font)
1071
+ axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
1072
+
1073
+ plt.tight_layout()
1074
+ plt.show()
1075
+
1076
+ columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
1077
+ #columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
1078
+
1079
+ width = figuresize*2
1080
+ columns_per_row = math.ceil(len(columns) / 2)
1081
+ height = (figuresize*2)/columns_per_row
1082
+
1083
+ fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
1084
+ axes = axes.flatten()
1085
+
1086
+ print(f'{columns}')
1087
+ for i, col in enumerate(columns):
1088
+ ax = axes[i]
1089
+ plotter_col = spacrGraph(df, grouping_column='condition', data_column=col)
1090
+ plotter_col.create_plot(ax=ax)
1091
+ ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
1092
+ ax.set_ylabel(f'{col}', fontsize=int(font * 2))
1093
+ #ax.legend_.remove()
1094
+ ax.tick_params(axis='both', which='major', labelsize=font)
1095
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
1096
+ if i <= 5:
1097
+ ax.set_ylim(1, None)
1098
+
1099
+ # Turn off any unused axes
1100
+ for i in range(len(columns), len(axes)):
1101
+ axes[i].axis('off')
1102
+
1103
+ plt.tight_layout()
1104
+ plt.show()
1002
1105
 
1003
- def _plot_recruitment(df, df_type, channel_of_interest, target, columns=[], figuresize=10):
1106
+ def _plot_recruitment(df, df_type, channel_of_interest, columns=[], figuresize=10):
1004
1107
  """
1005
1108
  Plot recruitment data for different conditions and pathogens.
1006
1109
 
@@ -1153,10 +1256,6 @@ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
1153
1256
  plt.tight_layout()
1154
1257
  plt.show()
1155
1258
 
1156
- ###################################################
1157
- # Classify
1158
- ###################################################
1159
-
1160
1259
  def _imshow(img, labels, nrow=20, color='white', fontsize=12):
1161
1260
  """
1162
1261
  Display multiple images in a grid with corresponding labels.
@@ -1359,7 +1458,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1359
1458
 
1360
1459
  return plate_map, min_max
1361
1460
 
1362
- def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
1461
+ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
1363
1462
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1364
1463
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1365
1464
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
@@ -1374,6 +1473,12 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
1374
1473
  fig.delaxes(ax[i])
1375
1474
 
1376
1475
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1476
+
1477
+ if not dst is None:
1478
+ filename = os.path.join(dst, 'plate_heatmap.pdf')
1479
+ fig.savefig(filename, format='pdf')
1480
+ print(f'Saved heatmap to {filename}')
1481
+
1377
1482
  if verbose:
1378
1483
  plt.show()
1379
1484
  return fig
@@ -1605,13 +1710,19 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1605
1710
  print(f'Saved Volcano plot: {filename}')
1606
1711
  plt.show()
1607
1712
 
1608
- def plot_histogram(df, dependent_variable):
1713
+ def plot_histogram(df, dependent_variable, dst=None):
1609
1714
  # Plot histogram of the dependent variable
1610
1715
  plt.figure(figsize=(10, 6))
1611
1716
  sns.histplot(df[dependent_variable], kde=True)
1612
1717
  plt.title(f'Histogram of {dependent_variable}')
1613
1718
  plt.xlabel(dependent_variable)
1614
1719
  plt.ylabel('Frequency')
1720
+
1721
+ if not dst is None:
1722
+ filename = os.path.join(dst, 'dependent_variable_histogram.pdf')
1723
+ plt.savefig(filename, format='pdf')
1724
+ print(f'Saved histogram to {filename}')
1725
+
1615
1726
  plt.show()
1616
1727
 
1617
1728
  def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
@@ -1732,4 +1843,580 @@ def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time
1732
1843
  plt.ylim(y_lim)
1733
1844
  plt.show()
1734
1845
  else:
1735
- print("No CSV files found in the specified directory.")
1846
+ print("No CSV files found in the specified directory.")
1847
+
1848
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
1849
+ """
1850
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
1851
+
1852
+ Args:
1853
+ src (str): Path to the source data.
1854
+ x_column (str): Name of the column to be used for the x-axis.
1855
+ y_column (str): Name of the column to be used for the y-axis.
1856
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
1857
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
1858
+
1859
+ Returns:
1860
+ pd.DataFrame: The filtered and balanced DataFrame.
1861
+ """
1862
+
1863
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
1864
+ from .io import _read_and_merge_data, _read_db
1865
+ db_loc = [src+'/measurements/measurements.db']
1866
+ loc = src+'/measurements/measurements.db'
1867
+ df, _ = _read_and_merge_data(db_loc,
1868
+ tables,
1869
+ verbose=True,
1870
+ nuclei_limit=True,
1871
+ pathogen_limit=True,
1872
+ uninfected=True)
1873
+ paths_df = _read_db(loc, tables=['png_list'])
1874
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
1875
+ return merged_df
1876
+
1877
+ # Read the CSV file into a DataFrame
1878
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
1879
+
1880
+ # Print column names for debugging
1881
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
1882
+ #print("Columns in DataFrame:", df.columns.tolist())
1883
+
1884
+ # Replace NaN values with a specific label in x_column
1885
+ df[x_column] = df[x_column].fillna('NaN')
1886
+
1887
+ # Filter the DataFrame if filter_column and filter_values are provided
1888
+ if not filter_column is None:
1889
+ if isinstance(filter_column, str):
1890
+ df = df[df[filter_column].isin(filter_values)]
1891
+ if isinstance(filter_column, list):
1892
+ for i,val in enumerate(filter_column):
1893
+ print(f'hello {len(df)}')
1894
+ df = df[df[val].isin(filter_values[i])]
1895
+
1896
+ # Use the correct column names based on your DataFrame
1897
+ required_columns = ['plate_x', 'row_x', 'col_x']
1898
+ if not all(column in df.columns for column in required_columns):
1899
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
1900
+
1901
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
1902
+ non_nan_df = df[df[x_column] != 'NaN']
1903
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
1904
+
1905
+ # Determine the minimum count of examples across all groups in x_column
1906
+ min_count = retained_rows[x_column].value_counts().min()
1907
+ print(f'Found {min_count} annotated images')
1908
+
1909
+ # Randomly sample min_count examples from each group in x_column
1910
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
1911
+
1912
+ # Create the jitter plot
1913
+ plt.figure(figsize=(10, 6))
1914
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
1915
+ plt.title(plot_title)
1916
+ plt.xlabel(x_column)
1917
+ plt.ylabel(y_column)
1918
+
1919
+ # Customize the x-axis labels
1920
+ plt.xticks(rotation=45, ha='right')
1921
+
1922
+ # Adjust the position of the x-axis labels to be centered below the data
1923
+ ax = plt.gca()
1924
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
1925
+
1926
+ # Save the plot to a file or display it
1927
+ if output_path:
1928
+ plt.savefig(output_path, bbox_inches='tight')
1929
+ print(f"Jitter plot saved to {output_path}")
1930
+ else:
1931
+ plt.show()
1932
+
1933
+ return balanced_df
1934
+
1935
+ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_axis_start=None, error_bar_type='std'):
1936
+ """
1937
+ Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
1938
+
1939
+ Parameters:
1940
+ - df: DataFrame containing the data.
1941
+ - grouping_column: Column name for the categorical grouping.
1942
+ - data_column: Column name for the data to be grouped and plotted.
1943
+ - graph_type: Type of plot ('bar', 'violin', 'jitter', 'box', 'jitter_box').
1944
+ - summary_func: Summary function to apply to each group ('mean', 'median', etc.).
1945
+ - order: List specifying the order of the groups. If None, groups will be ordered alphabetically.
1946
+ - colors: List of colors for each group.
1947
+ - output_dir: Directory where the figure and test results will be saved if `save=True`.
1948
+ - save: Boolean flag indicating whether to save the plot and results to files.
1949
+ - y_axis_start: Optional starting value for the y-axis.
1950
+ - error_bar_type: Type of error bars to plot, either 'std' for standard deviation or 'sem' for standard error of the mean.
1951
+
1952
+ Outputs:
1953
+ - Figure of the plot.
1954
+ - DataFrame with full statistical test results, including normality tests.
1955
+ """
1956
+
1957
+ # Remove NaN rows in grouping_column
1958
+ df = df.dropna(subset=[grouping_column])
1959
+
1960
+ # Ensure the output directory exists if save is True
1961
+ if save:
1962
+ os.makedirs(output_dir, exist_ok=True)
1963
+
1964
+ # Sorting and ordering
1965
+ if order:
1966
+ df[grouping_column] = pd.Categorical(df[grouping_column], categories=order, ordered=True)
1967
+ else:
1968
+ df[grouping_column] = pd.Categorical(df[grouping_column], categories=sorted(df[grouping_column].unique()), ordered=True)
1969
+
1970
+ # Get unique groups
1971
+ unique_groups = df[grouping_column].unique()
1972
+
1973
+ # Initialize test results
1974
+ test_results = []
1975
+
1976
+ # Test normality for each group
1977
+ grouped_data = [df.loc[df[grouping_column] == group, data_column] for group in unique_groups]
1978
+ normal_p_values = [normaltest(data).pvalue for data in grouped_data]
1979
+ normal_stats = [normaltest(data).statistic for data in grouped_data]
1980
+ is_normal = all(p > 0.05 for p in normal_p_values)
1981
+
1982
+ # Add normality test results to the results_df
1983
+ for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
1984
+ test_results.append({
1985
+ 'Comparison': f'Normality test for {group}',
1986
+ 'Test Statistic': stat,
1987
+ 'p-value': p_value,
1988
+ 'Test Name': 'Normality test'
1989
+ })
1990
+
1991
+ # Determine statistical test
1992
+ if len(unique_groups) == 2:
1993
+ if is_normal:
1994
+ stat_test = ttest_ind
1995
+ test_name = 'T-test'
1996
+ else:
1997
+ stat_test = mannwhitneyu
1998
+ test_name = 'Mann-Whitney U test'
1999
+ else:
2000
+ if is_normal:
2001
+ stat_test = f_oneway
2002
+ test_name = 'One-way ANOVA'
2003
+ else:
2004
+ stat_test = kruskal
2005
+ test_name = 'Kruskal-Wallis test'
2006
+
2007
+ # Perform pairwise statistical tests
2008
+ comparisons = list(itertools.combinations(unique_groups, 2))
2009
+ p_values = []
2010
+ test_statistics = []
2011
+
2012
+ for (group1, group2) in comparisons:
2013
+ data1 = df[df[grouping_column] == group1][data_column]
2014
+ data2 = df[df[grouping_column] == group2][data_column]
2015
+ stat, p = stat_test(data1, data2)
2016
+ p_values.append(p)
2017
+ test_statistics.append(stat)
2018
+ test_results.append({'Comparison': f'{group1} vs {group2}', 'Test Statistic': stat, 'p-value': p, 'Test Name': test_name})
2019
+
2020
+ # Post-hoc test (Tukey HSD for ANOVA)
2021
+ posthoc_p_values = None
2022
+ if is_normal and len(unique_groups) > 2:
2023
+ tukey_result = pairwise_tukeyhsd(df[data_column], df[grouping_column], alpha=0.05)
2024
+ posthoc_p_values = tukey_result.pvalues
2025
+ for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
2026
+ test_results.append({
2027
+ 'Comparison': f'{comparison[0]} vs {comparison[1]}',
2028
+ 'Test Statistic': None, # Tukey does not provide a test statistic in the same way
2029
+ 'p-value': p_value,
2030
+ 'Test Name': 'Tukey HSD Post-hoc'
2031
+ })
2032
+
2033
+ # Create plot
2034
+ plt.figure(figsize=(10, 6))
2035
+ sns.set(style="whitegrid")
2036
+
2037
+ if colors:
2038
+ color_palette = colors
2039
+ else:
2040
+ color_palette = sns.color_palette("husl", len(unique_groups))
2041
+
2042
+ # Choose graph type
2043
+ if graph_type == 'bar':
2044
+ summary_df = df.groupby(grouping_column)[data_column].agg([summary_func, 'std', 'sem'])
2045
+
2046
+ # Set error bars based on error_bar_type
2047
+ if error_bar_type == 'std':
2048
+ error_bars = summary_df['std']
2049
+ elif error_bar_type == 'sem':
2050
+ error_bars = summary_df['sem']
2051
+ else:
2052
+ raise ValueError(f"Invalid error_bar_type: {error_bar_type}. Choose either 'std' or 'sem'.")
2053
+
2054
+ sns.barplot(x=grouping_column, y=summary_func, data=summary_df.reset_index(), ci=None, order=order, palette=color_palette)
2055
+
2056
+ # Add error bars (standard deviation or standard error of the mean)
2057
+ plt.errorbar(x=np.arange(len(summary_df)), y=summary_df[summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
2058
+
2059
+ elif graph_type == 'violin':
2060
+ sns.violinplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
2061
+ elif graph_type == 'jitter':
2062
+ sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, order=order, palette=color_palette)
2063
+ elif graph_type == 'box':
2064
+ sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
2065
+ elif graph_type == 'jitter_box':
2066
+ sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
2067
+ sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, color='black', alpha=0.5, order=order)
2068
+
2069
+ # Create a DataFrame to summarize the test results
2070
+ results_df = pd.DataFrame(test_results)
2071
+
2072
+ # Set y-axis start if provided
2073
+ if y_axis_start is not None:
2074
+ plt.ylim(bottom=y_axis_start)
2075
+ else:
2076
+ plt.ylim(0, None) # Default to starting at 0 if no custom start value is provided
2077
+
2078
+ # If save is True, save the plot and results as PNG and CSV
2079
+ if save:
2080
+ # Save the plot as PNG
2081
+ plot_path = os.path.join(output_dir, 'grouped_plot.png')
2082
+ plt.title(f'{test_name} results for {graph_type} plot')
2083
+ plt.xticks(rotation=45)
2084
+ plt.tight_layout()
2085
+ plt.savefig(plot_path)
2086
+ print(f"Plot saved to {plot_path}")
2087
+
2088
+ # Save the test results as a CSV file
2089
+ results_path = os.path.join(output_dir, 'test_results.csv')
2090
+ results_df.to_csv(results_path, index=False)
2091
+ print(f"Test results saved to {results_path}")
2092
+
2093
+ # Show the plot
2094
+ plt.show()
2095
+
2096
+ return plt.gcf(), results_df
2097
+
2098
+ class spacrGraph:
2099
+ def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
2100
+ order=None, colors=None, output_dir='./output', save=False, y_axis_start=None,
2101
+ error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
2102
+ paired=False, all_to_all=True, compare_group=None):
2103
+ """
2104
+ Class for creating grouped plots with optional statistical tests and data preprocessing.
2105
+ """
2106
+ self.df = df
2107
+ self.grouping_column = grouping_column
2108
+ self.data_column = data_column
2109
+ self.graph_type = graph_type
2110
+ self.summary_func = summary_func
2111
+ self.order = order
2112
+ self.colors = colors
2113
+ self.output_dir = output_dir
2114
+ self.save = save
2115
+ self.y_axis_start = y_axis_start
2116
+ self.error_bar_type = error_bar_type
2117
+ self.remove_outliers = remove_outliers
2118
+ self.theme = theme
2119
+ self.representation = representation
2120
+ self.paired = paired
2121
+ self.all_to_all = all_to_all
2122
+ self.compare_group = compare_group
2123
+
2124
+ self.results_df = pd.DataFrame()
2125
+ self.sns_palette = None
2126
+ self.fig = None # To store the generated figure
2127
+
2128
+ # Preprocess and set palette
2129
+ self._set_theme()
2130
+ self.raw_df = self.df.copy() # Preserve the raw data for n_object count
2131
+ self.df = self.preprocess_data()
2132
+
2133
+ def _set_theme(self):
2134
+ """Set the Seaborn theme and reorder colors if necessary."""
2135
+ integer_list = list(range(1, 81))
2136
+ color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
2137
+ self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
2138
+
2139
+ def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
2140
+ """Set and reorder the Seaborn color palette."""
2141
+ palette = sns.color_palette(theme, n_colors)
2142
+ if order:
2143
+ reordered_palette = [palette[i] for i in order]
2144
+ else:
2145
+ reordered_palette = palette
2146
+ if show_theme:
2147
+ sns.palplot(reordered_palette)
2148
+ plt.show()
2149
+ return reordered_palette
2150
+
2151
+ def preprocess_data(self):
2152
+ """Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
2153
+ df = self.df.dropna(subset=[self.grouping_column, self.data_column])
2154
+
2155
+ # Group by 'prc' column if representation is 'well'
2156
+ if self.representation == 'well':
2157
+ df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
2158
+
2159
+ if self.order:
2160
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
2161
+ else:
2162
+ df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
2163
+
2164
+ return df
2165
+
2166
+ def remove_outliers_from_plot(self):
2167
+ """Remove outliers from the plot but keep them in the data."""
2168
+ filtered_df = self.df.copy()
2169
+ unique_groups = filtered_df[self.grouping_column].unique()
2170
+ for group in unique_groups:
2171
+ group_data = filtered_df[filtered_df[self.grouping_column] == group][self.data_column]
2172
+ q1 = group_data.quantile(0.25)
2173
+ q3 = group_data.quantile(0.75)
2174
+ iqr = q3 - q1
2175
+ lower_bound = q1 - 1.5 * iqr
2176
+ upper_bound = q3 + 1.5 * iqr
2177
+ filtered_df = filtered_df.drop(filtered_df[(filtered_df[self.grouping_column] == group) & ((filtered_df[self.data_column] < lower_bound) | (filtered_df[self.data_column] > upper_bound))].index)
2178
+ return filtered_df
2179
+
2180
+ def perform_normality_tests(self):
2181
+ """Perform normality tests for each group."""
2182
+ unique_groups = self.df[self.grouping_column].unique()
2183
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2184
+ raw_grouped_data = [self.raw_df.loc[self.raw_df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2185
+
2186
+ normal_p_values = [normaltest(data).pvalue for data in grouped_data]
2187
+ normal_stats = [normaltest(data).statistic for data in grouped_data]
2188
+ is_normal = all(p > 0.05 for p in normal_p_values)
2189
+
2190
+ test_results = []
2191
+ for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
2192
+ test_results.append({
2193
+ 'Comparison': f'Normality test for {group}',
2194
+ 'Test Statistic': stat,
2195
+ 'p-value': p_value,
2196
+ 'Test Name': 'Normality test',
2197
+ 'n_object': len(raw_grouped_data[unique_groups.tolist().index(group)]), # Raw sample size (objects/cells)
2198
+ 'n_well': len(grouped_data[unique_groups.tolist().index(group)]) if self.representation == 'well' else np.nan # Summarized size (wells)
2199
+ })
2200
+ return is_normal, test_results
2201
+
2202
+ def perform_levene_test(self, unique_groups):
2203
+ """Perform Levene's test for equal variance."""
2204
+ grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
2205
+ stat, p_value = levene(*grouped_data)
2206
+ return stat, p_value
2207
+
2208
+ def perform_statistical_tests(self, unique_groups, is_normal):
2209
+ """Perform statistical tests based on the number of groups, normality, and paired flag."""
2210
+ if len(unique_groups) == 2:
2211
+ if is_normal:
2212
+ if self.paired:
2213
+ stat_test = pg.ttest # Paired T-test
2214
+ test_name = 'Paired T-test'
2215
+ else:
2216
+ stat_test = ttest_ind
2217
+ test_name = 'T-test'
2218
+ else:
2219
+ if self.paired:
2220
+ stat_test = pg.wilcoxon # Paired Wilcoxon test
2221
+ test_name = 'Paired Wilcoxon test'
2222
+ else:
2223
+ stat_test = mannwhitneyu
2224
+ test_name = 'Mann-Whitney U test'
2225
+ else:
2226
+ if is_normal:
2227
+ stat_test = f_oneway
2228
+ test_name = 'One-way ANOVA'
2229
+ else:
2230
+ stat_test = kruskal
2231
+ test_name = 'Kruskal-Wallis test'
2232
+
2233
+ comparisons = list(itertools.combinations(unique_groups, 2))
2234
+ test_results = []
2235
+ for (group1, group2) in comparisons:
2236
+ data1 = self.df[self.df[self.grouping_column] == group1][self.data_column]
2237
+ data2 = self.df[self.df[self.grouping_column] == group2][self.data_column]
2238
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == group1][self.data_column]
2239
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == group2][self.data_column]
2240
+
2241
+ if self.paired:
2242
+ stat, p = stat_test(data1, data2, paired=True)
2243
+ else:
2244
+ stat, p = stat_test(data1, data2)
2245
+
2246
+ test_results.append({
2247
+ 'Comparison': f'{group1} vs {group2}',
2248
+ 'Test Statistic': stat,
2249
+ 'p-value': p,
2250
+ 'Test Name': test_name,
2251
+ 'n_object': len(raw_data1) + len(raw_data2), # Raw sample size (objects/cells)
2252
+ 'n_well': len(data1) + len(data2) if self.representation == 'well' else np.nan # Summarized size (wells)
2253
+ })
2254
+ return test_results
2255
+
2256
+ def perform_posthoc_tests(self, is_normal, unique_groups):
2257
+ """Perform post-hoc tests for multiple groups based on all_to_all flag."""
2258
+ if is_normal and len(unique_groups) > 2 and self.all_to_all:
2259
+ # Tukey HSD Post-hoc when comparing all to all
2260
+ tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
2261
+ posthoc_results = []
2262
+ for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
2263
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == comparison[0]][self.data_column]
2264
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == comparison[1]][self.data_column]
2265
+
2266
+ posthoc_results.append({
2267
+ 'Comparison': f'{comparison[0]} vs {comparison[1]}',
2268
+ 'Test Statistic': None, # Tukey does not provide a test statistic
2269
+ 'p-value': p_value,
2270
+ 'Test Name': 'Tukey HSD Post-hoc',
2271
+ 'n_object': len(raw_data1) + len(raw_data2),
2272
+ 'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])
2273
+ })
2274
+ return posthoc_results
2275
+
2276
+ elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
2277
+ # Dunn's post-hoc test using Pingouin
2278
+ dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
2279
+ posthoc_results = []
2280
+ for idx, row in dunn_result.iterrows():
2281
+ if row['A'] == self.compare_group or row['B'] == self.compare_group:
2282
+ posthoc_results.append({
2283
+ 'Comparison': f"{row['A']} vs {row['B']}",
2284
+ 'Test Statistic': row['T'], # Test statistic from Dunn's test
2285
+ 'p-value': row['p-val'],
2286
+ 'Test Name': 'Dunn’s Post-hoc',
2287
+ 'n_object': None,
2288
+ 'n_well': None
2289
+ })
2290
+ return posthoc_results
2291
+ return []
2292
+
2293
+ def create_plot(self, ax=None):
2294
+ """Create and display the plot based on the chosen graph type."""
2295
+ # Optional: Remove outliers for plotting
2296
+ if self.remove_outliers:
2297
+ self.df = self.remove_outliers_from_plot()
2298
+
2299
+ # Perform normality tests
2300
+ is_normal, normality_results = self.perform_normality_tests()
2301
+
2302
+ # Perform Levene's test for equal variance
2303
+ unique_groups = self.df[self.grouping_column].unique()
2304
+ levene_stat, levene_p = self.perform_levene_test(unique_groups)
2305
+ levene_result = {
2306
+ 'Comparison': 'Levene’s test for equal variance',
2307
+ 'Test Statistic': levene_stat,
2308
+ 'p-value': levene_p,
2309
+ 'Test Name': 'Levene’s Test'
2310
+ }
2311
+
2312
+ # Perform statistical tests
2313
+ stat_results = self.perform_statistical_tests(unique_groups, is_normal)
2314
+
2315
+ # Perform post-hoc tests if applicable
2316
+ posthoc_results = self.perform_posthoc_tests(is_normal, unique_groups)
2317
+
2318
+ # Combine all test results
2319
+ self.results_df = pd.DataFrame(normality_results + [levene_result] + stat_results + posthoc_results)
2320
+
2321
+ # Add sample size column
2322
+ sample_sizes = self.df.groupby(self.grouping_column)[self.data_column].count().reset_index(name='n')
2323
+ self.results_df['n'] = self.results_df['Comparison'].apply(
2324
+ lambda x: next((sample_sizes[sample_sizes[self.grouping_column] == g]['n'].values[0] for g in sample_sizes[self.grouping_column] if g in x), np.nan)
2325
+ )
2326
+
2327
+ # Dynamically set figure dimensions based on the number of unique groups
2328
+ num_groups = len(unique_groups)
2329
+ bar_width = 0.6 # Set the desired thickness of each bar
2330
+ spacing_between_groups = 0.3 # Set the desired spacing between bars and axis
2331
+
2332
+ fig_width = num_groups * (bar_width + spacing_between_groups) # Dynamically calculate the figure width
2333
+ fig_height = 6 # Fixed height for the plot
2334
+
2335
+ if ax is None:
2336
+ self.fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Store the figure in self.fig
2337
+ else:
2338
+ self.fig = ax.figure # Store the figure if ax is provided
2339
+
2340
+ sns.set(style="ticks")
2341
+ color_palette = self.sns_palette if not self.colors else self.colors
2342
+
2343
+ # Calculate x-axis limits to ensure equal space between the bars and the y-axis
2344
+ xlim_lower = -0.5 # Ensures space between the y-axis and the first category
2345
+ xlim_upper = num_groups - 0.5 # Ensures space after the last category
2346
+ ax.set_xlim(xlim_lower, xlim_upper)
2347
+
2348
+ if self.summary_func is None:
2349
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=color_palette, jitter=True, alpha=0.6, ax=ax)
2350
+ elif self.graph_type == 'bar':
2351
+ self._create_bar_plot(bar_width, ax)
2352
+ elif self.graph_type == 'box':
2353
+ self._create_box_plot(ax)
2354
+ elif self.graph_type == 'violin':
2355
+ self._create_violin_plot(ax)
2356
+ elif self.graph_type == 'jitter':
2357
+ self._create_jitter_plot(ax)
2358
+ else:
2359
+ raise ValueError(f"Invalid graph_type: {self.graph_type}. Choose from 'bar', 'box', 'violin', or 'jitter'.")
2360
+
2361
+ # Set y-axis start
2362
+ if self.y_axis_start is not None:
2363
+ ax.set_ylim(bottom=self.y_axis_start)
2364
+
2365
+ # Add ticks, remove grid, and save plot
2366
+ ax.minorticks_on()
2367
+ ax.tick_params(axis='x', which='minor', bottom=False) # Disable minor ticks on x-axis
2368
+ ax.tick_params(axis='x', which='major', length=6, width=2, direction='out')
2369
+ ax.tick_params(axis='y', which='major', length=6, width=2, direction='out')
2370
+ ax.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
2371
+ sns.despine(ax=ax, top=True, right=True)
2372
+
2373
+ if self.save:
2374
+ self._save_results()
2375
+
2376
+ plt.show() # Ensure the plot is shown, but plt.show() doesn't clear the figure context
2377
+
2378
+ def get_figure(self):
2379
+ """Return the generated figure."""
2380
+ return self.fig
2381
+
2382
+ def _create_bar_plot(self, bar_width, ax):
2383
+ """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
2384
+ summary_df = self.df.groupby(self.grouping_column)[self.data_column].agg([self.summary_func, 'std', 'sem'])
2385
+
2386
+ if self.error_bar_type == 'std':
2387
+ error_bars = summary_df['std']
2388
+ elif self.error_bar_type == 'sem':
2389
+ error_bars = summary_df['sem']
2390
+ else:
2391
+ raise ValueError(f"Invalid error_bar_type: {self.error_bar_type}. Choose either 'std' or 'sem'.")
2392
+
2393
+ sns.barplot(x=self.grouping_column, y=self.summary_func, data=summary_df.reset_index(), ci=None, palette=self.sns_palette, width=bar_width, ax=ax)
2394
+
2395
+ # Plot the error bars
2396
+ ax.errorbar(x=np.arange(len(summary_df)), y=summary_df[self.summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
2397
+
2398
+ def _create_jitter_plot(self, ax):
2399
+ """Helper method to create a jitter plot (strip plot)."""
2400
+ sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, jitter=True, alpha=0.6, ax=ax)
2401
+
2402
+ def _create_box_plot(self, ax):
2403
+ """Helper method to create a box plot."""
2404
+ sns.boxplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2405
+
2406
+ def _create_violin_plot(self, ax):
2407
+ """Helper method to create a violin plot."""
2408
+ sns.violinplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
2409
+
2410
+ def _save_results(self):
2411
+ """Helper method to save the plot and results."""
2412
+ os.makedirs(self.output_dir, exist_ok=True)
2413
+ plot_path = os.path.join(self.output_dir, 'grouped_plot.png')
2414
+ self.fig.savefig(plot_path)
2415
+ results_path = os.path.join(self.output_dir, 'test_results.csv')
2416
+ self.results_df.to_csv(results_path, index=False)
2417
+ print(f"Plot saved to {plot_path}")
2418
+ print(f"Test results saved to {results_path}")
2419
+
2420
+ def get_results(self):
2421
+ """Return the results dataframe."""
2422
+ return self.results_df