spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +245 -2494
- spacr/deep_spacr.py +316 -48
- spacr/gui.py +1 -0
- spacr/gui_core.py +74 -63
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +346 -6
- spacr/io.py +680 -141
- spacr/logger.py +28 -9
- spacr/measure.py +107 -95
- spacr/mediar.py +0 -3
- spacr/ml.py +1051 -0
- spacr/openai.py +37 -0
- spacr/plot.py +707 -20
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +134 -47
- spacr/sim.py +0 -2
- spacr/submodules.py +349 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +238 -0
- spacr/utils.py +419 -180
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,
|
1
|
+
import os, random, cv2, glob, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -14,11 +14,15 @@ from skimage.segmentation import find_boundaries
|
|
14
14
|
from skimage import measure
|
15
15
|
from skimage.measure import find_contours, label, regionprops
|
16
16
|
|
17
|
+
from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
|
18
|
+
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
19
|
+
from scipy.stats import ttest_ind, mannwhitneyu, levene, wilcoxon, kruskal
|
20
|
+
import itertools
|
21
|
+
import pingouin as pg
|
22
|
+
|
17
23
|
from ipywidgets import IntSlider, interact
|
18
24
|
from IPython.display import Image as ipyimage
|
19
25
|
|
20
|
-
from .logger import log_function_call
|
21
|
-
|
22
26
|
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
|
23
27
|
"""Plot image and mask overlays."""
|
24
28
|
|
@@ -409,7 +413,7 @@ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, thr
|
|
409
413
|
plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile, overlay, save=False)
|
410
414
|
return
|
411
415
|
|
412
|
-
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max,
|
416
|
+
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, nuclei_limit, pathogen_limit):
|
413
417
|
"""
|
414
418
|
Filters objects in a plot based on various criteria.
|
415
419
|
|
@@ -420,8 +424,8 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
420
424
|
pathogen_mask_dim (int): The dimension index of the pathogen mask.
|
421
425
|
mask_dims (list): A list of dimension indices for additional masks.
|
422
426
|
filter_min_max (list): A list of minimum and maximum area values for each mask.
|
423
|
-
|
424
|
-
|
427
|
+
nuclei_limit (bool): Whether to include multinucleated cells.
|
428
|
+
pathogen_limit (bool): Whether to include multiinfected cells.
|
425
429
|
|
426
430
|
Returns:
|
427
431
|
numpy.ndarray: The filtered stack of masks.
|
@@ -451,9 +455,9 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
451
455
|
total_count_after = len(props_after['label'])
|
452
456
|
|
453
457
|
if mask_dim == cell_mask_dim:
|
454
|
-
if
|
458
|
+
if nuclei_limit is False and nucleus_mask_dim is not None:
|
455
459
|
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
|
456
|
-
if
|
460
|
+
if pathogen_limit is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
|
457
461
|
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
|
458
462
|
cell_area_before = avg_size_before
|
459
463
|
cell_count_before = total_count_before
|
@@ -662,18 +666,18 @@ def plot_merged(src, settings):
|
|
662
666
|
display(settings)
|
663
667
|
|
664
668
|
if settings['pathogen_mask_dim'] is None:
|
665
|
-
settings['
|
669
|
+
settings['pathogen_limit'] = True
|
666
670
|
|
667
671
|
for file in os.listdir(src):
|
668
672
|
path = os.path.join(src, file)
|
669
673
|
stack = np.load(path)
|
670
674
|
print(f'Loaded: {path}')
|
671
|
-
if not settings['
|
675
|
+
if not settings['uninfected']:
|
672
676
|
if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
|
673
677
|
stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
|
674
678
|
|
675
|
-
if settings['
|
676
|
-
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['
|
679
|
+
if settings['pathogen_limit'] is not True or settings['nuclei_limit'] is not True or settings['filter_min_max'] is not None:
|
680
|
+
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['nuclei_limit'], settings['pathogen_limit'])
|
677
681
|
|
678
682
|
overlayed_image, image, outlines = _normalize_and_outline(image=stack,
|
679
683
|
remove_background=settings['remove_background'],
|
@@ -999,8 +1003,107 @@ def _display_gif(path):
|
|
999
1003
|
"""
|
1000
1004
|
with open(path, 'rb') as file:
|
1001
1005
|
display(ipyimage(file.read()))
|
1006
|
+
|
1007
|
+
def _plot_recruitment_v2(df, df_type, channel_of_interest, columns=[], figuresize=10):
|
1008
|
+
"""
|
1009
|
+
Plot recruitment data for different conditions and pathogens.
|
1010
|
+
|
1011
|
+
Args:
|
1012
|
+
df (DataFrame): The input DataFrame containing the recruitment data.
|
1013
|
+
df_type (str): The type of DataFrame (e.g., 'train', 'test').
|
1014
|
+
channel_of_interest (str): The channel of interest for plotting.
|
1015
|
+
target (str): The target variable for plotting.
|
1016
|
+
columns (list, optional): Additional columns to plot. Defaults to an empty list.
|
1017
|
+
figuresize (int, optional): The size of the figure. Defaults to 50.
|
1018
|
+
|
1019
|
+
Returns:
|
1020
|
+
None
|
1021
|
+
"""
|
1022
|
+
|
1023
|
+
from .plot import spacrGraph
|
1024
|
+
|
1025
|
+
color_list = [(55/255, 155/255, 155/255),
|
1026
|
+
(155/255, 55/255, 155/255),
|
1027
|
+
(55/255, 155/255, 255/255),
|
1028
|
+
(255/255, 55/255, 155/255)]
|
1029
|
+
|
1030
|
+
sns.set_palette(sns.color_palette(color_list))
|
1031
|
+
font = figuresize/2
|
1032
|
+
width=figuresize
|
1033
|
+
height=figuresize/4
|
1034
|
+
|
1035
|
+
# Create the subplots
|
1036
|
+
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(width, height))
|
1037
|
+
|
1038
|
+
# Plot for 'cell_channel' on axes[0]
|
1039
|
+
plotter_cell = spacrGraph(df,grouping_column='condition', data_column=f'cell_channel_{channel_of_interest}_mean_intensity')
|
1040
|
+
plotter_cell.create_plot(ax=axes[0])
|
1041
|
+
axes[0].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1042
|
+
axes[0].set_ylabel(f'cell_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1043
|
+
|
1044
|
+
# Plot for 'nucleus_channel' on axes[1]
|
1045
|
+
plotter_nucleus = spacrGraph(df,grouping_column='condition', data_column=f'nucleus_channel_{channel_of_interest}_mean_intensity')
|
1046
|
+
plotter_nucleus.create_plot(ax=axes[1])
|
1047
|
+
axes[1].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1048
|
+
axes[1].set_ylabel(f'nucleus_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1049
|
+
|
1050
|
+
# Plot for 'cytoplasm_channel' on axes[2]
|
1051
|
+
plotter_cytoplasm = spacrGraph(df, grouping_column='condition', data_column=f'cytoplasm_channel_{channel_of_interest}_mean_intensity')
|
1052
|
+
plotter_cytoplasm.create_plot(ax=axes[2])
|
1053
|
+
axes[2].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1054
|
+
axes[2].set_ylabel(f'cytoplasm_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1055
|
+
|
1056
|
+
# Plot for 'pathogen_channel' on axes[3]
|
1057
|
+
plotter_pathogen = spacrGraph(df, grouping_column='condition', data_column=f'pathogen_channel_{channel_of_interest}_mean_intensity')
|
1058
|
+
plotter_pathogen.create_plot(ax=axes[3])
|
1059
|
+
axes[3].set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1060
|
+
axes[3].set_ylabel(f'pathogen_channel_{channel_of_interest}_mean_intensity', fontsize=font)
|
1061
|
+
|
1062
|
+
#axes[0].legend_.remove()
|
1063
|
+
#axes[1].legend_.remove()
|
1064
|
+
#axes[2].legend_.remove()
|
1065
|
+
#axes[3].legend_.remove()
|
1066
|
+
|
1067
|
+
handles, labels = axes[3].get_legend_handles_labels()
|
1068
|
+
axes[3].legend(handles, labels, bbox_to_anchor=(1.05, 0.5), loc='center left')
|
1069
|
+
for i in [0,1,2,3]:
|
1070
|
+
axes[i].tick_params(axis='both', which='major', labelsize=font)
|
1071
|
+
axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=45)
|
1072
|
+
|
1073
|
+
plt.tight_layout()
|
1074
|
+
plt.show()
|
1075
|
+
|
1076
|
+
columns = columns + ['pathogen_cytoplasm_mean_mean', 'pathogen_cytoplasm_q75_mean', 'pathogen_periphery_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_mean_mean', 'pathogen_outside_cytoplasm_q75_mean']
|
1077
|
+
#columns = columns + [f'pathogen_slope_channel_{channel_of_interest}', f'pathogen_cell_distance_channel_{channel_of_interest}', f'nucleus_cell_distance_channel_{channel_of_interest}']
|
1078
|
+
|
1079
|
+
width = figuresize*2
|
1080
|
+
columns_per_row = math.ceil(len(columns) / 2)
|
1081
|
+
height = (figuresize*2)/columns_per_row
|
1082
|
+
|
1083
|
+
fig, axes = plt.subplots(nrows=2, ncols=columns_per_row, figsize=(width, height * 2))
|
1084
|
+
axes = axes.flatten()
|
1085
|
+
|
1086
|
+
print(f'{columns}')
|
1087
|
+
for i, col in enumerate(columns):
|
1088
|
+
ax = axes[i]
|
1089
|
+
plotter_col = spacrGraph(df, grouping_column='condition', data_column=col)
|
1090
|
+
plotter_col.create_plot(ax=ax)
|
1091
|
+
ax.set_xlabel(f'pathogen {df_type}', fontsize=font)
|
1092
|
+
ax.set_ylabel(f'{col}', fontsize=int(font * 2))
|
1093
|
+
#ax.legend_.remove()
|
1094
|
+
ax.tick_params(axis='both', which='major', labelsize=font)
|
1095
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
|
1096
|
+
if i <= 5:
|
1097
|
+
ax.set_ylim(1, None)
|
1098
|
+
|
1099
|
+
# Turn off any unused axes
|
1100
|
+
for i in range(len(columns), len(axes)):
|
1101
|
+
axes[i].axis('off')
|
1102
|
+
|
1103
|
+
plt.tight_layout()
|
1104
|
+
plt.show()
|
1002
1105
|
|
1003
|
-
def _plot_recruitment(df, df_type, channel_of_interest,
|
1106
|
+
def _plot_recruitment(df, df_type, channel_of_interest, columns=[], figuresize=10):
|
1004
1107
|
"""
|
1005
1108
|
Plot recruitment data for different conditions and pathogens.
|
1006
1109
|
|
@@ -1153,10 +1256,6 @@ def _plot_controls(df, mask_chans, channel_of_interest, figuresize=5):
|
|
1153
1256
|
plt.tight_layout()
|
1154
1257
|
plt.show()
|
1155
1258
|
|
1156
|
-
###################################################
|
1157
|
-
# Classify
|
1158
|
-
###################################################
|
1159
|
-
|
1160
1259
|
def _imshow(img, labels, nrow=20, color='white', fontsize=12):
|
1161
1260
|
"""
|
1162
1261
|
Display multiple images in a grid with corresponding labels.
|
@@ -1359,7 +1458,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1359
1458
|
|
1360
1459
|
return plate_map, min_max
|
1361
1460
|
|
1362
|
-
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
|
1461
|
+
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
|
1363
1462
|
plates = df['prc'].str.split('_', expand=True)[0].unique()
|
1364
1463
|
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
1365
1464
|
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
@@ -1374,6 +1473,12 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
|
|
1374
1473
|
fig.delaxes(ax[i])
|
1375
1474
|
|
1376
1475
|
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
1476
|
+
|
1477
|
+
if not dst is None:
|
1478
|
+
filename = os.path.join(dst, 'plate_heatmap.pdf')
|
1479
|
+
fig.savefig(filename, format='pdf')
|
1480
|
+
print(f'Saved heatmap to {filename}')
|
1481
|
+
|
1377
1482
|
if verbose:
|
1378
1483
|
plt.show()
|
1379
1484
|
return fig
|
@@ -1605,13 +1710,19 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
|
1605
1710
|
print(f'Saved Volcano plot: {filename}')
|
1606
1711
|
plt.show()
|
1607
1712
|
|
1608
|
-
def plot_histogram(df, dependent_variable):
|
1713
|
+
def plot_histogram(df, dependent_variable, dst=None):
|
1609
1714
|
# Plot histogram of the dependent variable
|
1610
1715
|
plt.figure(figsize=(10, 6))
|
1611
1716
|
sns.histplot(df[dependent_variable], kde=True)
|
1612
1717
|
plt.title(f'Histogram of {dependent_variable}')
|
1613
1718
|
plt.xlabel(dependent_variable)
|
1614
1719
|
plt.ylabel('Frequency')
|
1720
|
+
|
1721
|
+
if not dst is None:
|
1722
|
+
filename = os.path.join(dst, 'dependent_variable_histogram.pdf')
|
1723
|
+
plt.savefig(filename, format='pdf')
|
1724
|
+
print(f'Saved histogram to {filename}')
|
1725
|
+
|
1615
1726
|
plt.show()
|
1616
1727
|
|
1617
1728
|
def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
|
@@ -1732,4 +1843,580 @@ def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time
|
|
1732
1843
|
plt.ylim(y_lim)
|
1733
1844
|
plt.show()
|
1734
1845
|
else:
|
1735
|
-
print("No CSV files found in the specified directory.")
|
1846
|
+
print("No CSV files found in the specified directory.")
|
1847
|
+
|
1848
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
1849
|
+
"""
|
1850
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
1851
|
+
|
1852
|
+
Args:
|
1853
|
+
src (str): Path to the source data.
|
1854
|
+
x_column (str): Name of the column to be used for the x-axis.
|
1855
|
+
y_column (str): Name of the column to be used for the y-axis.
|
1856
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
1857
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
1858
|
+
|
1859
|
+
Returns:
|
1860
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
1861
|
+
"""
|
1862
|
+
|
1863
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
1864
|
+
from .io import _read_and_merge_data, _read_db
|
1865
|
+
db_loc = [src+'/measurements/measurements.db']
|
1866
|
+
loc = src+'/measurements/measurements.db'
|
1867
|
+
df, _ = _read_and_merge_data(db_loc,
|
1868
|
+
tables,
|
1869
|
+
verbose=True,
|
1870
|
+
nuclei_limit=True,
|
1871
|
+
pathogen_limit=True,
|
1872
|
+
uninfected=True)
|
1873
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
1874
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
1875
|
+
return merged_df
|
1876
|
+
|
1877
|
+
# Read the CSV file into a DataFrame
|
1878
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
1879
|
+
|
1880
|
+
# Print column names for debugging
|
1881
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
1882
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
1883
|
+
|
1884
|
+
# Replace NaN values with a specific label in x_column
|
1885
|
+
df[x_column] = df[x_column].fillna('NaN')
|
1886
|
+
|
1887
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
1888
|
+
if not filter_column is None:
|
1889
|
+
if isinstance(filter_column, str):
|
1890
|
+
df = df[df[filter_column].isin(filter_values)]
|
1891
|
+
if isinstance(filter_column, list):
|
1892
|
+
for i,val in enumerate(filter_column):
|
1893
|
+
print(f'hello {len(df)}')
|
1894
|
+
df = df[df[val].isin(filter_values[i])]
|
1895
|
+
|
1896
|
+
# Use the correct column names based on your DataFrame
|
1897
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
1898
|
+
if not all(column in df.columns for column in required_columns):
|
1899
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
1900
|
+
|
1901
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
1902
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
1903
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
1904
|
+
|
1905
|
+
# Determine the minimum count of examples across all groups in x_column
|
1906
|
+
min_count = retained_rows[x_column].value_counts().min()
|
1907
|
+
print(f'Found {min_count} annotated images')
|
1908
|
+
|
1909
|
+
# Randomly sample min_count examples from each group in x_column
|
1910
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
1911
|
+
|
1912
|
+
# Create the jitter plot
|
1913
|
+
plt.figure(figsize=(10, 6))
|
1914
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
1915
|
+
plt.title(plot_title)
|
1916
|
+
plt.xlabel(x_column)
|
1917
|
+
plt.ylabel(y_column)
|
1918
|
+
|
1919
|
+
# Customize the x-axis labels
|
1920
|
+
plt.xticks(rotation=45, ha='right')
|
1921
|
+
|
1922
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
1923
|
+
ax = plt.gca()
|
1924
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
1925
|
+
|
1926
|
+
# Save the plot to a file or display it
|
1927
|
+
if output_path:
|
1928
|
+
plt.savefig(output_path, bbox_inches='tight')
|
1929
|
+
print(f"Jitter plot saved to {output_path}")
|
1930
|
+
else:
|
1931
|
+
plt.show()
|
1932
|
+
|
1933
|
+
return balanced_df
|
1934
|
+
|
1935
|
+
def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_axis_start=None, error_bar_type='std'):
|
1936
|
+
"""
|
1937
|
+
Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
|
1938
|
+
|
1939
|
+
Parameters:
|
1940
|
+
- df: DataFrame containing the data.
|
1941
|
+
- grouping_column: Column name for the categorical grouping.
|
1942
|
+
- data_column: Column name for the data to be grouped and plotted.
|
1943
|
+
- graph_type: Type of plot ('bar', 'violin', 'jitter', 'box', 'jitter_box').
|
1944
|
+
- summary_func: Summary function to apply to each group ('mean', 'median', etc.).
|
1945
|
+
- order: List specifying the order of the groups. If None, groups will be ordered alphabetically.
|
1946
|
+
- colors: List of colors for each group.
|
1947
|
+
- output_dir: Directory where the figure and test results will be saved if `save=True`.
|
1948
|
+
- save: Boolean flag indicating whether to save the plot and results to files.
|
1949
|
+
- y_axis_start: Optional starting value for the y-axis.
|
1950
|
+
- error_bar_type: Type of error bars to plot, either 'std' for standard deviation or 'sem' for standard error of the mean.
|
1951
|
+
|
1952
|
+
Outputs:
|
1953
|
+
- Figure of the plot.
|
1954
|
+
- DataFrame with full statistical test results, including normality tests.
|
1955
|
+
"""
|
1956
|
+
|
1957
|
+
# Remove NaN rows in grouping_column
|
1958
|
+
df = df.dropna(subset=[grouping_column])
|
1959
|
+
|
1960
|
+
# Ensure the output directory exists if save is True
|
1961
|
+
if save:
|
1962
|
+
os.makedirs(output_dir, exist_ok=True)
|
1963
|
+
|
1964
|
+
# Sorting and ordering
|
1965
|
+
if order:
|
1966
|
+
df[grouping_column] = pd.Categorical(df[grouping_column], categories=order, ordered=True)
|
1967
|
+
else:
|
1968
|
+
df[grouping_column] = pd.Categorical(df[grouping_column], categories=sorted(df[grouping_column].unique()), ordered=True)
|
1969
|
+
|
1970
|
+
# Get unique groups
|
1971
|
+
unique_groups = df[grouping_column].unique()
|
1972
|
+
|
1973
|
+
# Initialize test results
|
1974
|
+
test_results = []
|
1975
|
+
|
1976
|
+
# Test normality for each group
|
1977
|
+
grouped_data = [df.loc[df[grouping_column] == group, data_column] for group in unique_groups]
|
1978
|
+
normal_p_values = [normaltest(data).pvalue for data in grouped_data]
|
1979
|
+
normal_stats = [normaltest(data).statistic for data in grouped_data]
|
1980
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
1981
|
+
|
1982
|
+
# Add normality test results to the results_df
|
1983
|
+
for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
|
1984
|
+
test_results.append({
|
1985
|
+
'Comparison': f'Normality test for {group}',
|
1986
|
+
'Test Statistic': stat,
|
1987
|
+
'p-value': p_value,
|
1988
|
+
'Test Name': 'Normality test'
|
1989
|
+
})
|
1990
|
+
|
1991
|
+
# Determine statistical test
|
1992
|
+
if len(unique_groups) == 2:
|
1993
|
+
if is_normal:
|
1994
|
+
stat_test = ttest_ind
|
1995
|
+
test_name = 'T-test'
|
1996
|
+
else:
|
1997
|
+
stat_test = mannwhitneyu
|
1998
|
+
test_name = 'Mann-Whitney U test'
|
1999
|
+
else:
|
2000
|
+
if is_normal:
|
2001
|
+
stat_test = f_oneway
|
2002
|
+
test_name = 'One-way ANOVA'
|
2003
|
+
else:
|
2004
|
+
stat_test = kruskal
|
2005
|
+
test_name = 'Kruskal-Wallis test'
|
2006
|
+
|
2007
|
+
# Perform pairwise statistical tests
|
2008
|
+
comparisons = list(itertools.combinations(unique_groups, 2))
|
2009
|
+
p_values = []
|
2010
|
+
test_statistics = []
|
2011
|
+
|
2012
|
+
for (group1, group2) in comparisons:
|
2013
|
+
data1 = df[df[grouping_column] == group1][data_column]
|
2014
|
+
data2 = df[df[grouping_column] == group2][data_column]
|
2015
|
+
stat, p = stat_test(data1, data2)
|
2016
|
+
p_values.append(p)
|
2017
|
+
test_statistics.append(stat)
|
2018
|
+
test_results.append({'Comparison': f'{group1} vs {group2}', 'Test Statistic': stat, 'p-value': p, 'Test Name': test_name})
|
2019
|
+
|
2020
|
+
# Post-hoc test (Tukey HSD for ANOVA)
|
2021
|
+
posthoc_p_values = None
|
2022
|
+
if is_normal and len(unique_groups) > 2:
|
2023
|
+
tukey_result = pairwise_tukeyhsd(df[data_column], df[grouping_column], alpha=0.05)
|
2024
|
+
posthoc_p_values = tukey_result.pvalues
|
2025
|
+
for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
|
2026
|
+
test_results.append({
|
2027
|
+
'Comparison': f'{comparison[0]} vs {comparison[1]}',
|
2028
|
+
'Test Statistic': None, # Tukey does not provide a test statistic in the same way
|
2029
|
+
'p-value': p_value,
|
2030
|
+
'Test Name': 'Tukey HSD Post-hoc'
|
2031
|
+
})
|
2032
|
+
|
2033
|
+
# Create plot
|
2034
|
+
plt.figure(figsize=(10, 6))
|
2035
|
+
sns.set(style="whitegrid")
|
2036
|
+
|
2037
|
+
if colors:
|
2038
|
+
color_palette = colors
|
2039
|
+
else:
|
2040
|
+
color_palette = sns.color_palette("husl", len(unique_groups))
|
2041
|
+
|
2042
|
+
# Choose graph type
|
2043
|
+
if graph_type == 'bar':
|
2044
|
+
summary_df = df.groupby(grouping_column)[data_column].agg([summary_func, 'std', 'sem'])
|
2045
|
+
|
2046
|
+
# Set error bars based on error_bar_type
|
2047
|
+
if error_bar_type == 'std':
|
2048
|
+
error_bars = summary_df['std']
|
2049
|
+
elif error_bar_type == 'sem':
|
2050
|
+
error_bars = summary_df['sem']
|
2051
|
+
else:
|
2052
|
+
raise ValueError(f"Invalid error_bar_type: {error_bar_type}. Choose either 'std' or 'sem'.")
|
2053
|
+
|
2054
|
+
sns.barplot(x=grouping_column, y=summary_func, data=summary_df.reset_index(), ci=None, order=order, palette=color_palette)
|
2055
|
+
|
2056
|
+
# Add error bars (standard deviation or standard error of the mean)
|
2057
|
+
plt.errorbar(x=np.arange(len(summary_df)), y=summary_df[summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
|
2058
|
+
|
2059
|
+
elif graph_type == 'violin':
|
2060
|
+
sns.violinplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
2061
|
+
elif graph_type == 'jitter':
|
2062
|
+
sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, order=order, palette=color_palette)
|
2063
|
+
elif graph_type == 'box':
|
2064
|
+
sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
2065
|
+
elif graph_type == 'jitter_box':
|
2066
|
+
sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
2067
|
+
sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, color='black', alpha=0.5, order=order)
|
2068
|
+
|
2069
|
+
# Create a DataFrame to summarize the test results
|
2070
|
+
results_df = pd.DataFrame(test_results)
|
2071
|
+
|
2072
|
+
# Set y-axis start if provided
|
2073
|
+
if y_axis_start is not None:
|
2074
|
+
plt.ylim(bottom=y_axis_start)
|
2075
|
+
else:
|
2076
|
+
plt.ylim(0, None) # Default to starting at 0 if no custom start value is provided
|
2077
|
+
|
2078
|
+
# If save is True, save the plot and results as PNG and CSV
|
2079
|
+
if save:
|
2080
|
+
# Save the plot as PNG
|
2081
|
+
plot_path = os.path.join(output_dir, 'grouped_plot.png')
|
2082
|
+
plt.title(f'{test_name} results for {graph_type} plot')
|
2083
|
+
plt.xticks(rotation=45)
|
2084
|
+
plt.tight_layout()
|
2085
|
+
plt.savefig(plot_path)
|
2086
|
+
print(f"Plot saved to {plot_path}")
|
2087
|
+
|
2088
|
+
# Save the test results as a CSV file
|
2089
|
+
results_path = os.path.join(output_dir, 'test_results.csv')
|
2090
|
+
results_df.to_csv(results_path, index=False)
|
2091
|
+
print(f"Test results saved to {results_path}")
|
2092
|
+
|
2093
|
+
# Show the plot
|
2094
|
+
plt.show()
|
2095
|
+
|
2096
|
+
return plt.gcf(), results_df
|
2097
|
+
|
2098
|
+
class spacrGraph:
|
2099
|
+
def __init__(self, df, grouping_column, data_column, graph_type='bar', summary_func='mean',
|
2100
|
+
order=None, colors=None, output_dir='./output', save=False, y_axis_start=None,
|
2101
|
+
error_bar_type='std', remove_outliers=False, theme='pastel', representation='object',
|
2102
|
+
paired=False, all_to_all=True, compare_group=None):
|
2103
|
+
"""
|
2104
|
+
Class for creating grouped plots with optional statistical tests and data preprocessing.
|
2105
|
+
"""
|
2106
|
+
self.df = df
|
2107
|
+
self.grouping_column = grouping_column
|
2108
|
+
self.data_column = data_column
|
2109
|
+
self.graph_type = graph_type
|
2110
|
+
self.summary_func = summary_func
|
2111
|
+
self.order = order
|
2112
|
+
self.colors = colors
|
2113
|
+
self.output_dir = output_dir
|
2114
|
+
self.save = save
|
2115
|
+
self.y_axis_start = y_axis_start
|
2116
|
+
self.error_bar_type = error_bar_type
|
2117
|
+
self.remove_outliers = remove_outliers
|
2118
|
+
self.theme = theme
|
2119
|
+
self.representation = representation
|
2120
|
+
self.paired = paired
|
2121
|
+
self.all_to_all = all_to_all
|
2122
|
+
self.compare_group = compare_group
|
2123
|
+
|
2124
|
+
self.results_df = pd.DataFrame()
|
2125
|
+
self.sns_palette = None
|
2126
|
+
self.fig = None # To store the generated figure
|
2127
|
+
|
2128
|
+
# Preprocess and set palette
|
2129
|
+
self._set_theme()
|
2130
|
+
self.raw_df = self.df.copy() # Preserve the raw data for n_object count
|
2131
|
+
self.df = self.preprocess_data()
|
2132
|
+
|
2133
|
+
def _set_theme(self):
|
2134
|
+
"""Set the Seaborn theme and reorder colors if necessary."""
|
2135
|
+
integer_list = list(range(1, 81))
|
2136
|
+
color_order = [0, 3, 9, 4, 6, 7, 9, 2] + integer_list
|
2137
|
+
self.sns_palette = self._set_reordered_theme(self.theme, color_order, 100)
|
2138
|
+
|
2139
|
+
def _set_reordered_theme(self, theme='muted', order=None, n_colors=100, show_theme=False):
|
2140
|
+
"""Set and reorder the Seaborn color palette."""
|
2141
|
+
palette = sns.color_palette(theme, n_colors)
|
2142
|
+
if order:
|
2143
|
+
reordered_palette = [palette[i] for i in order]
|
2144
|
+
else:
|
2145
|
+
reordered_palette = palette
|
2146
|
+
if show_theme:
|
2147
|
+
sns.palplot(reordered_palette)
|
2148
|
+
plt.show()
|
2149
|
+
return reordered_palette
|
2150
|
+
|
2151
|
+
def preprocess_data(self):
|
2152
|
+
"""Preprocess the data: remove NaNs, sort/order the grouping column, and optionally group by 'prc'."""
|
2153
|
+
df = self.df.dropna(subset=[self.grouping_column, self.data_column])
|
2154
|
+
|
2155
|
+
# Group by 'prc' column if representation is 'well'
|
2156
|
+
if self.representation == 'well':
|
2157
|
+
df = df.groupby(['prc', self.grouping_column])[self.data_column].agg(self.summary_func).reset_index()
|
2158
|
+
|
2159
|
+
if self.order:
|
2160
|
+
df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=self.order, ordered=True)
|
2161
|
+
else:
|
2162
|
+
df[self.grouping_column] = pd.Categorical(df[self.grouping_column], categories=sorted(df[self.grouping_column].unique()), ordered=True)
|
2163
|
+
|
2164
|
+
return df
|
2165
|
+
|
2166
|
+
def remove_outliers_from_plot(self):
|
2167
|
+
"""Remove outliers from the plot but keep them in the data."""
|
2168
|
+
filtered_df = self.df.copy()
|
2169
|
+
unique_groups = filtered_df[self.grouping_column].unique()
|
2170
|
+
for group in unique_groups:
|
2171
|
+
group_data = filtered_df[filtered_df[self.grouping_column] == group][self.data_column]
|
2172
|
+
q1 = group_data.quantile(0.25)
|
2173
|
+
q3 = group_data.quantile(0.75)
|
2174
|
+
iqr = q3 - q1
|
2175
|
+
lower_bound = q1 - 1.5 * iqr
|
2176
|
+
upper_bound = q3 + 1.5 * iqr
|
2177
|
+
filtered_df = filtered_df.drop(filtered_df[(filtered_df[self.grouping_column] == group) & ((filtered_df[self.data_column] < lower_bound) | (filtered_df[self.data_column] > upper_bound))].index)
|
2178
|
+
return filtered_df
|
2179
|
+
|
2180
|
+
def perform_normality_tests(self):
|
2181
|
+
"""Perform normality tests for each group."""
|
2182
|
+
unique_groups = self.df[self.grouping_column].unique()
|
2183
|
+
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2184
|
+
raw_grouped_data = [self.raw_df.loc[self.raw_df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2185
|
+
|
2186
|
+
normal_p_values = [normaltest(data).pvalue for data in grouped_data]
|
2187
|
+
normal_stats = [normaltest(data).statistic for data in grouped_data]
|
2188
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
2189
|
+
|
2190
|
+
test_results = []
|
2191
|
+
for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
|
2192
|
+
test_results.append({
|
2193
|
+
'Comparison': f'Normality test for {group}',
|
2194
|
+
'Test Statistic': stat,
|
2195
|
+
'p-value': p_value,
|
2196
|
+
'Test Name': 'Normality test',
|
2197
|
+
'n_object': len(raw_grouped_data[unique_groups.tolist().index(group)]), # Raw sample size (objects/cells)
|
2198
|
+
'n_well': len(grouped_data[unique_groups.tolist().index(group)]) if self.representation == 'well' else np.nan # Summarized size (wells)
|
2199
|
+
})
|
2200
|
+
return is_normal, test_results
|
2201
|
+
|
2202
|
+
def perform_levene_test(self, unique_groups):
|
2203
|
+
"""Perform Levene's test for equal variance."""
|
2204
|
+
grouped_data = [self.df.loc[self.df[self.grouping_column] == group, self.data_column] for group in unique_groups]
|
2205
|
+
stat, p_value = levene(*grouped_data)
|
2206
|
+
return stat, p_value
|
2207
|
+
|
2208
|
+
def perform_statistical_tests(self, unique_groups, is_normal):
|
2209
|
+
"""Perform statistical tests based on the number of groups, normality, and paired flag."""
|
2210
|
+
if len(unique_groups) == 2:
|
2211
|
+
if is_normal:
|
2212
|
+
if self.paired:
|
2213
|
+
stat_test = pg.ttest # Paired T-test
|
2214
|
+
test_name = 'Paired T-test'
|
2215
|
+
else:
|
2216
|
+
stat_test = ttest_ind
|
2217
|
+
test_name = 'T-test'
|
2218
|
+
else:
|
2219
|
+
if self.paired:
|
2220
|
+
stat_test = pg.wilcoxon # Paired Wilcoxon test
|
2221
|
+
test_name = 'Paired Wilcoxon test'
|
2222
|
+
else:
|
2223
|
+
stat_test = mannwhitneyu
|
2224
|
+
test_name = 'Mann-Whitney U test'
|
2225
|
+
else:
|
2226
|
+
if is_normal:
|
2227
|
+
stat_test = f_oneway
|
2228
|
+
test_name = 'One-way ANOVA'
|
2229
|
+
else:
|
2230
|
+
stat_test = kruskal
|
2231
|
+
test_name = 'Kruskal-Wallis test'
|
2232
|
+
|
2233
|
+
comparisons = list(itertools.combinations(unique_groups, 2))
|
2234
|
+
test_results = []
|
2235
|
+
for (group1, group2) in comparisons:
|
2236
|
+
data1 = self.df[self.df[self.grouping_column] == group1][self.data_column]
|
2237
|
+
data2 = self.df[self.df[self.grouping_column] == group2][self.data_column]
|
2238
|
+
raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == group1][self.data_column]
|
2239
|
+
raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == group2][self.data_column]
|
2240
|
+
|
2241
|
+
if self.paired:
|
2242
|
+
stat, p = stat_test(data1, data2, paired=True)
|
2243
|
+
else:
|
2244
|
+
stat, p = stat_test(data1, data2)
|
2245
|
+
|
2246
|
+
test_results.append({
|
2247
|
+
'Comparison': f'{group1} vs {group2}',
|
2248
|
+
'Test Statistic': stat,
|
2249
|
+
'p-value': p,
|
2250
|
+
'Test Name': test_name,
|
2251
|
+
'n_object': len(raw_data1) + len(raw_data2), # Raw sample size (objects/cells)
|
2252
|
+
'n_well': len(data1) + len(data2) if self.representation == 'well' else np.nan # Summarized size (wells)
|
2253
|
+
})
|
2254
|
+
return test_results
|
2255
|
+
|
2256
|
+
def perform_posthoc_tests(self, is_normal, unique_groups):
|
2257
|
+
"""Perform post-hoc tests for multiple groups based on all_to_all flag."""
|
2258
|
+
if is_normal and len(unique_groups) > 2 and self.all_to_all:
|
2259
|
+
# Tukey HSD Post-hoc when comparing all to all
|
2260
|
+
tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
|
2261
|
+
posthoc_results = []
|
2262
|
+
for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
|
2263
|
+
raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == comparison[0]][self.data_column]
|
2264
|
+
raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == comparison[1]][self.data_column]
|
2265
|
+
|
2266
|
+
posthoc_results.append({
|
2267
|
+
'Comparison': f'{comparison[0]} vs {comparison[1]}',
|
2268
|
+
'Test Statistic': None, # Tukey does not provide a test statistic
|
2269
|
+
'p-value': p_value,
|
2270
|
+
'Test Name': 'Tukey HSD Post-hoc',
|
2271
|
+
'n_object': len(raw_data1) + len(raw_data2),
|
2272
|
+
'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])
|
2273
|
+
})
|
2274
|
+
return posthoc_results
|
2275
|
+
|
2276
|
+
elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
|
2277
|
+
# Dunn's post-hoc test using Pingouin
|
2278
|
+
dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
|
2279
|
+
posthoc_results = []
|
2280
|
+
for idx, row in dunn_result.iterrows():
|
2281
|
+
if row['A'] == self.compare_group or row['B'] == self.compare_group:
|
2282
|
+
posthoc_results.append({
|
2283
|
+
'Comparison': f"{row['A']} vs {row['B']}",
|
2284
|
+
'Test Statistic': row['T'], # Test statistic from Dunn's test
|
2285
|
+
'p-value': row['p-val'],
|
2286
|
+
'Test Name': 'Dunn’s Post-hoc',
|
2287
|
+
'n_object': None,
|
2288
|
+
'n_well': None
|
2289
|
+
})
|
2290
|
+
return posthoc_results
|
2291
|
+
return []
|
2292
|
+
|
2293
|
+
def create_plot(self, ax=None):
|
2294
|
+
"""Create and display the plot based on the chosen graph type."""
|
2295
|
+
# Optional: Remove outliers for plotting
|
2296
|
+
if self.remove_outliers:
|
2297
|
+
self.df = self.remove_outliers_from_plot()
|
2298
|
+
|
2299
|
+
# Perform normality tests
|
2300
|
+
is_normal, normality_results = self.perform_normality_tests()
|
2301
|
+
|
2302
|
+
# Perform Levene's test for equal variance
|
2303
|
+
unique_groups = self.df[self.grouping_column].unique()
|
2304
|
+
levene_stat, levene_p = self.perform_levene_test(unique_groups)
|
2305
|
+
levene_result = {
|
2306
|
+
'Comparison': 'Levene’s test for equal variance',
|
2307
|
+
'Test Statistic': levene_stat,
|
2308
|
+
'p-value': levene_p,
|
2309
|
+
'Test Name': 'Levene’s Test'
|
2310
|
+
}
|
2311
|
+
|
2312
|
+
# Perform statistical tests
|
2313
|
+
stat_results = self.perform_statistical_tests(unique_groups, is_normal)
|
2314
|
+
|
2315
|
+
# Perform post-hoc tests if applicable
|
2316
|
+
posthoc_results = self.perform_posthoc_tests(is_normal, unique_groups)
|
2317
|
+
|
2318
|
+
# Combine all test results
|
2319
|
+
self.results_df = pd.DataFrame(normality_results + [levene_result] + stat_results + posthoc_results)
|
2320
|
+
|
2321
|
+
# Add sample size column
|
2322
|
+
sample_sizes = self.df.groupby(self.grouping_column)[self.data_column].count().reset_index(name='n')
|
2323
|
+
self.results_df['n'] = self.results_df['Comparison'].apply(
|
2324
|
+
lambda x: next((sample_sizes[sample_sizes[self.grouping_column] == g]['n'].values[0] for g in sample_sizes[self.grouping_column] if g in x), np.nan)
|
2325
|
+
)
|
2326
|
+
|
2327
|
+
# Dynamically set figure dimensions based on the number of unique groups
|
2328
|
+
num_groups = len(unique_groups)
|
2329
|
+
bar_width = 0.6 # Set the desired thickness of each bar
|
2330
|
+
spacing_between_groups = 0.3 # Set the desired spacing between bars and axis
|
2331
|
+
|
2332
|
+
fig_width = num_groups * (bar_width + spacing_between_groups) # Dynamically calculate the figure width
|
2333
|
+
fig_height = 6 # Fixed height for the plot
|
2334
|
+
|
2335
|
+
if ax is None:
|
2336
|
+
self.fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Store the figure in self.fig
|
2337
|
+
else:
|
2338
|
+
self.fig = ax.figure # Store the figure if ax is provided
|
2339
|
+
|
2340
|
+
sns.set(style="ticks")
|
2341
|
+
color_palette = self.sns_palette if not self.colors else self.colors
|
2342
|
+
|
2343
|
+
# Calculate x-axis limits to ensure equal space between the bars and the y-axis
|
2344
|
+
xlim_lower = -0.5 # Ensures space between the y-axis and the first category
|
2345
|
+
xlim_upper = num_groups - 0.5 # Ensures space after the last category
|
2346
|
+
ax.set_xlim(xlim_lower, xlim_upper)
|
2347
|
+
|
2348
|
+
if self.summary_func is None:
|
2349
|
+
sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=color_palette, jitter=True, alpha=0.6, ax=ax)
|
2350
|
+
elif self.graph_type == 'bar':
|
2351
|
+
self._create_bar_plot(bar_width, ax)
|
2352
|
+
elif self.graph_type == 'box':
|
2353
|
+
self._create_box_plot(ax)
|
2354
|
+
elif self.graph_type == 'violin':
|
2355
|
+
self._create_violin_plot(ax)
|
2356
|
+
elif self.graph_type == 'jitter':
|
2357
|
+
self._create_jitter_plot(ax)
|
2358
|
+
else:
|
2359
|
+
raise ValueError(f"Invalid graph_type: {self.graph_type}. Choose from 'bar', 'box', 'violin', or 'jitter'.")
|
2360
|
+
|
2361
|
+
# Set y-axis start
|
2362
|
+
if self.y_axis_start is not None:
|
2363
|
+
ax.set_ylim(bottom=self.y_axis_start)
|
2364
|
+
|
2365
|
+
# Add ticks, remove grid, and save plot
|
2366
|
+
ax.minorticks_on()
|
2367
|
+
ax.tick_params(axis='x', which='minor', bottom=False) # Disable minor ticks on x-axis
|
2368
|
+
ax.tick_params(axis='x', which='major', length=6, width=2, direction='out')
|
2369
|
+
ax.tick_params(axis='y', which='major', length=6, width=2, direction='out')
|
2370
|
+
ax.tick_params(axis='y', which='minor', length=4, width=1, direction='out')
|
2371
|
+
sns.despine(ax=ax, top=True, right=True)
|
2372
|
+
|
2373
|
+
if self.save:
|
2374
|
+
self._save_results()
|
2375
|
+
|
2376
|
+
plt.show() # Ensure the plot is shown, but plt.show() doesn't clear the figure context
|
2377
|
+
|
2378
|
+
def get_figure(self):
|
2379
|
+
"""Return the generated figure."""
|
2380
|
+
return self.fig
|
2381
|
+
|
2382
|
+
def _create_bar_plot(self, bar_width, ax):
|
2383
|
+
"""Helper method to create a bar plot with consistent bar thickness and centered error bars."""
|
2384
|
+
summary_df = self.df.groupby(self.grouping_column)[self.data_column].agg([self.summary_func, 'std', 'sem'])
|
2385
|
+
|
2386
|
+
if self.error_bar_type == 'std':
|
2387
|
+
error_bars = summary_df['std']
|
2388
|
+
elif self.error_bar_type == 'sem':
|
2389
|
+
error_bars = summary_df['sem']
|
2390
|
+
else:
|
2391
|
+
raise ValueError(f"Invalid error_bar_type: {self.error_bar_type}. Choose either 'std' or 'sem'.")
|
2392
|
+
|
2393
|
+
sns.barplot(x=self.grouping_column, y=self.summary_func, data=summary_df.reset_index(), ci=None, palette=self.sns_palette, width=bar_width, ax=ax)
|
2394
|
+
|
2395
|
+
# Plot the error bars
|
2396
|
+
ax.errorbar(x=np.arange(len(summary_df)), y=summary_df[self.summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
|
2397
|
+
|
2398
|
+
def _create_jitter_plot(self, ax):
|
2399
|
+
"""Helper method to create a jitter plot (strip plot)."""
|
2400
|
+
sns.stripplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, jitter=True, alpha=0.6, ax=ax)
|
2401
|
+
|
2402
|
+
def _create_box_plot(self, ax):
|
2403
|
+
"""Helper method to create a box plot."""
|
2404
|
+
sns.boxplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
|
2405
|
+
|
2406
|
+
def _create_violin_plot(self, ax):
|
2407
|
+
"""Helper method to create a violin plot."""
|
2408
|
+
sns.violinplot(x=self.grouping_column, y=self.data_column, data=self.df, palette=self.sns_palette, ax=ax)
|
2409
|
+
|
2410
|
+
def _save_results(self):
|
2411
|
+
"""Helper method to save the plot and results."""
|
2412
|
+
os.makedirs(self.output_dir, exist_ok=True)
|
2413
|
+
plot_path = os.path.join(self.output_dir, 'grouped_plot.png')
|
2414
|
+
self.fig.savefig(plot_path)
|
2415
|
+
results_path = os.path.join(self.output_dir, 'test_results.csv')
|
2416
|
+
self.results_df.to_csv(results_path, index=False)
|
2417
|
+
print(f"Plot saved to {plot_path}")
|
2418
|
+
print(f"Test results saved to {results_path}")
|
2419
|
+
|
2420
|
+
def get_results(self):
|
2421
|
+
"""Return the results dataframe."""
|
2422
|
+
return self.results_df
|