spacr 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +142 -2495
- spacr/deep_spacr.py +151 -29
- spacr/gui.py +1 -0
- spacr/gui_core.py +74 -63
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +346 -6
- spacr/io.py +631 -51
- spacr/logger.py +28 -9
- spacr/measure.py +107 -95
- spacr/mediar.py +0 -5
- spacr/ml.py +964 -0
- spacr/openai.py +37 -0
- spacr/plot.py +281 -16
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +129 -43
- spacr/sim.py +0 -2
- spacr/submodules.py +348 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +233 -0
- spacr/utils.py +275 -173
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/openai.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
#from openai import OpenAI, APIError, OpenAIError
|
2
|
+
#import openai
|
3
|
+
#
|
4
|
+
#class Chatbot:
|
5
|
+
# def __init__(self, api_key):
|
6
|
+
# openai.api_key = api_key
|
7
|
+
#
|
8
|
+
# def ask_question(self, question):
|
9
|
+
# try:
|
10
|
+
# # Sending the request to the gpt-4o-mini model
|
11
|
+
# response = openai.chat.completions.create(
|
12
|
+
# model="gpt-3.5-turbo", # Correct model name for your setup
|
13
|
+
# messages=[
|
14
|
+
# {"role": "user", "content": question}
|
15
|
+
# ],
|
16
|
+
# max_tokens=150
|
17
|
+
# )
|
18
|
+
# # Extracting the response from the model and returning it
|
19
|
+
# return response.choices[0].message['content'].strip()
|
20
|
+
# except (APIError, OpenAIError) as e:
|
21
|
+
# return f"Error: {str(e)}"
|
22
|
+
#
|
23
|
+
#def list_available_models():
|
24
|
+
# try:
|
25
|
+
# # List available models
|
26
|
+
# models = openai.models.list()
|
27
|
+
#
|
28
|
+
# # Iterate through models directly
|
29
|
+
# for model in models:
|
30
|
+
# print(model.id) # Print the model ID
|
31
|
+
# except (APIError, OpenAIError) as e:
|
32
|
+
# print(f"Error: {str(e)}")
|
33
|
+
|
34
|
+
# Replace with your actual API key or ensure it's set as an environment variable
|
35
|
+
|
36
|
+
#openai.api_key = api_key
|
37
|
+
#list_available_models()
|
spacr/plot.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os,
|
1
|
+
import os, random, cv2, glob, math, torch
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -14,11 +14,14 @@ from skimage.segmentation import find_boundaries
|
|
14
14
|
from skimage import measure
|
15
15
|
from skimage.measure import find_contours, label, regionprops
|
16
16
|
|
17
|
+
from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
|
18
|
+
from statsmodels.stats.multicomp import pairwise_tukeyhsd
|
19
|
+
import itertools
|
20
|
+
|
21
|
+
|
17
22
|
from ipywidgets import IntSlider, interact
|
18
23
|
from IPython.display import Image as ipyimage
|
19
24
|
|
20
|
-
from .logger import log_function_call
|
21
|
-
|
22
25
|
def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
|
23
26
|
"""Plot image and mask overlays."""
|
24
27
|
|
@@ -123,7 +126,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
|
|
123
126
|
|
124
127
|
fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
|
125
128
|
|
126
|
-
return
|
129
|
+
return fig
|
127
130
|
|
128
131
|
def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
|
129
132
|
"""
|
@@ -409,7 +412,7 @@ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, thr
|
|
409
412
|
plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile, overlay, save=False)
|
410
413
|
return
|
411
414
|
|
412
|
-
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max,
|
415
|
+
def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, nuclei_limit, pathogen_limit):
|
413
416
|
"""
|
414
417
|
Filters objects in a plot based on various criteria.
|
415
418
|
|
@@ -420,8 +423,8 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
420
423
|
pathogen_mask_dim (int): The dimension index of the pathogen mask.
|
421
424
|
mask_dims (list): A list of dimension indices for additional masks.
|
422
425
|
filter_min_max (list): A list of minimum and maximum area values for each mask.
|
423
|
-
|
424
|
-
|
426
|
+
nuclei_limit (bool): Whether to include multinucleated cells.
|
427
|
+
pathogen_limit (bool): Whether to include multiinfected cells.
|
425
428
|
|
426
429
|
Returns:
|
427
430
|
numpy.ndarray: The filtered stack of masks.
|
@@ -451,9 +454,9 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
|
|
451
454
|
total_count_after = len(props_after['label'])
|
452
455
|
|
453
456
|
if mask_dim == cell_mask_dim:
|
454
|
-
if
|
457
|
+
if nuclei_limit is False and nucleus_mask_dim is not None:
|
455
458
|
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
|
456
|
-
if
|
459
|
+
if pathogen_limit is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
|
457
460
|
stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
|
458
461
|
cell_area_before = avg_size_before
|
459
462
|
cell_count_before = total_count_before
|
@@ -662,18 +665,18 @@ def plot_merged(src, settings):
|
|
662
665
|
display(settings)
|
663
666
|
|
664
667
|
if settings['pathogen_mask_dim'] is None:
|
665
|
-
settings['
|
668
|
+
settings['pathogen_limit'] = True
|
666
669
|
|
667
670
|
for file in os.listdir(src):
|
668
671
|
path = os.path.join(src, file)
|
669
672
|
stack = np.load(path)
|
670
673
|
print(f'Loaded: {path}')
|
671
|
-
if not settings['
|
674
|
+
if not settings['uninfected']:
|
672
675
|
if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
|
673
676
|
stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
|
674
677
|
|
675
|
-
if settings['
|
676
|
-
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['
|
678
|
+
if settings['pathogen_limit'] is not True or settings['nuclei_limit'] is not True or settings['filter_min_max'] is not None:
|
679
|
+
stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['nuclei_limit'], settings['pathogen_limit'])
|
677
680
|
|
678
681
|
overlayed_image, image, outlines = _normalize_and_outline(image=stack,
|
679
682
|
remove_background=settings['remove_background'],
|
@@ -1359,7 +1362,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
|
|
1359
1362
|
|
1360
1363
|
return plate_map, min_max
|
1361
1364
|
|
1362
|
-
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
|
1365
|
+
def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
|
1363
1366
|
plates = df['prc'].str.split('_', expand=True)[0].unique()
|
1364
1367
|
n_rows, n_cols = (len(plates) + 3) // 4, 4
|
1365
1368
|
fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
|
@@ -1374,6 +1377,12 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
|
|
1374
1377
|
fig.delaxes(ax[i])
|
1375
1378
|
|
1376
1379
|
plt.subplots_adjust(wspace=0.1, hspace=0.4)
|
1380
|
+
|
1381
|
+
if not dst is None:
|
1382
|
+
filename = os.path.join(dst, 'plate_heatmap.pdf')
|
1383
|
+
fig.savefig(filename, format='pdf')
|
1384
|
+
print(f'Saved heatmap to {filename}')
|
1385
|
+
|
1377
1386
|
if verbose:
|
1378
1387
|
plt.show()
|
1379
1388
|
return fig
|
@@ -1605,13 +1614,19 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
|
|
1605
1614
|
print(f'Saved Volcano plot: {filename}')
|
1606
1615
|
plt.show()
|
1607
1616
|
|
1608
|
-
def plot_histogram(df, dependent_variable):
|
1617
|
+
def plot_histogram(df, dependent_variable, dst=None):
|
1609
1618
|
# Plot histogram of the dependent variable
|
1610
1619
|
plt.figure(figsize=(10, 6))
|
1611
1620
|
sns.histplot(df[dependent_variable], kde=True)
|
1612
1621
|
plt.title(f'Histogram of {dependent_variable}')
|
1613
1622
|
plt.xlabel(dependent_variable)
|
1614
1623
|
plt.ylabel('Frequency')
|
1624
|
+
|
1625
|
+
if not dst is None:
|
1626
|
+
filename = os.path.join(dst, 'dependent_variable_histogram.pdf')
|
1627
|
+
plt.savefig(filename, format='pdf')
|
1628
|
+
print(f'Saved histogram to {filename}')
|
1629
|
+
|
1615
1630
|
plt.show()
|
1616
1631
|
|
1617
1632
|
def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
|
@@ -1732,4 +1747,254 @@ def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time
|
|
1732
1747
|
plt.ylim(y_lim)
|
1733
1748
|
plt.show()
|
1734
1749
|
else:
|
1735
|
-
print("No CSV files found in the specified directory.")
|
1750
|
+
print("No CSV files found in the specified directory.")
|
1751
|
+
|
1752
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
1753
|
+
"""
|
1754
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
1755
|
+
|
1756
|
+
Args:
|
1757
|
+
src (str): Path to the source data.
|
1758
|
+
x_column (str): Name of the column to be used for the x-axis.
|
1759
|
+
y_column (str): Name of the column to be used for the y-axis.
|
1760
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
1761
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
1762
|
+
|
1763
|
+
Returns:
|
1764
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
1765
|
+
"""
|
1766
|
+
|
1767
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
1768
|
+
from .io import _read_and_merge_data, _read_db
|
1769
|
+
db_loc = [src+'/measurements/measurements.db']
|
1770
|
+
loc = src+'/measurements/measurements.db'
|
1771
|
+
df, _ = _read_and_merge_data(db_loc,
|
1772
|
+
tables,
|
1773
|
+
verbose=True,
|
1774
|
+
nuclei_limit=True,
|
1775
|
+
pathogen_limit=True,
|
1776
|
+
uninfected=True)
|
1777
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
1778
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
1779
|
+
return merged_df
|
1780
|
+
|
1781
|
+
# Read the CSV file into a DataFrame
|
1782
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
1783
|
+
|
1784
|
+
# Print column names for debugging
|
1785
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
1786
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
1787
|
+
|
1788
|
+
# Replace NaN values with a specific label in x_column
|
1789
|
+
df[x_column] = df[x_column].fillna('NaN')
|
1790
|
+
|
1791
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
1792
|
+
if not filter_column is None:
|
1793
|
+
if isinstance(filter_column, str):
|
1794
|
+
df = df[df[filter_column].isin(filter_values)]
|
1795
|
+
if isinstance(filter_column, list):
|
1796
|
+
for i,val in enumerate(filter_column):
|
1797
|
+
print(f'hello {len(df)}')
|
1798
|
+
df = df[df[val].isin(filter_values[i])]
|
1799
|
+
|
1800
|
+
# Use the correct column names based on your DataFrame
|
1801
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
1802
|
+
if not all(column in df.columns for column in required_columns):
|
1803
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
1804
|
+
|
1805
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
1806
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
1807
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
1808
|
+
|
1809
|
+
# Determine the minimum count of examples across all groups in x_column
|
1810
|
+
min_count = retained_rows[x_column].value_counts().min()
|
1811
|
+
print(f'Found {min_count} annotated images')
|
1812
|
+
|
1813
|
+
# Randomly sample min_count examples from each group in x_column
|
1814
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
1815
|
+
|
1816
|
+
# Create the jitter plot
|
1817
|
+
plt.figure(figsize=(10, 6))
|
1818
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
1819
|
+
plt.title(plot_title)
|
1820
|
+
plt.xlabel(x_column)
|
1821
|
+
plt.ylabel(y_column)
|
1822
|
+
|
1823
|
+
# Customize the x-axis labels
|
1824
|
+
plt.xticks(rotation=45, ha='right')
|
1825
|
+
|
1826
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
1827
|
+
ax = plt.gca()
|
1828
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
1829
|
+
|
1830
|
+
# Save the plot to a file or display it
|
1831
|
+
if output_path:
|
1832
|
+
plt.savefig(output_path, bbox_inches='tight')
|
1833
|
+
print(f"Jitter plot saved to {output_path}")
|
1834
|
+
else:
|
1835
|
+
plt.show()
|
1836
|
+
|
1837
|
+
return balanced_df
|
1838
|
+
|
1839
|
+
def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_axis_start=None, error_bar_type='std'):
|
1840
|
+
"""
|
1841
|
+
Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
|
1842
|
+
|
1843
|
+
Parameters:
|
1844
|
+
- df: DataFrame containing the data.
|
1845
|
+
- grouping_column: Column name for the categorical grouping.
|
1846
|
+
- data_column: Column name for the data to be grouped and plotted.
|
1847
|
+
- graph_type: Type of plot ('bar', 'violin', 'jitter', 'box', 'jitter_box').
|
1848
|
+
- summary_func: Summary function to apply to each group ('mean', 'median', etc.).
|
1849
|
+
- order: List specifying the order of the groups. If None, groups will be ordered alphabetically.
|
1850
|
+
- colors: List of colors for each group.
|
1851
|
+
- output_dir: Directory where the figure and test results will be saved if `save=True`.
|
1852
|
+
- save: Boolean flag indicating whether to save the plot and results to files.
|
1853
|
+
- y_axis_start: Optional starting value for the y-axis.
|
1854
|
+
- error_bar_type: Type of error bars to plot, either 'std' for standard deviation or 'sem' for standard error of the mean.
|
1855
|
+
|
1856
|
+
Outputs:
|
1857
|
+
- Figure of the plot.
|
1858
|
+
- DataFrame with full statistical test results, including normality tests.
|
1859
|
+
"""
|
1860
|
+
|
1861
|
+
# Remove NaN rows in grouping_column
|
1862
|
+
df = df.dropna(subset=[grouping_column])
|
1863
|
+
|
1864
|
+
# Ensure the output directory exists if save is True
|
1865
|
+
if save:
|
1866
|
+
os.makedirs(output_dir, exist_ok=True)
|
1867
|
+
|
1868
|
+
# Sorting and ordering
|
1869
|
+
if order:
|
1870
|
+
df[grouping_column] = pd.Categorical(df[grouping_column], categories=order, ordered=True)
|
1871
|
+
else:
|
1872
|
+
df[grouping_column] = pd.Categorical(df[grouping_column], categories=sorted(df[grouping_column].unique()), ordered=True)
|
1873
|
+
|
1874
|
+
# Get unique groups
|
1875
|
+
unique_groups = df[grouping_column].unique()
|
1876
|
+
|
1877
|
+
# Initialize test results
|
1878
|
+
test_results = []
|
1879
|
+
|
1880
|
+
# Test normality for each group
|
1881
|
+
grouped_data = [df.loc[df[grouping_column] == group, data_column] for group in unique_groups]
|
1882
|
+
normal_p_values = [normaltest(data).pvalue for data in grouped_data]
|
1883
|
+
normal_stats = [normaltest(data).statistic for data in grouped_data]
|
1884
|
+
is_normal = all(p > 0.05 for p in normal_p_values)
|
1885
|
+
|
1886
|
+
# Add normality test results to the results_df
|
1887
|
+
for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
|
1888
|
+
test_results.append({
|
1889
|
+
'Comparison': f'Normality test for {group}',
|
1890
|
+
'Test Statistic': stat,
|
1891
|
+
'p-value': p_value,
|
1892
|
+
'Test Name': 'Normality test'
|
1893
|
+
})
|
1894
|
+
|
1895
|
+
# Determine statistical test
|
1896
|
+
if len(unique_groups) == 2:
|
1897
|
+
if is_normal:
|
1898
|
+
stat_test = ttest_ind
|
1899
|
+
test_name = 'T-test'
|
1900
|
+
else:
|
1901
|
+
stat_test = mannwhitneyu
|
1902
|
+
test_name = 'Mann-Whitney U test'
|
1903
|
+
else:
|
1904
|
+
if is_normal:
|
1905
|
+
stat_test = f_oneway
|
1906
|
+
test_name = 'One-way ANOVA'
|
1907
|
+
else:
|
1908
|
+
stat_test = kruskal
|
1909
|
+
test_name = 'Kruskal-Wallis test'
|
1910
|
+
|
1911
|
+
# Perform pairwise statistical tests
|
1912
|
+
comparisons = list(itertools.combinations(unique_groups, 2))
|
1913
|
+
p_values = []
|
1914
|
+
test_statistics = []
|
1915
|
+
|
1916
|
+
for (group1, group2) in comparisons:
|
1917
|
+
data1 = df[df[grouping_column] == group1][data_column]
|
1918
|
+
data2 = df[df[grouping_column] == group2][data_column]
|
1919
|
+
stat, p = stat_test(data1, data2)
|
1920
|
+
p_values.append(p)
|
1921
|
+
test_statistics.append(stat)
|
1922
|
+
test_results.append({'Comparison': f'{group1} vs {group2}', 'Test Statistic': stat, 'p-value': p, 'Test Name': test_name})
|
1923
|
+
|
1924
|
+
# Post-hoc test (Tukey HSD for ANOVA)
|
1925
|
+
posthoc_p_values = None
|
1926
|
+
if is_normal and len(unique_groups) > 2:
|
1927
|
+
tukey_result = pairwise_tukeyhsd(df[data_column], df[grouping_column], alpha=0.05)
|
1928
|
+
posthoc_p_values = tukey_result.pvalues
|
1929
|
+
for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
|
1930
|
+
test_results.append({
|
1931
|
+
'Comparison': f'{comparison[0]} vs {comparison[1]}',
|
1932
|
+
'Test Statistic': None, # Tukey does not provide a test statistic in the same way
|
1933
|
+
'p-value': p_value,
|
1934
|
+
'Test Name': 'Tukey HSD Post-hoc'
|
1935
|
+
})
|
1936
|
+
|
1937
|
+
# Create plot
|
1938
|
+
plt.figure(figsize=(10, 6))
|
1939
|
+
sns.set(style="whitegrid")
|
1940
|
+
|
1941
|
+
if colors:
|
1942
|
+
color_palette = colors
|
1943
|
+
else:
|
1944
|
+
color_palette = sns.color_palette("husl", len(unique_groups))
|
1945
|
+
|
1946
|
+
# Choose graph type
|
1947
|
+
if graph_type == 'bar':
|
1948
|
+
summary_df = df.groupby(grouping_column)[data_column].agg([summary_func, 'std', 'sem'])
|
1949
|
+
|
1950
|
+
# Set error bars based on error_bar_type
|
1951
|
+
if error_bar_type == 'std':
|
1952
|
+
error_bars = summary_df['std']
|
1953
|
+
elif error_bar_type == 'sem':
|
1954
|
+
error_bars = summary_df['sem']
|
1955
|
+
else:
|
1956
|
+
raise ValueError(f"Invalid error_bar_type: {error_bar_type}. Choose either 'std' or 'sem'.")
|
1957
|
+
|
1958
|
+
sns.barplot(x=grouping_column, y=summary_func, data=summary_df.reset_index(), ci=None, order=order, palette=color_palette)
|
1959
|
+
|
1960
|
+
# Add error bars (standard deviation or standard error of the mean)
|
1961
|
+
plt.errorbar(x=np.arange(len(summary_df)), y=summary_df[summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
|
1962
|
+
|
1963
|
+
elif graph_type == 'violin':
|
1964
|
+
sns.violinplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
1965
|
+
elif graph_type == 'jitter':
|
1966
|
+
sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, order=order, palette=color_palette)
|
1967
|
+
elif graph_type == 'box':
|
1968
|
+
sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
1969
|
+
elif graph_type == 'jitter_box':
|
1970
|
+
sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
|
1971
|
+
sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, color='black', alpha=0.5, order=order)
|
1972
|
+
|
1973
|
+
# Create a DataFrame to summarize the test results
|
1974
|
+
results_df = pd.DataFrame(test_results)
|
1975
|
+
|
1976
|
+
# Set y-axis start if provided
|
1977
|
+
if y_axis_start is not None:
|
1978
|
+
plt.ylim(bottom=y_axis_start)
|
1979
|
+
else:
|
1980
|
+
plt.ylim(0, None) # Default to starting at 0 if no custom start value is provided
|
1981
|
+
|
1982
|
+
# If save is True, save the plot and results as PNG and CSV
|
1983
|
+
if save:
|
1984
|
+
# Save the plot as PNG
|
1985
|
+
plot_path = os.path.join(output_dir, 'grouped_plot.png')
|
1986
|
+
plt.title(f'{test_name} results for {graph_type} plot')
|
1987
|
+
plt.xticks(rotation=45)
|
1988
|
+
plt.tight_layout()
|
1989
|
+
plt.savefig(plot_path)
|
1990
|
+
print(f"Plot saved to {plot_path}")
|
1991
|
+
|
1992
|
+
# Save the test results as a CSV file
|
1993
|
+
results_path = os.path.join(output_dir, 'test_results.csv')
|
1994
|
+
results_df.to_csv(results_path, index=False)
|
1995
|
+
print(f"Test results saved to {results_path}")
|
1996
|
+
|
1997
|
+
# Show the plot
|
1998
|
+
plt.show()
|
1999
|
+
|
2000
|
+
return plt.gcf(), results_df
|