masster 0.4.21__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of masster might be problematic. Click here for more details.

masster/study/plot.py CHANGED
@@ -308,7 +308,7 @@ def plot_alignment(
308
308
  self.logger.info("Showing current RT values for both plots. Run align() first to see alignment comparison.")
309
309
 
310
310
  # Get sample_uids to filter by if specified
311
- sample_uids = self._get_sample_uids(samples) if samples is not None else None
311
+ sample_uids = self._get_samples_uids(samples) if samples is not None else None
312
312
 
313
313
  # Start with full features_df
314
314
  features_df = self.features_df
@@ -836,7 +836,7 @@ def plot_samples_2d(
836
836
  from bokeh.io.export import export_png
837
837
  from bokeh.models import ColumnDataSource, HoverTool
838
838
 
839
- sample_uids = self._get_sample_uids(samples)
839
+ sample_uids = self._get_samples_uids(samples)
840
840
 
841
841
  if not sample_uids:
842
842
  self.logger.error("No valid sample_uids provided.")
@@ -1053,7 +1053,7 @@ def plot_bpc(
1053
1053
  from bokeh.io.export import export_png
1054
1054
  from masster.study.helpers import get_bpc
1055
1055
 
1056
- sample_uids = self._get_sample_uids(samples)
1056
+ sample_uids = self._get_samples_uids(samples)
1057
1057
  if not sample_uids:
1058
1058
  self.logger.error("No valid sample_uids provided for BPC plotting.")
1059
1059
  return
@@ -1238,7 +1238,7 @@ def plot_eic(
1238
1238
  self.logger.error("mz must be provided for EIC plotting")
1239
1239
  return
1240
1240
 
1241
- sample_uids = self._get_sample_uids(samples)
1241
+ sample_uids = self._get_samples_uids(samples)
1242
1242
  if not sample_uids:
1243
1243
  self.logger.error("No valid sample_uids provided for EIC plotting.")
1244
1244
  return
@@ -1400,7 +1400,7 @@ def plot_rt_correction(
1400
1400
  self.logger.error("Column 'rt_original' not found in features_df. Alignment/backup RTs missing.")
1401
1401
  return
1402
1402
 
1403
- sample_uids = self._get_sample_uids(samples)
1403
+ sample_uids = self._get_samples_uids(samples)
1404
1404
  if not sample_uids:
1405
1405
  self.logger.error("No valid sample_uids provided for RT correction plotting.")
1406
1406
  return
@@ -1537,7 +1537,7 @@ def plot_chrom(
1537
1537
  height=300,
1538
1538
  ):
1539
1539
  cons_uids = self._get_consensus_uids(uids)
1540
- sample_uids = self._get_sample_uids(samples)
1540
+ sample_uids = self._get_samples_uids(samples)
1541
1541
 
1542
1542
  chroms = self.get_chrom(uids=cons_uids, samples=sample_uids)
1543
1543
 
@@ -1723,226 +1723,213 @@ def plot_chrom(
1723
1723
  def plot_consensus_stats(
1724
1724
  self,
1725
1725
  filename=None,
1726
- width=1200,
1727
- height=1200,
1726
+ width=840, # Reduced from 1200 (30% smaller)
1727
+ height=None,
1728
1728
  alpha=0.6,
1729
- markersize=3,
1729
+ bins=30,
1730
+ n_cols=4,
1730
1731
  ):
1731
1732
  """
1732
- Plot a scatter plot matrix (SPLOM) of consensus statistics using Bokeh.
1733
-
1733
+ Plot histograms/distributions for specific consensus statistics in the requested order.
1734
+
1735
+ Shows the following properties in order:
1736
+ 1. rt: Retention time
1737
+ 2. rt_delta_mean: Mean retention time delta
1738
+ 3. mz: Mass-to-charge ratio
1739
+ 4. mz_range: Mass range (mz_max - mz_min)
1740
+ 5. log10_inty_mean: Log10 of mean intensity
1741
+ 6. number_samples: Number of samples
1742
+ 7. number_ms2: Number of MS2 spectra
1743
+ 8. charge_mean: Mean charge
1744
+ 9. quality: Feature quality
1745
+ 10. chrom_coherence_mean: Mean chromatographic coherence
1746
+ 11. chrom_height_scaled_mean: Mean scaled chromatographic height
1747
+ 12. chrom_prominence_scaled_mean: Mean scaled chromatographic prominence
1748
+
1734
1749
  Parameters:
1735
1750
  filename (str, optional): Output filename for saving the plot
1736
- width (int): Overall width of the plot (default: 1200)
1737
- height (int): Overall height of the plot (default: 1200)
1738
- alpha (float): Point transparency (default: 0.6)
1739
- markersize (int): Size of points (default: 5)
1751
+ width (int): Overall width of the plot (default: 840)
1752
+ height (int, optional): Overall height of the plot (auto-calculated if None)
1753
+ alpha (float): Histogram transparency (default: 0.6)
1754
+ bins (int): Number of histogram bins (default: 30)
1755
+ n_cols (int): Number of columns in the grid layout (default: 4)
1740
1756
  """
1741
1757
  from bokeh.layouts import gridplot
1742
- from bokeh.models import ColumnDataSource, HoverTool
1743
- from bokeh.plotting import figure, show, output_file
1758
+ from bokeh.plotting import figure
1759
+ import polars as pl
1760
+ import numpy as np
1744
1761
 
1745
1762
  # Check if consensus_df exists and has data
1746
1763
  if self.consensus_df is None or self.consensus_df.is_empty():
1747
1764
  self.logger.error("No consensus data available. Run merge/find_consensus first.")
1748
1765
  return
1749
1766
 
1750
- # Define the columns to plot
1751
- columns = [
1752
- "rt",
1753
- "mz",
1754
- "number_samples",
1755
- "log10_quality",
1756
- "mz_delta_mean",
1757
- "rt_delta_mean",
1758
- "chrom_coherence_mean",
1759
- "chrom_prominence_scaled_mean",
1760
- "inty_mean",
1761
- "number_ms2",
1762
- ]
1763
-
1764
- # Check which columns exist in the dataframe and compute missing ones
1765
- available_columns = self.consensus_df.columns
1767
+ # Get all columns and their data types - work with original dataframe
1766
1768
  data_df = self.consensus_df.clone()
1767
1769
 
1768
- # Add log10_quality if quality exists
1769
- if "quality" in available_columns and "log10_quality" not in available_columns:
1770
- data_df = data_df.with_columns(
1771
- pl.col("quality").log10().alias("log10_quality"),
1772
- )
1773
-
1774
- # Filter columns that actually exist
1775
- final_columns = [col for col in columns if col in data_df.columns]
1770
+ # Define specific columns to plot in the exact order requested
1771
+ desired_columns = [
1772
+ "rt",
1773
+ "rt_delta_mean",
1774
+ "mz",
1775
+ "mz_range", # mz_max-mz_min (will be calculated)
1776
+ "log10_inty_mean", # log10(inty_mean) (will be calculated)
1777
+ "number_samples",
1778
+ "number_ms2",
1779
+ "charge_mean",
1780
+ "quality",
1781
+ "chrom_coherence_mean",
1782
+ "chrom_height_scaled_mean",
1783
+ "chrom_prominence_scaled_mean"
1784
+ ]
1785
+
1786
+ # Calculate derived columns if they don't exist
1787
+ if "mz_range" not in data_df.columns and "mz_max" in data_df.columns and "mz_min" in data_df.columns:
1788
+ data_df = data_df.with_columns((pl.col("mz_max") - pl.col("mz_min")).alias("mz_range"))
1789
+
1790
+ if "log10_inty_mean" not in data_df.columns and "inty_mean" in data_df.columns:
1791
+ data_df = data_df.with_columns(pl.col("inty_mean").log10().alias("log10_inty_mean"))
1792
+
1793
+ # Filter to only include columns that exist in the dataframe, preserving order
1794
+ numeric_columns = [col for col in desired_columns if col in data_df.columns]
1795
+
1796
+ # Check if the numeric columns are actually numeric
1797
+ final_numeric_columns = []
1798
+ for col in numeric_columns:
1799
+ dtype = data_df[col].dtype
1800
+ if dtype in [pl.Int8, pl.Int16, pl.Int32, pl.Int64,
1801
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
1802
+ pl.Float32, pl.Float64]:
1803
+ final_numeric_columns.append(col)
1804
+
1805
+ numeric_columns = final_numeric_columns
1776
1806
 
1777
- if len(final_columns) < 2:
1778
- self.logger.error(f"Need at least 2 columns for SPLOM. Available: {final_columns}")
1807
+ if len(numeric_columns) == 0:
1808
+ self.logger.error(f"None of the requested consensus statistics columns were found or are numeric. Available columns: {list(data_df.columns)}")
1779
1809
  return
1780
1810
 
1781
- self.logger.debug(f"Creating SPLOM with columns: {final_columns}")
1782
-
1783
- # Add important ID columns for tooltips even if not plotting them
1784
- tooltip_columns = []
1785
- for id_col in ["consensus_uid", "consensus_id"]:
1786
- if id_col in data_df.columns and id_col not in final_columns:
1787
- tooltip_columns.append(id_col)
1811
+ self.logger.debug(f"Creating distribution plots for {len(numeric_columns)} specific consensus columns: {numeric_columns}")
1788
1812
 
1789
- # Select plotting columns plus tooltip columns
1790
- all_columns = final_columns + tooltip_columns
1791
- data_pd = data_df.select(all_columns).to_pandas()
1813
+ # Work directly with Polars - no conversion to pandas needed
1814
+ data_df_clean = data_df.select(numeric_columns)
1792
1815
 
1793
- # Remove any infinite or NaN values
1794
- data_pd = data_pd.replace([np.inf, -np.inf], np.nan).dropna()
1795
-
1796
- if data_pd.empty:
1797
- self.logger.error("No valid data after removing NaN/infinite values.")
1816
+ # Check if all numeric columns are empty
1817
+ all_columns_empty = True
1818
+ for col in numeric_columns:
1819
+ # Check if column has any non-null, finite values
1820
+ non_null_count = data_df_clean[col].filter(
1821
+ data_df_clean[col].is_not_null() &
1822
+ (data_df_clean[col].is_finite() if data_df_clean[col].dtype in [pl.Float32, pl.Float64] else pl.lit(True))
1823
+ ).len()
1824
+
1825
+ if non_null_count > 0:
1826
+ all_columns_empty = False
1827
+ break
1828
+
1829
+ if all_columns_empty:
1830
+ self.logger.error("All numeric columns contain only NaN/infinite values.")
1798
1831
  return
1799
1832
 
1800
- source = ColumnDataSource(data_pd)
1801
-
1802
- n_vars = len(final_columns)
1803
-
1804
- # Fixed dimensions - override user input to ensure consistent layout
1805
- total_width = 1200
1806
- total_height = 1200
1807
-
1808
- # Calculate plot sizes to ensure uniform inner plot areas
1809
- # First column needs extra width for y-axis labels
1810
- plot_width_first = 180 # Wider to account for y-axis labels
1811
- plot_width_others = 120 # Standard width for other columns
1812
- plot_height_normal = 120 # Standard height
1813
- plot_height_last = 155 # Taller last row to accommodate x-axis labels while keeping inner plot area same size
1833
+ # Calculate grid dimensions
1834
+ n_plots = len(numeric_columns)
1835
+ n_rows = (n_plots + n_cols - 1) // n_cols # Ceiling division
1836
+
1837
+ # Auto-calculate height if not provided
1838
+ if height is None:
1839
+ plot_height = 210 # Reduced from 300 (30% smaller)
1840
+ height = plot_height * n_rows + 56 # Reduced from 80 (30% smaller)
1841
+ else:
1842
+ plot_height = (height - 56) // n_rows # Reduced padding (30% smaller)
1843
+
1844
+ plot_width = (width - 56) // n_cols # Reduced padding (30% smaller)
1814
1845
 
1815
- # Create grid of plots with variable outer sizes but equal inner areas
1846
+ # Create plots grid
1816
1847
  plots = []
1848
+ current_row = []
1849
+
1850
+ for i, col in enumerate(numeric_columns):
1851
+ # Check if this column should use log scale for y-axis
1852
+ y_axis_type = "log" if col in ["number_samples", "number_ms2"] else "linear"
1853
+
1854
+ # Create histogram for this column
1855
+ p = figure(
1856
+ width=plot_width,
1857
+ height=plot_height,
1858
+ title=col,
1859
+ toolbar_location="above",
1860
+ tools="pan,wheel_zoom,box_zoom,reset,save",
1861
+ y_axis_type=y_axis_type
1862
+ )
1863
+
1864
+ # Set white background
1865
+ p.background_fill_color = "white"
1866
+ p.border_fill_color = "white"
1867
+
1868
+ # Calculate histogram using Polars
1869
+ # Get valid (non-null, finite) values for this column
1870
+ if data_df_clean[col].dtype in [pl.Float32, pl.Float64]:
1871
+ valid_values = data_df_clean.filter(
1872
+ data_df_clean[col].is_not_null() & data_df_clean[col].is_finite()
1873
+ )[col]
1874
+ else:
1875
+ valid_values = data_df_clean.filter(data_df_clean[col].is_not_null())[col]
1876
+
1877
+ if valid_values.len() == 0:
1878
+ self.logger.warning(f"No valid values for column {col}")
1879
+ continue
1880
+
1881
+ # Convert to numpy for histogram calculation
1882
+ values_array = valid_values.to_numpy()
1883
+ hist, edges = np.histogram(values_array, bins=bins)
1884
+
1885
+ # Handle log y-axis: replace zero counts with small positive values
1886
+ if y_axis_type == "log":
1887
+ # Replace zero counts with a small value (1e-1) to make them visible on log scale
1888
+ hist_log_safe = np.where(hist == 0, 0.1, hist)
1889
+ bottom_val = 0.1 # Use small positive value for bottom on log scale
1890
+ else:
1891
+ hist_log_safe = hist
1892
+ bottom_val = 0
1893
+
1894
+ # Create histogram bars
1895
+ p.quad(
1896
+ top=hist_log_safe,
1897
+ bottom=bottom_val,
1898
+ left=edges[:-1],
1899
+ right=edges[1:],
1900
+ fill_color="steelblue",
1901
+ line_color="white",
1902
+ alpha=alpha,
1903
+ )
1904
+
1905
+ # Style the plot
1906
+ p.title.text_font_size = "10pt" # Reduced from 12pt
1907
+ p.xaxis.axis_label = "" # Remove x-axis title
1908
+ p.grid.grid_line_alpha = 0.3 # Show y-axis grid with transparency
1909
+ p.grid.grid_line_color = "gray"
1910
+ p.grid.grid_line_dash = [6, 4] # Dashed grid lines
1911
+ p.xgrid.visible = False # Hide x-axis grid
1912
+ p.outline_line_color = None # Remove gray border around plot area
1913
+
1914
+ # Remove y-axis label but keep y-axis visible
1915
+ p.yaxis.axis_label = ""
1916
+
1917
+ current_row.append(p)
1918
+
1919
+ # If we've filled a row or reached the end, add the row to plots
1920
+ if len(current_row) == n_cols or i == n_plots - 1:
1921
+ # Fill remaining spots in the last row with None if needed
1922
+ while len(current_row) < n_cols and i == n_plots - 1:
1923
+ current_row.append(None)
1924
+ plots.append(current_row)
1925
+ current_row = []
1926
+
1927
+ # Create grid layout with white background
1928
+ grid = gridplot(plots, toolbar_location="above", merge_tools=True)
1929
+
1930
+ # The background should be white by default in Bokeh
1931
+ # Individual plots already have white backgrounds set above
1817
1932
 
1818
- for i, y_var in enumerate(final_columns):
1819
- row = []
1820
- for j, x_var in enumerate(final_columns):
1821
- # Determine if this plot needs axis labels
1822
- has_x_label = i == n_vars - 1 # bottom row
1823
- has_y_label = j == 0 # left column
1824
-
1825
- # First column wider to accommodate y-axis labels, ensuring equal inner plot areas
1826
- current_width = plot_width_first if has_y_label else plot_width_others
1827
- current_height = plot_height_last if has_x_label else plot_height_normal
1828
-
1829
- p = figure(
1830
- width=current_width,
1831
- height=current_height,
1832
- title=None, # No title on any plot
1833
- toolbar_location=None,
1834
- # Adjusted borders - first column has more space, others minimal
1835
- min_border_left=70 if has_y_label else 15,
1836
- min_border_bottom=50 if has_x_label else 15,
1837
- min_border_right=15,
1838
- min_border_top=15,
1839
- )
1840
-
1841
- # Ensure subplot background and border are explicitly white so the plot looks
1842
- # correct in dark and light themes.
1843
- p.outline_line_color = None
1844
- p.border_fill_color = "white"
1845
- p.border_fill_alpha = 1.0
1846
- p.background_fill_color = "white"
1847
-
1848
- # Remove axis lines to eliminate black lines between plots
1849
- p.xaxis.axis_line_color = None
1850
- p.yaxis.axis_line_color = None
1851
-
1852
- # Keep subtle grid lines for data reference
1853
- p.grid.visible = True
1854
- p.grid.grid_line_color = "#E0E0E0" # Light gray grid lines
1855
-
1856
- # Set axis labels and formatting
1857
- if has_x_label: # bottom row
1858
- p.xaxis.axis_label = x_var
1859
- p.xaxis.axis_label_text_font_size = "12pt"
1860
- p.xaxis.major_label_text_font_size = "9pt"
1861
- p.xaxis.axis_label_standoff = 15
1862
- else:
1863
- p.xaxis.major_label_text_font_size = "0pt"
1864
- p.xaxis.minor_tick_line_color = None
1865
- p.xaxis.major_tick_line_color = None
1866
-
1867
- if has_y_label: # left column
1868
- p.yaxis.axis_label = y_var
1869
- p.yaxis.axis_label_text_font_size = "10pt" # Smaller y-axis title
1870
- p.yaxis.major_label_text_font_size = "8pt"
1871
- p.yaxis.axis_label_standoff = 12
1872
- else:
1873
- p.yaxis.major_label_text_font_size = "0pt"
1874
- p.yaxis.minor_tick_line_color = None
1875
- p.yaxis.major_tick_line_color = None
1876
-
1877
- if i == j:
1878
- # Diagonal: histogram
1879
- hist, edges = np.histogram(data_pd[x_var], bins=30)
1880
- p.quad(
1881
- top=hist,
1882
- bottom=0,
1883
- left=edges[:-1],
1884
- right=edges[1:],
1885
- fill_color="green",
1886
- line_color="white",
1887
- alpha=alpha,
1888
- )
1889
- else:
1890
- # Off-diagonal: scatter plot
1891
- scatter = p.scatter(
1892
- x=x_var,
1893
- y=y_var,
1894
- size=markersize,
1895
- alpha=alpha,
1896
- color="blue",
1897
- source=source,
1898
- )
1899
-
1900
- # Add hover tool
1901
- hover = HoverTool(
1902
- tooltips=[
1903
- (x_var, f"@{x_var}{{0.0000}}"),
1904
- (y_var, f"@{y_var}{{0.0000}}"),
1905
- (
1906
- "consensus_uid",
1907
- "@consensus_uid"
1908
- if "consensus_uid" in data_pd.columns
1909
- else "@consensus_id"
1910
- if "consensus_id" in data_pd.columns
1911
- else "N/A",
1912
- ),
1913
- ("rt", "@rt{0.00}" if "rt" in data_pd.columns else "N/A"),
1914
- ("mz", "@mz{0.0000}" if "mz" in data_pd.columns else "N/A"),
1915
- ],
1916
- renderers=[scatter],
1917
- )
1918
- p.add_tools(hover)
1919
-
1920
- row.append(p)
1921
- plots.append(row)
1922
-
1923
- # Link axes for same variables
1924
- for i in range(n_vars):
1925
- for j in range(n_vars):
1926
- if i != j: # Don't link diagonal plots
1927
- # Link x-axis to other plots in same column
1928
- for k in range(n_vars):
1929
- if k != i and k != j:
1930
- plots[i][j].x_range = plots[k][j].x_range
1931
-
1932
- # Link y-axis to other plots in same row
1933
- for k in range(n_vars):
1934
- if k != j and k != i:
1935
- plots[i][j].y_range = plots[i][k].y_range
1936
-
1937
- # Create grid layout and force overall background/border to white so the outer
1938
- # container doesn't show dark UI colors in night mode.
1939
- grid = gridplot(plots)
1940
-
1941
- # Set overall background and border to white when supported
1942
- if hasattr(grid, "background_fill_color"):
1943
- grid.background_fill_color = "white"
1944
- if hasattr(grid, "border_fill_color"):
1945
- grid.border_fill_color = "white"
1946
1933
 
1947
1934
  # Apply consistent save/display behavior
1948
1935
  if filename is not None:
@@ -1962,7 +1949,7 @@ def plot_consensus_stats(
1962
1949
  return grid
1963
1950
 
1964
1951
 
1965
- def plot_pca(
1952
+ def plot_samples_pca(
1966
1953
  self,
1967
1954
  filename=None,
1968
1955
  width=500,
@@ -2102,6 +2089,7 @@ def plot_pca(
2102
2089
  tools="pan,wheel_zoom,box_zoom,reset,save",
2103
2090
  )
2104
2091
 
2092
+ p.grid.visible = False
2105
2093
  p.xaxis.axis_label = f"PC1 ({explained_var[0]:.1%} variance)"
2106
2094
  p.yaxis.axis_label = f"PC2 ({explained_var[1]:.1%} variance)"
2107
2095
 
@@ -2226,6 +2214,293 @@ def plot_pca(
2226
2214
  return p
2227
2215
 
2228
2216
 
2217
+ def plot_samples_umap(
2218
+ self,
2219
+ filename=None,
2220
+ width=500,
2221
+ height=450,
2222
+ alpha=0.8,
2223
+ markersize=6,
2224
+ n_components=2,
2225
+ colorby=None,
2226
+ title="UMAP of Consensus Matrix",
2227
+ n_neighbors=15,
2228
+ min_dist=0.1,
2229
+ metric="euclidean",
2230
+ random_state=42,
2231
+ ):
2232
+ """
2233
+ Plot UMAP (Uniform Manifold Approximation and Projection) of the consensus matrix using Bokeh.
2234
+
2235
+ Parameters:
2236
+ filename (str, optional): Output filename for saving the plot
2237
+ width (int): Plot width (default: 500)
2238
+ height (int): Plot height (default: 450)
2239
+ alpha (float): Point transparency (default: 0.8)
2240
+ markersize (int): Size of points (default: 6)
2241
+ n_components (int): Number of UMAP components to compute (default: 2)
2242
+ colorby (str, optional): Column from samples_df to color points by
2243
+ title (str): Plot title (default: "UMAP of Consensus Matrix")
2244
+ n_neighbors (int): Number of neighbors for UMAP (default: 15)
2245
+ min_dist (float): Minimum distance for UMAP (default: 0.1)
2246
+ metric (str): Distance metric for UMAP (default: "euclidean")
2247
+ random_state (int or None): Random state for reproducibility (default: 42).
2248
+ - Use an integer (e.g., 42) for reproducible results (slower, single-threaded)
2249
+ - Use None for faster computation with multiple cores (non-reproducible)
2250
+
2251
+ Note:
2252
+ Setting random_state forces single-threaded computation but ensures reproducible results.
2253
+ Set random_state=None to enable parallel processing for faster computation.
2254
+ """
2255
+ try:
2256
+ import umap
2257
+ except ImportError:
2258
+ self.logger.error("UMAP not available. Please install umap-learn: pip install umap-learn")
2259
+ return
2260
+
2261
+ from bokeh.models import ColumnDataSource, HoverTool, ColorBar, LinearColorMapper
2262
+ from bokeh.plotting import figure
2263
+ from bokeh.palettes import Category20, viridis
2264
+ from bokeh.transform import factor_cmap
2265
+ from sklearn.preprocessing import StandardScaler
2266
+ import pandas as pd
2267
+ import numpy as np
2268
+
2269
+ # Check if consensus matrix and samples_df exist
2270
+ try:
2271
+ consensus_matrix = self.get_consensus_matrix()
2272
+ samples_df = self.samples_df
2273
+ except Exception as e:
2274
+ self.logger.error(f"Error getting consensus matrix or samples_df: {e}")
2275
+ return
2276
+
2277
+ if consensus_matrix is None or consensus_matrix.shape[0] == 0:
2278
+ self.logger.error("No consensus matrix available. Run merge/find_consensus first.")
2279
+ return
2280
+
2281
+ if samples_df is None or samples_df.is_empty():
2282
+ self.logger.error("No samples dataframe available.")
2283
+ return
2284
+
2285
+ self.logger.debug(f"Performing UMAP on consensus matrix with shape: {consensus_matrix.shape}")
2286
+
2287
+ # Extract only the sample columns (exclude consensus_uid column)
2288
+ sample_cols = [col for col in consensus_matrix.columns if col != "consensus_uid"]
2289
+
2290
+ # Convert consensus matrix to numpy, excluding the consensus_uid column
2291
+ if hasattr(consensus_matrix, "select"):
2292
+ # Polars DataFrame
2293
+ matrix_data = consensus_matrix.select(sample_cols).to_numpy()
2294
+ else:
2295
+ # Pandas DataFrame or other - drop consensus_uid column
2296
+ matrix_sample_data = consensus_matrix.drop(columns=["consensus_uid"], errors="ignore")
2297
+ if hasattr(matrix_sample_data, "values"):
2298
+ matrix_data = matrix_sample_data.values
2299
+ elif hasattr(matrix_sample_data, "to_numpy"):
2300
+ matrix_data = matrix_sample_data.to_numpy()
2301
+ else:
2302
+ matrix_data = np.array(matrix_sample_data)
2303
+
2304
+ # Transpose matrix so samples are rows and features are columns
2305
+ matrix_data = matrix_data.T
2306
+
2307
+ # Handle missing values by replacing with 0
2308
+ matrix_data = np.nan_to_num(matrix_data, nan=0.0, posinf=0.0, neginf=0.0)
2309
+
2310
+ # Standardize the data
2311
+ scaler = StandardScaler()
2312
+ matrix_scaled = scaler.fit_transform(matrix_data)
2313
+
2314
+ # Perform UMAP
2315
+ reducer = umap.UMAP(
2316
+ n_components=n_components,
2317
+ n_neighbors=n_neighbors,
2318
+ min_dist=min_dist,
2319
+ metric=metric,
2320
+ random_state=random_state,
2321
+ n_jobs=1
2322
+ )
2323
+ umap_result = reducer.fit_transform(matrix_scaled)
2324
+
2325
+ self.logger.debug(f"UMAP completed with shape: {umap_result.shape}")
2326
+
2327
+ # Convert samples_df to pandas for easier manipulation
2328
+ samples_pd = samples_df.to_pandas()
2329
+
2330
+ # Create dataframe with UMAP results and sample information
2331
+ umap_df = pd.DataFrame({
2332
+ "UMAP1": umap_result[:, 0],
2333
+ "UMAP2": umap_result[:, 1] if n_components > 1 else np.zeros(len(umap_result)),
2334
+ })
2335
+
2336
+ # Add sample information to UMAP dataframe
2337
+ if len(samples_pd) == len(umap_df):
2338
+ for col in samples_pd.columns:
2339
+ umap_df[col] = samples_pd[col].values
2340
+ else:
2341
+ self.logger.warning(
2342
+ f"Sample count mismatch: samples_df has {len(samples_pd)} rows, "
2343
+ f"but consensus matrix has {len(umap_df)} samples",
2344
+ )
2345
+
2346
+ # Prepare color mapping
2347
+ color_column = None
2348
+ color_mapper = None
2349
+
2350
+ if colorby and colorby in umap_df.columns:
2351
+ color_column = colorby
2352
+ unique_values = umap_df[colorby].unique()
2353
+
2354
+ # Handle categorical vs numeric coloring
2355
+ if umap_df[colorby].dtype in ["object", "string", "category"]:
2356
+ # Categorical coloring
2357
+ if len(unique_values) <= 20:
2358
+ palette = Category20[min(20, max(3, len(unique_values)))]
2359
+ else:
2360
+ palette = viridis(min(256, len(unique_values)))
2361
+ color_mapper = factor_cmap(colorby, palette, unique_values)
2362
+ else:
2363
+ # Numeric coloring
2364
+ palette = viridis(256)
2365
+ color_mapper = LinearColorMapper(
2366
+ palette=palette,
2367
+ low=umap_df[colorby].min(),
2368
+ high=umap_df[colorby].max(),
2369
+ )
2370
+
2371
+ # Create Bokeh plot
2372
+ p = figure(
2373
+ width=width,
2374
+ height=height,
2375
+ title=f"{title}",
2376
+ tools="pan,wheel_zoom,box_zoom,reset,save",
2377
+ )
2378
+
2379
+ p.grid.visible = False
2380
+ p.xaxis.axis_label = "UMAP1"
2381
+ p.yaxis.axis_label = "UMAP2"
2382
+
2383
+ # Create data source
2384
+ source = ColumnDataSource(umap_df)
2385
+
2386
+ # Create scatter plot
2387
+ if color_mapper:
2388
+ if isinstance(color_mapper, LinearColorMapper):
2389
+ scatter = p.scatter(
2390
+ "UMAP1",
2391
+ "UMAP2",
2392
+ size=markersize,
2393
+ alpha=alpha,
2394
+ color={"field": colorby, "transform": color_mapper},
2395
+ source=source,
2396
+ )
2397
+ # Add colorbar for numeric coloring
2398
+ color_bar = ColorBar(color_mapper=color_mapper, width=8, location=(0, 0))
2399
+ p.add_layout(color_bar, "right")
2400
+ else:
2401
+ scatter = p.scatter(
2402
+ "UMAP1",
2403
+ "UMAP2",
2404
+ size=markersize,
2405
+ alpha=alpha,
2406
+ color=color_mapper,
2407
+ source=source,
2408
+ legend_field=colorby,
2409
+ )
2410
+ else:
2411
+ # If no color_by provided, use sample_color column from samples_df
2412
+ if "sample_uid" in umap_df.columns or "sample_name" in umap_df.columns:
2413
+ # Choose the identifier to map colors by
2414
+ id_col = "sample_uid" if "sample_uid" in umap_df.columns else "sample_name"
2415
+
2416
+ # Get colors from samples_df based on the identifier
2417
+ if id_col == "sample_uid":
2418
+ sample_colors = (
2419
+ self.samples_df.filter(pl.col("sample_uid").is_in(umap_df[id_col].unique()))
2420
+ .select(["sample_uid", "sample_color"])
2421
+ .to_dict(as_series=False)
2422
+ )
2423
+ color_map = dict(zip(sample_colors["sample_uid"], sample_colors["sample_color"]))
2424
+ else: # sample_name
2425
+ sample_colors = (
2426
+ self.samples_df.filter(pl.col("sample_name").is_in(umap_df[id_col].unique()))
2427
+ .select(["sample_name", "sample_color"])
2428
+ .to_dict(as_series=False)
2429
+ )
2430
+ color_map = dict(zip(sample_colors["sample_name"], sample_colors["sample_color"]))
2431
+
2432
+ # Map colors into dataframe
2433
+ umap_df["color"] = [color_map.get(x, "#1f77b4") for x in umap_df[id_col]] # fallback to blue
2434
+ # Update the ColumnDataSource with new color column
2435
+ source = ColumnDataSource(umap_df)
2436
+ scatter = p.scatter(
2437
+ "UMAP1",
2438
+ "UMAP2",
2439
+ size=markersize,
2440
+ alpha=alpha,
2441
+ color="color",
2442
+ source=source,
2443
+ )
2444
+ else:
2445
+ scatter = p.scatter(
2446
+ "UMAP1",
2447
+ "UMAP2",
2448
+ size=markersize,
2449
+ alpha=alpha,
2450
+ color="blue",
2451
+ source=source,
2452
+ )
2453
+
2454
+ # Create comprehensive hover tooltips with all sample information
2455
+ tooltip_list = []
2456
+
2457
+ # Columns to exclude from tooltips (file paths and internal/plot fields)
2458
+ excluded_cols = {"file_source", "file_path", "sample_path", "map_id", "UMAP1", "UMAP2", "ms1", "ms2", "size"}
2459
+
2460
+ # Add all sample dataframe columns to tooltips, skipping excluded ones
2461
+ for col in samples_pd.columns:
2462
+ if col in excluded_cols:
2463
+ continue
2464
+ if col in umap_df.columns:
2465
+ if col == "sample_color":
2466
+ # Display sample_color as a colored swatch
2467
+ tooltip_list.append(("color", "$color[swatch]:sample_color"))
2468
+ elif umap_df[col].dtype in ["float64", "float32"]:
2469
+ tooltip_list.append((col, f"@{col}{{0.00}}"))
2470
+ else:
2471
+ tooltip_list.append((col, f"@{col}"))
2472
+
2473
+ hover = HoverTool(
2474
+ tooltips=tooltip_list,
2475
+ renderers=[scatter],
2476
+ )
2477
+ p.add_tools(hover)
2478
+
2479
+ # Add legend if using categorical coloring
2480
+ if color_mapper and not isinstance(color_mapper, LinearColorMapper) and colorby:
2481
+ # Only set legend properties if legends exist (avoid Bokeh warning when none created)
2482
+ if getattr(p, "legend", None) and len(p.legend) > 0:
2483
+ p.legend.location = "top_left"
2484
+ p.legend.click_policy = "hide"
2485
+
2486
+ # Apply consistent save/display behavior
2487
+ if filename is not None:
2488
+ # Convert relative paths to absolute paths using study folder as base
2489
+ import os
2490
+ if not os.path.isabs(filename):
2491
+ filename = os.path.join(self.folder, filename)
2492
+
2493
+ # Convert to absolute path for logging
2494
+ abs_filename = os.path.abspath(filename)
2495
+
2496
+ # Use isolated file saving
2497
+ _isolated_save_plot(p, filename, abs_filename, self.logger, "UMAP Plot")
2498
+ else:
2499
+ # Show in notebook when no filename provided
2500
+ _isolated_show_notebook(p)
2501
+ return p
2502
+
2503
+
2229
2504
  def plot_tic(
2230
2505
  self,
2231
2506
  samples=100,
@@ -2246,7 +2521,7 @@ def plot_tic(
2246
2521
  from bokeh.io.export import export_png
2247
2522
  from masster.study.helpers import get_tic
2248
2523
 
2249
- sample_uids = self._get_sample_uids(samples)
2524
+ sample_uids = self._get_samples_uids(samples)
2250
2525
  if not sample_uids:
2251
2526
  self.logger.error("No valid sample_uids provided for TIC plotting.")
2252
2527
  return
@@ -2379,3 +2654,16 @@ def plot_tic(
2379
2654
  _isolated_show_notebook(p)
2380
2655
 
2381
2656
  return p
2657
+
2658
+
2659
+ def plot_pca(self, *args, **kwargs):
2660
+ """Deprecated: Use plot_samples_pca instead."""
2661
+ import warnings
2662
+ warnings.warn("plot_pca is deprecated, use plot_samples_pca instead", DeprecationWarning, stacklevel=2)
2663
+ return self.plot_samples_pca(*args, **kwargs)
2664
+
2665
+ def plot_umap(self, *args, **kwargs):
2666
+ """Deprecated: Use plot_samples_umap instead."""
2667
+ import warnings
2668
+ warnings.warn("plot_umap is deprecated, use plot_samples_umap instead", DeprecationWarning, stacklevel=2)
2669
+ return self.plot_samples_umap(*args, **kwargs)