spacr 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +142 -2495
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +631 -51
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -5
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +281 -16
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +275 -173
  27. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/openai.py ADDED
@@ -0,0 +1,37 @@
1
+ #from openai import OpenAI, APIError, OpenAIError
2
+ #import openai
3
+ #
4
+ #class Chatbot:
5
+ # def __init__(self, api_key):
6
+ # openai.api_key = api_key
7
+ #
8
+ # def ask_question(self, question):
9
+ # try:
10
+ # # Sending the request to the gpt-4o-mini model
11
+ # response = openai.chat.completions.create(
12
+ # model="gpt-3.5-turbo", # Correct model name for your setup
13
+ # messages=[
14
+ # {"role": "user", "content": question}
15
+ # ],
16
+ # max_tokens=150
17
+ # )
18
+ # # Extracting the response from the model and returning it
19
+ # return response.choices[0].message['content'].strip()
20
+ # except (APIError, OpenAIError) as e:
21
+ # return f"Error: {str(e)}"
22
+ #
23
+ #def list_available_models():
24
+ # try:
25
+ # # List available models
26
+ # models = openai.models.list()
27
+ #
28
+ # # Iterate through models directly
29
+ # for model in models:
30
+ # print(model.id) # Print the model ID
31
+ # except (APIError, OpenAIError) as e:
32
+ # print(f"Error: {str(e)}")
33
+
34
+ # Replace with your actual API key or ensure it's set as an environment variable
35
+
36
+ #openai.api_key = api_key
37
+ #list_available_models()
spacr/plot.py CHANGED
@@ -1,4 +1,4 @@
1
- import os,re, random, cv2, glob, time, math, torch
1
+ import os, random, cv2, glob, math, torch
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -14,11 +14,14 @@ from skimage.segmentation import find_boundaries
14
14
  from skimage import measure
15
15
  from skimage.measure import find_contours, label, regionprops
16
16
 
17
+ from scipy.stats import normaltest, ttest_ind, mannwhitneyu, f_oneway, kruskal
18
+ from statsmodels.stats.multicomp import pairwise_tukeyhsd
19
+ import itertools
20
+
21
+
17
22
  from ipywidgets import IntSlider, interact
18
23
  from IPython.display import Image as ipyimage
19
24
 
20
- from .logger import log_function_call
21
-
22
25
  def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, normalize=True, thickness=3, save_pdf=True):
23
26
  """Plot image and mask overlays."""
24
27
 
@@ -123,7 +126,7 @@ def plot_image_mask_overlay(file, channels, cell_channel, nucleus_channel, patho
123
126
 
124
127
  fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness)
125
128
 
126
- return
129
+ return fig
127
130
 
128
131
  def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
129
132
  """
@@ -409,7 +412,7 @@ def plot_images_and_arrays(folders, lower_percentile=1, upper_percentile=99, thr
409
412
  plot_from_file_dict(file_dict, threshold, lower_percentile, upper_percentile, overlay, save=False)
410
413
  return
411
414
 
412
- def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, include_multinucleated, include_multiinfected):
415
+ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, mask_dims, filter_min_max, nuclei_limit, pathogen_limit):
413
416
  """
414
417
  Filters objects in a plot based on various criteria.
415
418
 
@@ -420,8 +423,8 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
420
423
  pathogen_mask_dim (int): The dimension index of the pathogen mask.
421
424
  mask_dims (list): A list of dimension indices for additional masks.
422
425
  filter_min_max (list): A list of minimum and maximum area values for each mask.
423
- include_multinucleated (bool): Whether to include multinucleated cells.
424
- include_multiinfected (bool): Whether to include multiinfected cells.
426
+ nuclei_limit (bool): Whether to include multinucleated cells.
427
+ pathogen_limit (bool): Whether to include multiinfected cells.
425
428
 
426
429
  Returns:
427
430
  numpy.ndarray: The filtered stack of masks.
@@ -451,9 +454,9 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
451
454
  total_count_after = len(props_after['label'])
452
455
 
453
456
  if mask_dim == cell_mask_dim:
454
- if include_multinucleated is False and nucleus_mask_dim is not None:
457
+ if nuclei_limit is False and nucleus_mask_dim is not None:
455
458
  stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=pathogen_mask_dim)
456
- if include_multiinfected is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
459
+ if pathogen_limit is False and cell_mask_dim is not None and pathogen_mask_dim is not None:
457
460
  stack = _remove_multiobject_cells(stack, mask_dim, cell_mask_dim, nucleus_mask_dim, pathogen_mask_dim, object_dim=nucleus_mask_dim)
458
461
  cell_area_before = avg_size_before
459
462
  cell_count_before = total_count_before
@@ -662,18 +665,18 @@ def plot_merged(src, settings):
662
665
  display(settings)
663
666
 
664
667
  if settings['pathogen_mask_dim'] is None:
665
- settings['include_multiinfected'] = True
668
+ settings['pathogen_limit'] = True
666
669
 
667
670
  for file in os.listdir(src):
668
671
  path = os.path.join(src, file)
669
672
  stack = np.load(path)
670
673
  print(f'Loaded: {path}')
671
- if not settings['include_noninfected']:
674
+ if not settings['uninfected']:
672
675
  if settings['pathogen_mask_dim'] is not None and settings['cell_mask_dim'] is not None:
673
676
  stack = _remove_noninfected(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'])
674
677
 
675
- if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
676
- stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
678
+ if settings['pathogen_limit'] is not True or settings['nuclei_limit'] is not True or settings['filter_min_max'] is not None:
679
+ stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['nuclei_limit'], settings['pathogen_limit'])
677
680
 
678
681
  overlayed_image, image, outlines = _normalize_and_outline(image=stack,
679
682
  remove_background=settings['remove_background'],
@@ -1359,7 +1362,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1359
1362
 
1360
1363
  return plate_map, min_max
1361
1364
 
1362
- def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True):
1365
+ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True, dst=None):
1363
1366
  plates = df['prc'].str.split('_', expand=True)[0].unique()
1364
1367
  n_rows, n_cols = (len(plates) + 3) // 4, 4
1365
1368
  fig, ax = plt.subplots(n_rows, n_cols, figsize=(40, 5 * n_rows))
@@ -1374,6 +1377,12 @@ def plot_plates(df, variable, grouping, min_max, cmap, min_count=0, verbose=True
1374
1377
  fig.delaxes(ax[i])
1375
1378
 
1376
1379
  plt.subplots_adjust(wspace=0.1, hspace=0.4)
1380
+
1381
+ if not dst is None:
1382
+ filename = os.path.join(dst, 'plate_heatmap.pdf')
1383
+ fig.savefig(filename, format='pdf')
1384
+ print(f'Saved heatmap to {filename}')
1385
+
1377
1386
  if verbose:
1378
1387
  plt.show()
1379
1388
  return fig
@@ -1605,13 +1614,19 @@ def volcano_plot(coef_df, filename='volcano_plot.pdf'):
1605
1614
  print(f'Saved Volcano plot: {filename}')
1606
1615
  plt.show()
1607
1616
 
1608
- def plot_histogram(df, dependent_variable):
1617
+ def plot_histogram(df, dependent_variable, dst=None):
1609
1618
  # Plot histogram of the dependent variable
1610
1619
  plt.figure(figsize=(10, 6))
1611
1620
  sns.histplot(df[dependent_variable], kde=True)
1612
1621
  plt.title(f'Histogram of {dependent_variable}')
1613
1622
  plt.xlabel(dependent_variable)
1614
1623
  plt.ylabel('Frequency')
1624
+
1625
+ if not dst is None:
1626
+ filename = os.path.join(dst, 'dependent_variable_histogram.pdf')
1627
+ plt.savefig(filename, format='pdf')
1628
+ print(f'Saved histogram to {filename}')
1629
+
1615
1630
  plt.show()
1616
1631
 
1617
1632
  def plot_lorenz_curves(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
@@ -1732,4 +1747,254 @@ def read_and_plot__vision_results(base_dir, y_axis='accuracy', name_split='_time
1732
1747
  plt.ylim(y_lim)
1733
1748
  plt.show()
1734
1749
  else:
1735
- print("No CSV files found in the specified directory.")
1750
+ print("No CSV files found in the specified directory.")
1751
+
1752
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
1753
+ """
1754
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
1755
+
1756
+ Args:
1757
+ src (str): Path to the source data.
1758
+ x_column (str): Name of the column to be used for the x-axis.
1759
+ y_column (str): Name of the column to be used for the y-axis.
1760
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
1761
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
1762
+
1763
+ Returns:
1764
+ pd.DataFrame: The filtered and balanced DataFrame.
1765
+ """
1766
+
1767
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
1768
+ from .io import _read_and_merge_data, _read_db
1769
+ db_loc = [src+'/measurements/measurements.db']
1770
+ loc = src+'/measurements/measurements.db'
1771
+ df, _ = _read_and_merge_data(db_loc,
1772
+ tables,
1773
+ verbose=True,
1774
+ nuclei_limit=True,
1775
+ pathogen_limit=True,
1776
+ uninfected=True)
1777
+ paths_df = _read_db(loc, tables=['png_list'])
1778
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
1779
+ return merged_df
1780
+
1781
+ # Read the CSV file into a DataFrame
1782
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
1783
+
1784
+ # Print column names for debugging
1785
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
1786
+ #print("Columns in DataFrame:", df.columns.tolist())
1787
+
1788
+ # Replace NaN values with a specific label in x_column
1789
+ df[x_column] = df[x_column].fillna('NaN')
1790
+
1791
+ # Filter the DataFrame if filter_column and filter_values are provided
1792
+ if not filter_column is None:
1793
+ if isinstance(filter_column, str):
1794
+ df = df[df[filter_column].isin(filter_values)]
1795
+ if isinstance(filter_column, list):
1796
+ for i,val in enumerate(filter_column):
1797
+ print(f'hello {len(df)}')
1798
+ df = df[df[val].isin(filter_values[i])]
1799
+
1800
+ # Use the correct column names based on your DataFrame
1801
+ required_columns = ['plate_x', 'row_x', 'col_x']
1802
+ if not all(column in df.columns for column in required_columns):
1803
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
1804
+
1805
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
1806
+ non_nan_df = df[df[x_column] != 'NaN']
1807
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
1808
+
1809
+ # Determine the minimum count of examples across all groups in x_column
1810
+ min_count = retained_rows[x_column].value_counts().min()
1811
+ print(f'Found {min_count} annotated images')
1812
+
1813
+ # Randomly sample min_count examples from each group in x_column
1814
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
1815
+
1816
+ # Create the jitter plot
1817
+ plt.figure(figsize=(10, 6))
1818
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
1819
+ plt.title(plot_title)
1820
+ plt.xlabel(x_column)
1821
+ plt.ylabel(y_column)
1822
+
1823
+ # Customize the x-axis labels
1824
+ plt.xticks(rotation=45, ha='right')
1825
+
1826
+ # Adjust the position of the x-axis labels to be centered below the data
1827
+ ax = plt.gca()
1828
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
1829
+
1830
+ # Save the plot to a file or display it
1831
+ if output_path:
1832
+ plt.savefig(output_path, bbox_inches='tight')
1833
+ print(f"Jitter plot saved to {output_path}")
1834
+ else:
1835
+ plt.show()
1836
+
1837
+ return balanced_df
1838
+
1839
+ def create_grouped_plot(df, grouping_column, data_column, graph_type='bar', summary_func='mean', order=None, colors=None, output_dir='./output', save=False, y_axis_start=None, error_bar_type='std'):
1840
+ """
1841
+ Create a grouped plot, perform statistical tests, and optionally export the results along with the plot.
1842
+
1843
+ Parameters:
1844
+ - df: DataFrame containing the data.
1845
+ - grouping_column: Column name for the categorical grouping.
1846
+ - data_column: Column name for the data to be grouped and plotted.
1847
+ - graph_type: Type of plot ('bar', 'violin', 'jitter', 'box', 'jitter_box').
1848
+ - summary_func: Summary function to apply to each group ('mean', 'median', etc.).
1849
+ - order: List specifying the order of the groups. If None, groups will be ordered alphabetically.
1850
+ - colors: List of colors for each group.
1851
+ - output_dir: Directory where the figure and test results will be saved if `save=True`.
1852
+ - save: Boolean flag indicating whether to save the plot and results to files.
1853
+ - y_axis_start: Optional starting value for the y-axis.
1854
+ - error_bar_type: Type of error bars to plot, either 'std' for standard deviation or 'sem' for standard error of the mean.
1855
+
1856
+ Outputs:
1857
+ - Figure of the plot.
1858
+ - DataFrame with full statistical test results, including normality tests.
1859
+ """
1860
+
1861
+ # Remove NaN rows in grouping_column
1862
+ df = df.dropna(subset=[grouping_column])
1863
+
1864
+ # Ensure the output directory exists if save is True
1865
+ if save:
1866
+ os.makedirs(output_dir, exist_ok=True)
1867
+
1868
+ # Sorting and ordering
1869
+ if order:
1870
+ df[grouping_column] = pd.Categorical(df[grouping_column], categories=order, ordered=True)
1871
+ else:
1872
+ df[grouping_column] = pd.Categorical(df[grouping_column], categories=sorted(df[grouping_column].unique()), ordered=True)
1873
+
1874
+ # Get unique groups
1875
+ unique_groups = df[grouping_column].unique()
1876
+
1877
+ # Initialize test results
1878
+ test_results = []
1879
+
1880
+ # Test normality for each group
1881
+ grouped_data = [df.loc[df[grouping_column] == group, data_column] for group in unique_groups]
1882
+ normal_p_values = [normaltest(data).pvalue for data in grouped_data]
1883
+ normal_stats = [normaltest(data).statistic for data in grouped_data]
1884
+ is_normal = all(p > 0.05 for p in normal_p_values)
1885
+
1886
+ # Add normality test results to the results_df
1887
+ for group, stat, p_value in zip(unique_groups, normal_stats, normal_p_values):
1888
+ test_results.append({
1889
+ 'Comparison': f'Normality test for {group}',
1890
+ 'Test Statistic': stat,
1891
+ 'p-value': p_value,
1892
+ 'Test Name': 'Normality test'
1893
+ })
1894
+
1895
+ # Determine statistical test
1896
+ if len(unique_groups) == 2:
1897
+ if is_normal:
1898
+ stat_test = ttest_ind
1899
+ test_name = 'T-test'
1900
+ else:
1901
+ stat_test = mannwhitneyu
1902
+ test_name = 'Mann-Whitney U test'
1903
+ else:
1904
+ if is_normal:
1905
+ stat_test = f_oneway
1906
+ test_name = 'One-way ANOVA'
1907
+ else:
1908
+ stat_test = kruskal
1909
+ test_name = 'Kruskal-Wallis test'
1910
+
1911
+ # Perform pairwise statistical tests
1912
+ comparisons = list(itertools.combinations(unique_groups, 2))
1913
+ p_values = []
1914
+ test_statistics = []
1915
+
1916
+ for (group1, group2) in comparisons:
1917
+ data1 = df[df[grouping_column] == group1][data_column]
1918
+ data2 = df[df[grouping_column] == group2][data_column]
1919
+ stat, p = stat_test(data1, data2)
1920
+ p_values.append(p)
1921
+ test_statistics.append(stat)
1922
+ test_results.append({'Comparison': f'{group1} vs {group2}', 'Test Statistic': stat, 'p-value': p, 'Test Name': test_name})
1923
+
1924
+ # Post-hoc test (Tukey HSD for ANOVA)
1925
+ posthoc_p_values = None
1926
+ if is_normal and len(unique_groups) > 2:
1927
+ tukey_result = pairwise_tukeyhsd(df[data_column], df[grouping_column], alpha=0.05)
1928
+ posthoc_p_values = tukey_result.pvalues
1929
+ for comparison, p_value in zip(tukey_result._results_table.data[1:], tukey_result.pvalues):
1930
+ test_results.append({
1931
+ 'Comparison': f'{comparison[0]} vs {comparison[1]}',
1932
+ 'Test Statistic': None, # Tukey does not provide a test statistic in the same way
1933
+ 'p-value': p_value,
1934
+ 'Test Name': 'Tukey HSD Post-hoc'
1935
+ })
1936
+
1937
+ # Create plot
1938
+ plt.figure(figsize=(10, 6))
1939
+ sns.set(style="whitegrid")
1940
+
1941
+ if colors:
1942
+ color_palette = colors
1943
+ else:
1944
+ color_palette = sns.color_palette("husl", len(unique_groups))
1945
+
1946
+ # Choose graph type
1947
+ if graph_type == 'bar':
1948
+ summary_df = df.groupby(grouping_column)[data_column].agg([summary_func, 'std', 'sem'])
1949
+
1950
+ # Set error bars based on error_bar_type
1951
+ if error_bar_type == 'std':
1952
+ error_bars = summary_df['std']
1953
+ elif error_bar_type == 'sem':
1954
+ error_bars = summary_df['sem']
1955
+ else:
1956
+ raise ValueError(f"Invalid error_bar_type: {error_bar_type}. Choose either 'std' or 'sem'.")
1957
+
1958
+ sns.barplot(x=grouping_column, y=summary_func, data=summary_df.reset_index(), ci=None, order=order, palette=color_palette)
1959
+
1960
+ # Add error bars (standard deviation or standard error of the mean)
1961
+ plt.errorbar(x=np.arange(len(summary_df)), y=summary_df[summary_func], yerr=error_bars, fmt='none', c='black', capsize=5)
1962
+
1963
+ elif graph_type == 'violin':
1964
+ sns.violinplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
1965
+ elif graph_type == 'jitter':
1966
+ sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, order=order, palette=color_palette)
1967
+ elif graph_type == 'box':
1968
+ sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
1969
+ elif graph_type == 'jitter_box':
1970
+ sns.boxplot(x=grouping_column, y=data_column, data=df, order=order, palette=color_palette)
1971
+ sns.stripplot(x=grouping_column, y=data_column, data=df, jitter=True, color='black', alpha=0.5, order=order)
1972
+
1973
+ # Create a DataFrame to summarize the test results
1974
+ results_df = pd.DataFrame(test_results)
1975
+
1976
+ # Set y-axis start if provided
1977
+ if y_axis_start is not None:
1978
+ plt.ylim(bottom=y_axis_start)
1979
+ else:
1980
+ plt.ylim(0, None) # Default to starting at 0 if no custom start value is provided
1981
+
1982
+ # If save is True, save the plot and results as PNG and CSV
1983
+ if save:
1984
+ # Save the plot as PNG
1985
+ plot_path = os.path.join(output_dir, 'grouped_plot.png')
1986
+ plt.title(f'{test_name} results for {graph_type} plot')
1987
+ plt.xticks(rotation=45)
1988
+ plt.tight_layout()
1989
+ plt.savefig(plot_path)
1990
+ print(f"Plot saved to {plot_path}")
1991
+
1992
+ # Save the test results as a CSV file
1993
+ results_path = os.path.join(output_dir, 'test_results.csv')
1994
+ results_df.to_csv(results_path, index=False)
1995
+ print(f"Test results saved to {results_path}")
1996
+
1997
+ # Show the plot
1998
+ plt.show()
1999
+
2000
+ return plt.gcf(), results_df