openms-insight 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,374 @@
1
+ """VolcanoPlot component for differential expression visualization."""
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import polars as pl
6
+
7
+ from ..core.base import BaseComponent
8
+ from ..core.registry import register_component
9
+ from ..preprocessing.scatter import build_scatter_columns
10
+
11
+
12
+ @register_component("volcanoplot")
13
+ class VolcanoPlot(BaseComponent):
14
+ """
15
+ Interactive volcano plot for differential expression analysis.
16
+
17
+ Displays log2 fold change (x-axis) vs -log10(p-value) (y-axis) with
18
+ three-category coloring based on significance thresholds. Thresholds
19
+ are passed at render time to avoid cache invalidation when adjusting
20
+ sliders.
21
+
22
+ Features:
23
+ - Client-side significance computation (instant threshold updates)
24
+ - Three-category coloring (up-regulated, down-regulated, not significant)
25
+ - Threshold lines at ±fc_threshold and -log10(p_threshold)
26
+ - Optional labels on significant points
27
+ - Click-to-select with cross-component linking
28
+ - SVG export
29
+
30
+ Example:
31
+ volcano = VolcanoPlot(
32
+ cache_id="protein_volcano",
33
+ data_path="proteins.parquet",
34
+ log2fc_column="log2FC",
35
+ pvalue_column="pvalue",
36
+ label_column="protein_name",
37
+ interactivity={'protein': 'protein_id'},
38
+ filters={'comparison': 'comparison_id'},
39
+ )
40
+
41
+ # Thresholds passed at render time - no cache impact
42
+ volcano(
43
+ state_manager=state,
44
+ fc_threshold=1.0,
45
+ p_threshold=0.05,
46
+ height=500,
47
+ )
48
+ """
49
+
50
+ _component_type: str = "volcanoplot"
51
+
52
+ def __init__(
53
+ self,
54
+ cache_id: str,
55
+ log2fc_column: str = "log2FC",
56
+ pvalue_column: str = "pvalue",
57
+ data: Optional[pl.LazyFrame] = None,
58
+ data_path: Optional[str] = None,
59
+ label_column: Optional[str] = None,
60
+ filters: Optional[Dict[str, str]] = None,
61
+ filter_defaults: Optional[Dict[str, Any]] = None,
62
+ interactivity: Optional[Dict[str, str]] = None,
63
+ cache_path: str = ".",
64
+ regenerate_cache: bool = False,
65
+ title: Optional[str] = None,
66
+ x_label: Optional[str] = None,
67
+ y_label: Optional[str] = None,
68
+ up_color: str = "#E74C3C",
69
+ down_color: str = "#3498DB",
70
+ ns_color: str = "#95A5A6",
71
+ show_threshold_lines: bool = True,
72
+ threshold_line_style: str = "dash",
73
+ **kwargs,
74
+ ):
75
+ """
76
+ Initialize the VolcanoPlot component.
77
+
78
+ Args:
79
+ cache_id: Unique identifier for this component's cache (MANDATORY).
80
+ Creates a folder {cache_path}/{cache_id}/ for cached data.
81
+ log2fc_column: Name of column for log2 fold change (x-axis).
82
+ pvalue_column: Name of column for p-value. Will be transformed
83
+ to -log10(pvalue) for display on y-axis.
84
+ data: Polars LazyFrame with volcano data. Optional if cache exists.
85
+ data_path: Path to parquet file (preferred for large datasets).
86
+ label_column: Name of column for point labels (shown on hover and
87
+ optionally as annotations on significant points).
88
+ filters: Mapping of identifier names to column names for filtering.
89
+ Example: {'comparison': 'comparison_id'} filters by comparison.
90
+ filter_defaults: Default values for filter identifiers when no
91
+ selection is present in state.
92
+ interactivity: Mapping of identifier names to column names for clicks.
93
+ When a point is clicked, sets each identifier to the clicked
94
+ point's value in the corresponding column.
95
+ cache_path: Base path for cache storage. Default "." (current dir).
96
+ regenerate_cache: If True, regenerate cache even if valid cache exists.
97
+ title: Plot title displayed above the volcano plot.
98
+ x_label: X-axis label (default: "log2 Fold Change").
99
+ y_label: Y-axis label (default: "-log10(p-value)").
100
+ up_color: Color for up-regulated points (default: red #E74C3C).
101
+ down_color: Color for down-regulated points (default: blue #3498DB).
102
+ ns_color: Color for not significant points (default: gray #95A5A6).
103
+ show_threshold_lines: Show threshold lines on plot (default: True).
104
+ threshold_line_style: Line style for thresholds (default: "dash").
105
+ **kwargs: Additional configuration options.
106
+ """
107
+ self._log2fc_column = log2fc_column
108
+ self._pvalue_column = pvalue_column
109
+ self._label_column = label_column
110
+ self._title = title
111
+ self._x_label = x_label or "log2 Fold Change"
112
+ self._y_label = y_label or "-log10(p-value)"
113
+ self._up_color = up_color
114
+ self._down_color = down_color
115
+ self._ns_color = ns_color
116
+ self._show_threshold_lines = show_threshold_lines
117
+ self._threshold_line_style = threshold_line_style
118
+
119
+ # Render-time threshold values (set in __call__)
120
+ self._current_fc_threshold: float = 1.0
121
+ self._current_p_threshold: float = 0.05
122
+ self._current_max_labels: int = 10
123
+
124
+ # Computed -log10(pvalue) column name
125
+ self._neglog10p_column = "_neglog10_pvalue"
126
+
127
+ super().__init__(
128
+ cache_id=cache_id,
129
+ data=data,
130
+ data_path=data_path,
131
+ filters=filters,
132
+ filter_defaults=filter_defaults,
133
+ interactivity=interactivity,
134
+ cache_path=cache_path,
135
+ regenerate_cache=regenerate_cache,
136
+ **kwargs,
137
+ )
138
+
139
+ def _validate_columns(self, schema: pl.Schema) -> None:
140
+ """Validate that required columns exist in the data schema."""
141
+ available = set(schema.names())
142
+
143
+ required = [self._log2fc_column, self._pvalue_column]
144
+ missing = [col for col in required if col not in available]
145
+ if missing:
146
+ raise ValueError(
147
+ f"Missing required columns: {missing}. "
148
+ f"Available columns: {sorted(available)}"
149
+ )
150
+
151
+ if self._label_column and self._label_column not in available:
152
+ raise ValueError(
153
+ f"Label column '{self._label_column}' not found. "
154
+ f"Available columns: {sorted(available)}"
155
+ )
156
+
157
+ def _get_component_config_hash_inputs(self) -> Dict[str, Any]:
158
+ """Get inputs for component config hash (cache invalidation)."""
159
+ return {
160
+ "log2fc_column": self._log2fc_column,
161
+ "pvalue_column": self._pvalue_column,
162
+ "label_column": self._label_column,
163
+ # Note: thresholds are NOT included - they're render-time params
164
+ }
165
+
166
+ def _get_cache_config(self) -> Dict[str, Any]:
167
+ """Get configuration that affects cache validity."""
168
+ return {
169
+ "log2fc_column": self._log2fc_column,
170
+ "pvalue_column": self._pvalue_column,
171
+ "label_column": self._label_column,
172
+ "title": self._title,
173
+ "x_label": self._x_label,
174
+ "y_label": self._y_label,
175
+ "up_color": self._up_color,
176
+ "down_color": self._down_color,
177
+ "ns_color": self._ns_color,
178
+ "show_threshold_lines": self._show_threshold_lines,
179
+ "threshold_line_style": self._threshold_line_style,
180
+ }
181
+
182
+ def _restore_cache_config(self, config: Dict[str, Any]) -> None:
183
+ """Restore component-specific configuration from cached config."""
184
+ self._log2fc_column = config.get("log2fc_column", "log2FC")
185
+ self._pvalue_column = config.get("pvalue_column", "pvalue")
186
+ self._label_column = config.get("label_column")
187
+ self._title = config.get("title")
188
+ self._x_label = config.get("x_label", "log2 Fold Change")
189
+ self._y_label = config.get("y_label", "-log10(p-value)")
190
+ self._up_color = config.get("up_color", "#E74C3C")
191
+ self._down_color = config.get("down_color", "#3498DB")
192
+ self._ns_color = config.get("ns_color", "#95A5A6")
193
+ self._show_threshold_lines = config.get("show_threshold_lines", True)
194
+ self._threshold_line_style = config.get("threshold_line_style", "dash")
195
+
196
+ def _preprocess(self) -> None:
197
+ """Preprocess data for volcano plot.
198
+
199
+ Computes -log10(pvalue) and caches the result. No downsampling is
200
+ typically needed for volcano plots (<10K proteins), but we handle
201
+ it if datasets get large.
202
+ """
203
+ if self._raw_data is None:
204
+ raise ValueError("No data provided and no cache exists")
205
+
206
+ # Build list of columns to select
207
+ # Note: pvalue is passed as y_column only (no duplicate value_column)
208
+ extra_cols = [self._label_column] if self._label_column else []
209
+ columns = build_scatter_columns(
210
+ x_column=self._log2fc_column,
211
+ y_column=self._pvalue_column,
212
+ value_column=self._pvalue_column,
213
+ interactivity=self._interactivity,
214
+ filters=self._filters,
215
+ extra_columns=extra_cols if extra_cols else None,
216
+ )
217
+ # Remove duplicates while preserving order
218
+ columns = list(dict.fromkeys(columns))
219
+
220
+ # Select columns and compute -log10(pvalue)
221
+ schema_names = self._raw_data.collect_schema().names()
222
+ available_cols = [c for c in columns if c in schema_names]
223
+
224
+ df = (
225
+ self._raw_data.select(available_cols)
226
+ .with_columns(
227
+ pl.when(pl.col(self._pvalue_column) > 0)
228
+ .then(-pl.col(self._pvalue_column).log(10))
229
+ .otherwise(0.0)
230
+ .alias(self._neglog10p_column)
231
+ )
232
+ .collect()
233
+ )
234
+
235
+ self._preprocessed_data = {"volcanoData": df}
236
+
237
+ def _get_vue_component_name(self) -> str:
238
+ """Return the Vue component name."""
239
+ return "PlotlyVolcano"
240
+
241
+ def _get_data_key(self) -> str:
242
+ """Return the key for the primary data in Vue payload."""
243
+ return "volcanoData"
244
+
245
+ def _prepare_vue_data(self, state: Dict[str, Any]) -> Dict[str, Any]:
246
+ """Prepare filtered data for Vue component.
247
+
248
+ Uses shared prepare_scatter_data for filtering and conversion.
249
+ """
250
+ if self._preprocessed_data is None or not self._preprocessed_data:
251
+ self._load_preprocessed_data()
252
+
253
+ data = self._preprocessed_data["volcanoData"]
254
+ # Handle both LazyFrame (from cache) and DataFrame
255
+ if isinstance(data, pl.LazyFrame):
256
+ df_polars = data.collect()
257
+ else:
258
+ df_polars = data
259
+
260
+ # Build columns to select (remove duplicates)
261
+ extra_cols = (
262
+ [self._label_column, self._pvalue_column]
263
+ if self._label_column
264
+ else [self._pvalue_column]
265
+ )
266
+ columns = build_scatter_columns(
267
+ x_column=self._log2fc_column,
268
+ y_column=self._neglog10p_column,
269
+ value_column=self._neglog10p_column,
270
+ interactivity=self._interactivity,
271
+ filters=self._filters,
272
+ extra_columns=extra_cols,
273
+ )
274
+ # Remove duplicates while preserving order
275
+ columns = list(dict.fromkeys(columns))
276
+
277
+ # Apply filters if any
278
+ if self._filters:
279
+ from ..preprocessing.filtering import (
280
+ compute_dataframe_hash,
281
+ filter_and_collect_cached,
282
+ )
283
+
284
+ df_pandas, data_hash = filter_and_collect_cached(
285
+ df_polars.lazy(),
286
+ self._filters,
287
+ state,
288
+ columns=columns,
289
+ filter_defaults=self._filter_defaults,
290
+ )
291
+
292
+ # Sort by significance (most significant on top for rendering)
293
+ if len(df_pandas) > 0 and self._neglog10p_column in df_pandas.columns:
294
+ df_pandas = df_pandas.sort_values(
295
+ self._neglog10p_column, ascending=True
296
+ ).reset_index(drop=True)
297
+
298
+ return {"volcanoData": df_pandas, "_hash": data_hash}
299
+ else:
300
+ # No filters - select columns and convert to pandas
301
+ available_cols = [c for c in columns if c in df_polars.columns]
302
+ df_filtered = df_polars.select(available_cols)
303
+
304
+ # Sort by significance
305
+ if self._neglog10p_column in df_filtered.columns:
306
+ df_filtered = df_filtered.sort(self._neglog10p_column, descending=False)
307
+
308
+ from ..preprocessing.filtering import compute_dataframe_hash
309
+
310
+ data_hash = compute_dataframe_hash(df_filtered)
311
+ df_pandas = df_filtered.to_pandas()
312
+
313
+ return {"volcanoData": df_pandas, "_hash": data_hash}
314
+
315
+ def _get_component_args(self) -> Dict[str, Any]:
316
+ """Return configuration for Vue component."""
317
+ return {
318
+ "componentType": self._get_vue_component_name(),
319
+ "log2fcColumn": self._log2fc_column,
320
+ "neglog10pColumn": self._neglog10p_column,
321
+ "pvalueColumn": self._pvalue_column,
322
+ "labelColumn": self._label_column,
323
+ "title": self._title,
324
+ "xLabel": self._x_label,
325
+ "yLabel": self._y_label,
326
+ "upColor": self._up_color,
327
+ "downColor": self._down_color,
328
+ "nsColor": self._ns_color,
329
+ "showThresholdLines": self._show_threshold_lines,
330
+ "thresholdLineStyle": self._threshold_line_style,
331
+ # Render-time threshold values
332
+ "fcThreshold": self._current_fc_threshold,
333
+ "pThreshold": self._current_p_threshold,
334
+ "maxLabels": self._current_max_labels,
335
+ "interactivity": self._interactivity or {},
336
+ }
337
+
338
+ def __call__(
339
+ self,
340
+ key: Optional[str] = None,
341
+ state_manager: Optional[Any] = None,
342
+ height: Optional[int] = None,
343
+ fc_threshold: float = 1.0,
344
+ p_threshold: float = 0.05,
345
+ max_labels: int = 10,
346
+ ) -> Any:
347
+ """
348
+ Render the volcano plot component.
349
+
350
+ Args:
351
+ key: Optional unique key for this component instance.
352
+ state_manager: StateManager for cross-component linking.
353
+ height: Optional height override in pixels.
354
+ fc_threshold: Fold change threshold for significance
355
+ (default: 1.0, meaning |log2FC| >= 1).
356
+ p_threshold: P-value threshold for significance
357
+ (default: 0.05). Points with p < threshold are significant.
358
+ max_labels: Maximum number of labels to show on significant
359
+ points (default: 10). Labels are shown for top N by
360
+ significance.
361
+
362
+ Returns:
363
+ Component result for Streamlit rendering.
364
+ """
365
+ # Store render-time threshold values
366
+ self._current_fc_threshold = fc_threshold
367
+ self._current_p_threshold = p_threshold
368
+ self._current_max_labels = max_labels
369
+
370
+ # Update height if provided
371
+ if height is not None:
372
+ self._height = height
373
+
374
+ return super().__call__(key=key, state_manager=state_manager, height=height)
@@ -19,6 +19,10 @@ if TYPE_CHECKING:
19
19
  # Version 3: Downcast numeric types (Int64→Int32, Float64→Float32) for efficient transfer
20
20
  CACHE_VERSION = 3
21
21
 
22
+ # Default height for components when not specified
23
+ # This is the single source of truth for component height
24
+ DEFAULT_COMPONENT_HEIGHT = 400
25
+
22
26
 
23
27
  class BaseComponent(ABC):
24
28
  """
@@ -318,6 +322,9 @@ class BaseComponent(ABC):
318
322
  "data_values": {},
319
323
  }
320
324
 
325
+ # Check if files were already saved during preprocessing (e.g., cascading)
326
+ files_already_saved = self._preprocessed_data.pop("_files_already_saved", False)
327
+
321
328
  # Save preprocessed data with type optimization for efficient transfer
322
329
  # Float64→Float32 reduces Arrow payload size
323
330
  # Int64→Int32 (when safe) avoids BigInt overhead in JavaScript
@@ -325,18 +332,28 @@ class BaseComponent(ABC):
325
332
  if isinstance(value, pl.LazyFrame):
326
333
  filename = f"{key}.parquet"
327
334
  filepath = preprocessed_dir / filename
328
- # Apply streaming-safe optimization (Float64→Float32 only)
329
- # Int64 bounds checking would require collect(), breaking streaming
330
- value = optimize_for_transfer_lazy(value)
331
- value.sink_parquet(filepath, compression="zstd")
332
- manifest["data_files"][key] = filename
335
+
336
+ if files_already_saved and filepath.exists():
337
+ # File was saved during preprocessing (cascading) - just register it
338
+ manifest["data_files"][key] = filename
339
+ else:
340
+ # Apply streaming-safe optimization (Float64→Float32 only)
341
+ # Int64 bounds checking would require collect(), breaking streaming
342
+ value = optimize_for_transfer_lazy(value)
343
+ value.sink_parquet(filepath, compression="zstd")
344
+ manifest["data_files"][key] = filename
333
345
  elif isinstance(value, pl.DataFrame):
334
346
  filename = f"{key}.parquet"
335
347
  filepath = preprocessed_dir / filename
336
- # Full optimization including Int64→Int32 with bounds checking
337
- value = optimize_for_transfer(value)
338
- value.write_parquet(filepath, compression="zstd")
339
- manifest["data_files"][key] = filename
348
+
349
+ if files_already_saved and filepath.exists():
350
+ # File was saved during preprocessing - just register it
351
+ manifest["data_files"][key] = filename
352
+ else:
353
+ # Full optimization including Int64→Int32 with bounds checking
354
+ value = optimize_for_transfer(value)
355
+ value.write_parquet(filepath, compression="zstd")
356
+ manifest["data_files"][key] = filename
340
357
  elif self._is_json_serializable(value):
341
358
  manifest["data_values"][key] = value
342
359
 
@@ -472,7 +489,8 @@ class BaseComponent(ABC):
472
489
  key: Optional unique key for the Streamlit component
473
490
  state_manager: Optional StateManager for cross-component state.
474
491
  If not provided, uses a default shared StateManager.
475
- height: Optional height in pixels for the component
492
+ height: Optional height in pixels for the component.
493
+ If not provided, uses DEFAULT_COMPONENT_HEIGHT (400px).
476
494
 
477
495
  Returns:
478
496
  The value returned by the Vue component (usually selection state)
@@ -483,6 +501,10 @@ class BaseComponent(ABC):
483
501
  if state_manager is None:
484
502
  state_manager = get_default_state_manager()
485
503
 
504
+ # Use default height if not specified
505
+ if height is None:
506
+ height = DEFAULT_COMPONENT_HEIGHT
507
+
486
508
  return render_component(
487
509
  component=self, state_manager=state_manager, key=key, height=height
488
510
  )