spacr 0.3.50__py3-none-any.whl → 0.3.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/plot.py CHANGED
@@ -16,6 +16,7 @@ from skimage import measure
16
16
  from skimage.measure import find_contours, label, regionprops
17
17
  from skimage.segmentation import mark_boundaries
18
18
  from skimage.transform import resize as sk_resize
19
+ import scikit_posthocs as sp
19
20
 
20
21
  import tifffile as tiff
21
22
 
@@ -365,146 +366,6 @@ def plot_image_mask_overlay(
365
366
 
366
367
  return fig
367
368
 
368
- def plot_image_mask_overlay_v1(file, channels, cell_channel, nucleus_channel, pathogen_channel, figuresize=10, percentiles=(2,98), thickness=3, save_pdf=True, mode='outlines', export_tiffs=False):
369
- """Plot image and mask overlays."""
370
-
371
- def _plot_merged_plot(image, outlines, outline_colors, figuresize, thickness, percentiles, mode='outlines'):
372
- """Plot the merged plot with overlay, image channels, and masks."""
373
-
374
- def _generate_colored_mask(mask, alpha):
375
- """ Generate a colored mask with transparency using the given colormap. """
376
- cmap = generate_mask_random_cmap(mask)
377
- rgba_mask = cmap(mask / mask.max()) # Normalize mask and map to colormap (RGBA)
378
- rgba_mask[..., 3] = np.where(mask > 0, alpha, 0) # Apply transparency only where mask is present
379
- return rgba_mask
380
-
381
- def _overlay_mask(image, mask):
382
- """Overlay the colored mask onto the original image."""
383
- combined = np.clip(image + mask[..., :3] * mask[..., 3:4], 0, 1) # Ensure pixel values stay in [0, 1]
384
- return combined
385
-
386
- def _normalize_image(image, percentiles=(2, 98)):
387
- """Normalize the image to the given percentiles."""
388
- v_min, v_max = np.percentile(image, percentiles)
389
- image_normalized = np.clip((image - v_min) / (v_max - v_min), 0, 1)
390
- return image_normalized
391
-
392
- def _generate_contours(mask):
393
- """Generate contours for the given mask using OpenCV."""
394
- contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
395
- return contours
396
-
397
- def _apply_contours(image, mask, color, thickness):
398
- """Apply the contours to the RGB image for each unique label."""
399
- unique_labels = np.unique(mask)
400
- for label in unique_labels:
401
- if label == 0:
402
- continue # Skip background
403
- label_mask = np.where(mask == label, 1, 0).astype(np.uint8)
404
- contours = _generate_contours(label_mask)
405
- for contour in contours:
406
- cv2.drawContours(image, [contour], -1, mpl.colors.to_rgb(color), thickness)
407
- return image
408
-
409
- num_channels = image.shape[-1]
410
- fig, ax = plt.subplots(1, num_channels + 1, figsize=(4 * figuresize, figuresize))
411
-
412
- # Plot each channel with its corresponding outlines
413
- for v in range(num_channels):
414
- channel_image = image[..., v]
415
- channel_image_normalized = _normalize_image(channel_image, percentiles)
416
- channel_image_rgb = np.dstack((channel_image_normalized, channel_image_normalized, channel_image_normalized))
417
-
418
- for outline, color in zip(outlines, outline_colors):
419
- if mode == 'outlines':
420
- channel_image_rgb = _apply_contours(channel_image_rgb, outline, color, thickness)
421
- else:
422
- mask = _generate_colored_mask(outline, alpha=0.5)
423
- channel_image_rgb = _overlay_mask(channel_image_rgb, mask)
424
-
425
- ax[v].imshow(channel_image_rgb)
426
- ax[v].set_title(f'Image - Channel {v}')
427
-
428
- # Plot the combined RGB image with all outlines
429
- rgb_image = np.zeros((*image.shape[:2], 3), dtype=float)
430
- rgb_channels = min(3, num_channels)
431
- for i in range(rgb_channels):
432
- channel_image = image[..., i]
433
- channel_image_normalized = _normalize_image(channel_image, percentiles)
434
- rgb_image[..., i] = channel_image_normalized
435
-
436
- for outline, color in zip(outlines, outline_colors):
437
- if mode == 'outlines':
438
- rgb_image = _apply_contours(rgb_image, outline, color, thickness)
439
- else:
440
- mask = _generate_colored_mask(outline, alpha=0.5)
441
- rgb_image = _overlay_mask(rgb_image, mask)
442
-
443
- ax[-1].imshow(rgb_image)
444
- ax[-1].set_title('Combined RGB Image')
445
-
446
- plt.tight_layout()
447
-
448
- # Save the figure as a PDF
449
- if save_pdf:
450
- pdf_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', 'overlay')
451
- os.makedirs(pdf_dir, exist_ok=True)
452
- pdf_path = os.path.join(pdf_dir, os.path.basename(file).replace('.npy', '.pdf'))
453
- fig.savefig(pdf_path, format='pdf')
454
-
455
- plt.show()
456
- return fig
457
-
458
- def _save_channels_as_tiff(stack, save_dir, filename):
459
- """Save each channel in the stack as a grayscale TIFF."""
460
- os.makedirs(save_dir, exist_ok=True)
461
- for i in range(stack.shape[-1]):
462
- channel = stack[..., i]
463
- tiff_path = os.path.join(save_dir, f"{filename}_channel_{i}.tiff")
464
- tiff.imwrite(tiff_path, channel, photometric='minisblack')
465
- print(f"Saved {tiff_path}")
466
-
467
- stack = np.load(file)
468
-
469
- if export_tiffs:
470
- save_dir = os.path.join(os.path.dirname(os.path.dirname(file)), 'results', os.path.splitext(os.path.basename(file))[0], 'tiff')
471
- filename = os.path.splitext(os.path.basename(file))[0]
472
- _save_channels_as_tiff(stack, save_dir, filename)
473
-
474
- # Convert to float for normalization and ensure correct handling of both 8-bit and 16-bit arrays
475
- if stack.dtype == np.uint16:
476
- stack = stack.astype(np.float32)
477
- elif stack.dtype == np.uint8:
478
- stack = stack.astype(np.float32)
479
-
480
- image = stack[..., channels]
481
- outlines = []
482
- outline_colors = []
483
-
484
- if pathogen_channel is not None:
485
- pathogen_mask_dim = -1 # last dimension
486
- outlines.append(np.take(stack, pathogen_mask_dim, axis=2))
487
- outline_colors.append('blue')
488
-
489
- if nucleus_channel is not None:
490
- nucleus_mask_dim = -2 if pathogen_channel is not None else -1
491
- outlines.append(np.take(stack, nucleus_mask_dim, axis=2))
492
- outline_colors.append('green')
493
-
494
- if cell_channel is not None:
495
- if nucleus_channel is not None and pathogen_channel is not None:
496
- cell_mask_dim = -3
497
- elif nucleus_channel is not None or pathogen_channel is not None:
498
- cell_mask_dim = -2
499
- else:
500
- cell_mask_dim = -1
501
- outlines.append(np.take(stack, cell_mask_dim, axis=2))
502
- outline_colors.append('red')
503
-
504
- fig = _plot_merged_plot(image=image, outlines=outlines, outline_colors=outline_colors, figuresize=figuresize, thickness=thickness, percentiles=percentiles, mode=mode)
505
-
506
- return fig
507
-
508
369
  def plot_masks(batch, masks, flows, cmap='inferno', figuresize=10, nr=1, file_type='.npz', print_object_number=True):
509
370
  """
510
371
  Plot the masks and flows for a given batch of images.
@@ -1792,25 +1653,40 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1792
1653
  if not isinstance(min_count, (int, float)):
1793
1654
  min_count = 0
1794
1655
 
1795
- df = df.copy() # Work on a copy to avoid SettingWithCopyWarning
1796
- df['plate'], df['row'], df['col'] = zip(*df['prc'].str.split('_'))
1656
+ # Check the number of parts in 'prc'
1657
+ num_parts = len(df['prc'].iloc[0].split('_'))
1658
+ if num_parts == 4:
1659
+ split = df['prc'].str.split('_', expand=True)
1660
+ df['row_name'] = split[2]
1661
+ df['prc'] = f"{plate_number}" + '_' + split[2] + '_' + split[3]
1662
+
1663
+ # Construct 'prc' based on 'plate', 'row_name', and 'column' columns
1664
+ #df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column'].astype(str)
1665
+
1666
+ if 'column_name' not in df.columns:
1667
+ if 'column' in df.columns:
1668
+ df['column_name'] = df['column']
1669
+ if 'column_name' in df.columns:
1670
+ df['column_name'] = df['column_name']
1671
+
1672
+ df['plate'], df['row_name'], df['column_name'] = zip(*df['prc'].str.split('_'))
1797
1673
 
1798
1674
  # Filtering the dataframe based on the plate_number
1799
1675
  df = df[df['plate'] == plate_number].copy() # Create another copy after filtering
1800
-
1676
+
1801
1677
  # Ensure proper ordering
1802
1678
  row_order = [f'r{i}' for i in range(1, 17)]
1803
1679
  col_order = [f'c{i}' for i in range(1, 28)] # Exclude c15 as per your earlier code
1804
1680
 
1805
- df['row'] = pd.Categorical(df['row'], categories=row_order, ordered=True)
1806
- df['col'] = pd.Categorical(df['col'], categories=col_order, ordered=True)
1807
- df['count'] = df.groupby(['row', 'col'])['row'].transform('count')
1681
+ df['row_name'] = pd.Categorical(df['row_name'], categories=row_order, ordered=True)
1682
+ df['column_name'] = pd.Categorical(df['column_name'], categories=col_order, ordered=True)
1683
+ df['count'] = df.groupby(['row_name', 'column_name'])['row_name'].transform('count')
1808
1684
 
1809
1685
  if min_count > 0:
1810
1686
  df = df[df['count'] >= min_count]
1811
1687
 
1812
1688
  # Explicitly set observed=True to avoid FutureWarning
1813
- grouped = df.groupby(['row', 'col'], observed=True) # Group by row and column
1689
+ grouped = df.groupby(['row_name', 'column_name'], observed=True) # Group by row and column
1814
1690
 
1815
1691
  if grouping == 'mean':
1816
1692
  plate = grouped[variable].mean().reset_index()
@@ -1822,7 +1698,7 @@ def generate_plate_heatmap(df, plate_number, variable, grouping, min_max, min_co
1822
1698
  else:
1823
1699
  raise ValueError(f"Unsupported grouping: {grouping}")
1824
1700
 
1825
- plate_map = pd.pivot_table(plate, values=variable, index='row', columns='col').fillna(0)
1701
+ plate_map = pd.pivot_table(plate, values=variable, index='row_name', columns='column_name').fillna(0)
1826
1702
 
1827
1703
  if min_max == 'all':
1828
1704
  min_max = [plate_map.min().min(), plate_map.max().max()]
@@ -1964,81 +1840,6 @@ def print_mask_and_flows(stack, mask, flows, overlay=True, max_size=1000, thickn
1964
1840
 
1965
1841
  fig.tight_layout()
1966
1842
  plt.show()
1967
-
1968
- def print_mask_and_flows_v1(stack, mask, flows, overlay=False, max_size=1000):
1969
- """
1970
- Display the original image, mask, and flow with optional resizing for large images.
1971
-
1972
- Args:
1973
- stack (np.array): Original image or stack.
1974
- mask (np.array): Mask image.
1975
- flows (list): List of flow images.
1976
- overlay (bool): Whether to overlay the mask on the original image.
1977
- max_size (int): Maximum allowed size for any dimension of the images.
1978
- """
1979
-
1980
- def resize_if_needed(image, max_size):
1981
- """Resize image if any dimension exceeds max_size while maintaining aspect ratio."""
1982
- if max(image.shape[:2]) > max_size:
1983
- scale = max_size / max(image.shape[:2])
1984
- new_shape = (int(image.shape[0] * scale), int(image.shape[1] * scale))
1985
- if image.ndim == 3:
1986
- new_shape += (image.shape[2],)
1987
- return skimage.transform.resize(image, new_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
1988
- return image
1989
-
1990
- # Resize if necessary
1991
- stack = resize_if_needed(stack, max_size)
1992
- mask = resize_if_needed(mask, max_size)
1993
- flows = [resize_if_needed(flow, max_size) for flow in flows]
1994
-
1995
- fig, axs = plt.subplots(1, 3, figsize=(12, 4)) # Adjust subplot layout
1996
-
1997
- if stack.shape[-1] == 1:
1998
- stack = np.squeeze(stack)
1999
-
2000
- # Display original image or its first channel
2001
- if stack.ndim == 2:
2002
- axs[0].imshow(stack, cmap='gray')
2003
- elif stack.ndim == 3:
2004
- axs[0].imshow(stack)
2005
- else:
2006
- raise ValueError("Unexpected stack dimensionality.")
2007
-
2008
- axs[0].set_title('Original Image')
2009
- axs[0].axis('off')
2010
-
2011
-
2012
- # Overlay mask on original image if overlay is True
2013
- if overlay:
2014
- mask_cmap = generate_mask_random_cmap(mask) # Generate random colormap for mask
2015
- mask_overlay = np.ma.masked_where(mask == 0, mask) # Mask background
2016
- outlines = find_boundaries(mask, mode='thick') # Find mask outlines
2017
-
2018
- if stack.ndim == 2 or stack.ndim == 3:
2019
- axs[1].imshow(stack, cmap='gray' if stack.ndim == 2 else None)
2020
- axs[1].imshow(mask_overlay, cmap=mask_cmap, alpha=0.5) # Overlay mask
2021
- axs[1].contour(outlines, colors='r', linewidths=2) # Add red outlines with thickness 2
2022
- else:
2023
- axs[1].imshow(mask, cmap='gray')
2024
-
2025
- axs[1].set_title('Mask with Overlay' if overlay else 'Mask')
2026
- axs[1].axis('off')
2027
-
2028
- # Display flow image or its first channel
2029
- if flows and isinstance(flows, list) and flows[0].ndim in [2, 3]:
2030
- flow_image = flows[0]
2031
- if flow_image.ndim == 3:
2032
- flow_image = flow_image[:, :, 0] # Use first channel for 3D
2033
- axs[2].imshow(flow_image, cmap='jet')
2034
- else:
2035
- raise ValueError("Unexpected flow dimensionality or structure.")
2036
-
2037
- axs[2].set_title('Flows')
2038
- axs[2].axis('off')
2039
-
2040
- fig.tight_layout()
2041
- plt.show()
2042
1843
 
2043
1844
  def plot_resize(images, resized_images, labels, resized_labels):
2044
1845
  # Display an example image and label before and after resizing
@@ -2296,48 +2097,6 @@ def plot_lorenz_curves(csv_files, name_column='grna_name', value_column='count',
2296
2097
  print(f"Saved Lorenz Curve: {save_file_path}")
2297
2098
  plt.show()
2298
2099
 
2299
- def plot_lorenz_curves_v1(csv_files, remove_keys=['TGGT1_220950_1', 'TGGT1_233460_4']):
2300
-
2301
- def lorenz_curve(data):
2302
- """Calculate Lorenz curve."""
2303
- sorted_data = np.sort(data)
2304
- cumulative_data = np.cumsum(sorted_data)
2305
- lorenz_curve = cumulative_data / cumulative_data[-1]
2306
- lorenz_curve = np.insert(lorenz_curve, 0, 0)
2307
- return lorenz_curve
2308
-
2309
- combined_data = []
2310
-
2311
- plt.figure(figsize=(10, 6))
2312
-
2313
- for idx, csv_file in enumerate(csv_files):
2314
- if idx == 1:
2315
- save_fldr = os.path.dirname(csv_file)
2316
- save_path = os.path.join(save_fldr, 'lorenz_curve.pdf')
2317
-
2318
- df = pd.read_csv(csv_file)
2319
- for remove in remove_keys:
2320
- df = df[df['key'] != remove]
2321
-
2322
- values = df['value'].values
2323
- combined_data.extend(values)
2324
-
2325
- lorenz = lorenz_curve(values)
2326
- name = os.path.basename(csv_file)[:3]
2327
- plt.plot(np.linspace(0, 1, len(lorenz)), lorenz, label=name)
2328
-
2329
- # Plot combined Lorenz curve
2330
- combined_lorenz = lorenz_curve(np.array(combined_data))
2331
- plt.plot(np.linspace(0, 1, len(combined_lorenz)), combined_lorenz, label="Combined Lorenz Curve", linestyle='--', color='black')
2332
-
2333
- plt.title('Lorenz Curves')
2334
- plt.xlabel('Cumulative Share of Individuals')
2335
- plt.ylabel('Cumulative Share of Value')
2336
- plt.legend()
2337
- plt.grid(False)
2338
- plt.savefig(save_path)
2339
- plt.show()
2340
-
2341
2100
  def plot_permutation(permutation_df):
2342
2101
  num_features = len(permutation_df)
2343
2102
  fig_height = max(8, num_features * 0.3) # Set a minimum height of 8 and adjust height based on number of features
@@ -2844,9 +2603,13 @@ class spacrGraph:
2844
2603
  len(self.df[self.df[self.grouping_column] == unique_groups[1]])})
2845
2604
 
2846
2605
  return test_results
2847
-
2606
+
2848
2607
  def perform_posthoc_tests(self, is_normal, unique_groups):
2849
2608
  """Perform post-hoc tests for multiple groups based on all_to_all flag."""
2609
+
2610
+ from .utils import choose_p_adjust_method
2611
+
2612
+ posthoc_results = []
2850
2613
  if is_normal and len(unique_groups) > 2 and self.all_to_all:
2851
2614
  tukey_result = pairwise_tukeyhsd(self.df[self.data_column], self.df[self.grouping_column], alpha=0.05)
2852
2615
  posthoc_results = []
@@ -2862,22 +2625,40 @@ class spacrGraph:
2862
2625
  'n_object': len(raw_data1) + len(raw_data2),
2863
2626
  'n_well': len(self.df[self.df[self.grouping_column] == comparison[0]]) + len(self.df[self.df[self.grouping_column] == comparison[1]])})
2864
2627
  return posthoc_results
2865
-
2866
- elif len(unique_groups) > 2 and not self.all_to_all and self.compare_group:
2867
- dunn_result = pg.pairwise_tests(data=self.df, dv=self.data_column, between=self.grouping_column, padjust='bonf', test='dunn')
2868
- posthoc_results = []
2869
- for idx, row in dunn_result.iterrows():
2870
- if row['A'] == self.compare_group or row['B'] == self.compare_group:
2871
- posthoc_results.append({
2872
- 'Comparison': f"{row['A']} vs {row['B']}",
2873
- 'Test Statistic': row['T'], # Test statistic from Dunn's test
2874
- 'p-value': row['p-val'],
2875
- 'Test Name': 'Dunn’s Post-hoc',
2876
- 'n_object': None,
2877
- 'n_well': None})
2878
-
2628
+
2629
+ elif len(unique_groups) > 2 and self.all_to_all:
2630
+ print('performing_dunns')
2631
+
2632
+ # Prepare data for Dunn's test in long format
2633
+ long_data = self.df[[self.data_column[0], self.grouping_column]].dropna()
2634
+
2635
+ p_adjust_method = choose_p_adjust_method(num_groups=len(long_data[self.grouping_column].unique()),num_data_points=len(long_data) // len(long_data[self.grouping_column].unique()))
2636
+
2637
+ # Perform Dunn's test with Bonferroni correction
2638
+ dunn_result = sp.posthoc_dunn(
2639
+ long_data,
2640
+ val_col=self.data_column[0],
2641
+ group_col=self.grouping_column,
2642
+ p_adjust=p_adjust_method
2643
+ )
2644
+
2645
+ for group_a, group_b in zip(*np.triu_indices_from(dunn_result, k=1)):
2646
+ raw_data1 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.index[group_a]][self.data_column]
2647
+ raw_data2 = self.raw_df[self.raw_df[self.grouping_column] == dunn_result.columns[group_b]][self.data_column]
2648
+
2649
+ posthoc_results.append({
2650
+ 'Comparison': f"{dunn_result.index[group_a]} vs {dunn_result.columns[group_b]}",
2651
+ 'Test Statistic': None, # Dunn's test does not return a specific test statistic
2652
+ 'p-value': dunn_result.iloc[group_a, group_b], # Extract the p-value from the matrix
2653
+ 'Test Name': "Dunn's Post-hoc",
2654
+ 'p_adjust_method': p_adjust_method,
2655
+ 'n_object': len(raw_data1) + len(raw_data2), # Total objects
2656
+ 'n_well': len(self.df[self.df[self.grouping_column] == dunn_result.index[group_a]]) +
2657
+ len(self.df[self.grouping_column] == dunn_result.columns[group_b])})
2658
+
2879
2659
  return posthoc_results
2880
- return []
2660
+
2661
+ return posthoc_results
2881
2662
 
2882
2663
  def create_plot(self, ax=None):
2883
2664
  """Create and display the plot based on the chosen graph type."""
@@ -2913,31 +2694,40 @@ class spacrGraph:
2913
2694
  transposed_table = list(map(list, zip(*table_data)))
2914
2695
  return row_labels, transposed_table
2915
2696
 
2916
- def _place_symbols(row_labels, transposed_table, x_positions, ax):
2697
+
2698
+ def _place_symbols(row_labels, transposed_table, x_positions, ax):
2699
+ """
2700
+ Places symbols and row labels aligned under the bars or jitter points on the graph.
2917
2701
 
2918
- # Get the bottom of the y-axis (y=0) in data coordinates and convert to display coordinates
2702
+ Parameters:
2703
+ - row_labels: List of row titles to be displayed along the y-axis.
2704
+ - transposed_table: Data to be placed under each bar/jitter as symbols.
2705
+ - x_positions: X-axis positions for each group to align the symbols.
2706
+ - ax: The matplotlib Axes object where the plot is drawn.
2707
+ """
2708
+ # Get plot dimensions and adjust for different plot sizes
2919
2709
  y_axis_min = ax.get_ylim()[0] # Minimum y-axis value (usually 0)
2920
- symbol_start_y = ax.transData.transform((0, y_axis_min))[1] - 30 # Slightly below the x-axis line
2921
-
2922
- # Convert to figure coordinates
2923
- symbol_start_y_fig = ax.transAxes.inverted().transform((0, symbol_start_y))[1]
2924
-
2925
- # Calculate y-spacing for the table rows (adjust as needed)
2926
- y_spacing = 0.02 # Control vertical spacing between elements
2927
-
2928
- # X-coordinate for the row labels at the y-axis and x-axis intersection
2929
- label_x_pos = ax.get_xlim()[0] - 0.5 # Slightly offset from the y-axis
2930
-
2931
- # Place the row titles at the y-axis intersection
2710
+ symbol_start_y = y_axis_min - 0.05 * (ax.get_ylim()[1] - y_axis_min) # Adjust a bit below the x-axis
2711
+
2712
+ # Calculate spacing for the table rows (adjust as needed)
2713
+ y_spacing = 0.04 # Adjust this for better spacing between rows
2714
+
2715
+ # Determine the leftmost x-position for row labels (align with the y-axis)
2716
+ label_x_pos = ax.get_xlim()[0] - 0.3 # Adjust offset from the y-axis
2717
+
2718
+ # Place row labels vertically aligned with symbols
2932
2719
  for row_idx, title in enumerate(row_labels):
2933
- y_pos = symbol_start_y_fig - (row_idx * y_spacing) # Align with row index
2720
+ y_pos = symbol_start_y - (row_idx * y_spacing) # Calculate vertical position for each label
2934
2721
  ax.text(label_x_pos, y_pos, title, ha='right', va='center', fontsize=12, fontweight='regular')
2935
-
2936
- # Place the symbols under each bar
2722
+
2723
+ # Place symbols under each bar or jitter point based on x-positions
2937
2724
  for idx, (x_pos, column_data) in enumerate(zip(x_positions, transposed_table)):
2938
2725
  for row_idx, text in enumerate(column_data):
2939
- y_pos = symbol_start_y_fig - (row_idx * y_spacing)
2940
- ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12)
2726
+ y_pos = symbol_start_y - (row_idx * y_spacing) # Adjust vertical spacing for symbols
2727
+ ax.text(x_pos, y_pos, text, ha='center', va='center', fontsize=12, fontweight='regular')
2728
+
2729
+ # Redraw to apply changes
2730
+ ax.figure.canvas.draw()
2941
2731
 
2942
2732
  def _get_positions(self, ax):
2943
2733
  if self.graph_type in ['bar','jitter_bar']:
@@ -3048,6 +2838,10 @@ class spacrGraph:
3048
2838
  else:
3049
2839
  raise ValueError(f"Unknown graph type: {self.graph_type}")
3050
2840
 
2841
+ if len(self.data_column) == 1:
2842
+ num_groups = len(self.df[self.grouping_column].unique())
2843
+ self._standerdize_figure_format(ax=ax, num_groups=num_groups, graph_type=self.graph_type)
2844
+
3051
2845
  # Set y-axis start
3052
2846
  if isinstance(self.y_lim, list):
3053
2847
  if len(self.y_lim) == 2:
@@ -3082,7 +2876,73 @@ class spacrGraph:
3082
2876
  if self.save:
3083
2877
  self._save_results()
3084
2878
 
3085
- ax.margins(x=0.12)
2879
+ ax.margins(x=0.12)
2880
+
2881
+ def _standerdize_figure_format(self, ax, num_groups, graph_type):
2882
+ """
2883
+ Adjusts the figure layout (size, bar width, jitter, and spacing) based on the number of groups.
2884
+
2885
+ Parameters:
2886
+ - ax: The matplotlib Axes object.
2887
+ - num_groups: Number of unique groups.
2888
+ - graph_type: The type of graph (e.g., 'bar', 'jitter', 'box', etc.).
2889
+
2890
+ Returns:
2891
+ - None. Modifies the figure and Axes in place.
2892
+ """
2893
+ if graph_type in ['line', 'line_std']:
2894
+ print("Skipping layout adjustment for line graphs.")
2895
+ return # Skip layout adjustment for line graphs
2896
+
2897
+ correction_factor = 4
2898
+
2899
+ # Set figure size to ensure it remains square with a minimum size
2900
+ fig_size = max(6, num_groups * 2) / correction_factor
2901
+ ax.figure.set_size_inches(fig_size, fig_size)
2902
+
2903
+ # Configure layout based on the number of groups
2904
+ bar_width = min(0.8, 1.5 / num_groups) / correction_factor
2905
+ jitter_amount = min(0.1, 0.2 / num_groups) / correction_factor
2906
+ jitter_size = max(50 / num_groups, 200)
2907
+
2908
+ # Adjust axis limits to ensure bars are centered with respect to group labels
2909
+ ax.set_xlim(-0.5, num_groups - 0.5)
2910
+
2911
+ # Set ticks to match the group labels in your DataFrame
2912
+ group_labels = self.df[self.grouping_column].unique()
2913
+ ax.set_xticks(range(len(group_labels)))
2914
+ ax.set_xticklabels(group_labels, rotation=45, ha='right')
2915
+
2916
+ # Customize elements based on the graph type
2917
+ if graph_type == 'bar':
2918
+ # Adjust bars' width and position
2919
+ for bar in ax.patches:
2920
+ bar.set_width(bar_width)
2921
+ bar.set_x(bar.get_x() - bar_width / 2)
2922
+
2923
+ elif graph_type in ['jitter', 'jitter_bar', 'jitter_box']:
2924
+ # Adjust jitter points' position and size
2925
+ for coll in ax.collections:
2926
+ offsets = coll.get_offsets()
2927
+ offsets[:, 0] += jitter_amount # Shift jitter points slightly
2928
+ coll.set_offsets(offsets)
2929
+ coll.set_sizes([jitter_size] * len(offsets)) # Adjust point size dynamically
2930
+
2931
+ elif graph_type in ['box', 'violin']:
2932
+ # Adjust box width for consistent spacing
2933
+ for artist in ax.artists:
2934
+ artist.set_width(bar_width)
2935
+
2936
+ # Adjust legend and axis labels
2937
+ ax.tick_params(axis='x', labelsize=max(10, 15 - num_groups // 2))
2938
+ ax.tick_params(axis='y', labelsize=max(10, 15 - num_groups // 2))
2939
+
2940
+ if ax.get_legend():
2941
+ ax.get_legend().set_bbox_to_anchor((1.05, 1)) #loc='upper left',borderaxespad=0.
2942
+ ax.get_legend().prop.set_size(max(8, 12 - num_groups // 3))
2943
+
2944
+ # Redraw the figure to apply changes
2945
+ ax.figure.canvas.draw()
3086
2946
 
3087
2947
  def _create_bar_plot(self, ax):
3088
2948
  """Helper method to create a bar plot with consistent bar thickness and centered error bars."""
@@ -3301,11 +3161,11 @@ class spacrGraph:
3301
3161
  bar.set_x(bar.get_x() - target_width / 2)
3302
3162
 
3303
3163
  # Adjust error bars alignment with bars
3304
- bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
3305
- for bar, (_, row) in zip(bars, summary_df.iterrows()):
3306
- x_bar = bar.get_x() + bar.get_width() / 2
3307
- err = row[self.error_bar_type]
3308
- ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
3164
+ #bars = [bar for bar in ax.patches if isinstance(bar, plt.Rectangle)]
3165
+ #for bar, (_, row) in zip(bars, summary_df.iterrows()):
3166
+ # x_bar = bar.get_x() + bar.get_width() / 2
3167
+ # err = row[self.error_bar_type]
3168
+ # ax.errorbar(x=x_bar, y=bar.get_height(), yerr=err, fmt='none', c='black', capsize=5, lw=2)
3309
3169
 
3310
3170
  # Set legend and labels
3311
3171
  ax.set_xlabel(self.grouping_column)
@@ -3420,7 +3280,7 @@ def plot_data_from_db(settings):
3420
3280
  dfs.append(dft)
3421
3281
 
3422
3282
  df = pd.concat(dfs, axis=0)
3423
- df['prc'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str)
3283
+ df['prc'] = df['plate'].astype(str) + '_' + df['row_name'].astype(str) + '_' + df['column_name'].astype(str)
3424
3284
  #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3425
3285
  #df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
3426
3286
  df['class'] = df['png_path'].apply(lambda x: 'class_1' if 'class_1' in x else ('class_0' if 'class_0' in x else None))
spacr/sequencing.py CHANGED
@@ -125,7 +125,7 @@ def process_chunk(chunk_data):
125
125
  consensus_sequences.append(consensus_seq)
126
126
  column_sequence = match.group('column')
127
127
  grna_sequence = match.group('grna')
128
- row_sequence = match.group('row')
128
+ row_sequence = match.group('row_name')
129
129
  columns.append(column_sequence)
130
130
  grnas.append(grna_sequence)
131
131
  rows.append(row_sequence)
@@ -176,7 +176,7 @@ def process_chunk(chunk_data):
176
176
  consensus_sequences.append(consensus_seq)
177
177
  column_sequence = match.group('column')
178
178
  grna_sequence = match.group('grna')
179
- row_sequence = match.group('row')
179
+ row_sequence = match.group('row_name')
180
180
  columns.append(column_sequence)
181
181
  grnas.append(grna_sequence)
182
182
  rows.append(row_sequence)
@@ -532,7 +532,7 @@ def graph_sequencing_stats(settings):
532
532
  # Iterate through the fraction thresholds
533
533
  for threshold in fraction_thresholds:
534
534
  filtered_df = df[df['fraction'] >= threshold]
535
- unique_count = filtered_df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
535
+ unique_count = filtered_df.groupby(['plate', 'row_name', 'column'])['grna'].nunique().mean()
536
536
  results.append((threshold, unique_count))
537
537
 
538
538
  results_df = pd.DataFrame(results, columns=['fraction_threshold', 'unique_count'])
@@ -588,17 +588,21 @@ def graph_sequencing_stats(settings):
588
588
  # Apply the closest threshold to the DataFrame
589
589
  df = df[df['fraction'] >= closest_threshold]
590
590
 
591
- # Group by 'plate', 'row', 'column' and compute unique counts of 'grna'
592
- unique_counts = df.groupby(['plate', 'row', 'column'])['grna'].nunique().reset_index(name='unique_counts')
593
- unique_count_mean = df.groupby(['plate', 'row', 'column'])['grna'].nunique().mean()
594
- unique_count_std = df.groupby(['plate', 'row', 'column'])['grna'].nunique().std()
591
+ # Group by 'plate', 'row_name', 'column' and compute unique counts of 'grna'
592
+ unique_counts = df.groupby(['plate', 'row_name', 'column'])['grna'].nunique().reset_index(name='unique_counts')
593
+ unique_count_mean = df.groupby(['plate', 'row_name', 'column'])['grna'].nunique().mean()
594
+ unique_count_std = df.groupby(['plate', 'row_name', 'column'])['grna'].nunique().std()
595
595
 
596
596
  # Merge the unique counts back into the original DataFrame
597
- df = pd.merge(df, unique_counts, on=['plate', 'row', 'column'], how='left')
597
+ df = pd.merge(df, unique_counts, on=['plate', 'row_name', 'column'], how='left')
598
598
 
599
599
  print(f"unique_count mean: {unique_count_mean} std: {unique_count_std}")
600
- display(df)
601
600
  #_plot_density(df, dependent_variable='unique_counts')
601
+
602
+ has_underscore = df['row_name'].str.contains('_').any()
603
+ if has_underscore:
604
+ df['row_name'] = df['row_name'].apply(lambda x: x.split('_')[1])
605
+
602
606
  plot_plates(df=df, variable='unique_counts', grouping='mean', min_max='allq', cmap='viridis',min_count=0, verbose=True, dst=dst)
603
607
 
604
608
  return closest_threshold