linkml-store 0.2.6__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of linkml-store might be problematic. Click here for more details.

Files changed (35) hide show
  1. linkml_store/api/client.py +2 -3
  2. linkml_store/api/collection.py +63 -8
  3. linkml_store/api/database.py +20 -3
  4. linkml_store/api/stores/duckdb/duckdb_collection.py +168 -4
  5. linkml_store/api/stores/duckdb/duckdb_database.py +5 -5
  6. linkml_store/api/stores/filesystem/__init__.py +1 -1
  7. linkml_store/api/stores/filesystem/filesystem_database.py +1 -1
  8. linkml_store/api/stores/mongodb/mongodb_collection.py +132 -15
  9. linkml_store/api/stores/mongodb/mongodb_database.py +2 -1
  10. linkml_store/api/stores/neo4j/neo4j_database.py +1 -1
  11. linkml_store/api/stores/solr/solr_collection.py +107 -18
  12. linkml_store/cli.py +201 -21
  13. linkml_store/index/implementations/llm_indexer.py +13 -6
  14. linkml_store/index/indexer.py +9 -5
  15. linkml_store/inference/implementations/llm_inference_engine.py +15 -13
  16. linkml_store/inference/implementations/rag_inference_engine.py +13 -10
  17. linkml_store/inference/implementations/sklearn_inference_engine.py +7 -1
  18. linkml_store/inference/inference_config.py +2 -1
  19. linkml_store/inference/inference_engine.py +1 -1
  20. linkml_store/plotting/__init__.py +5 -0
  21. linkml_store/plotting/cli.py +172 -0
  22. linkml_store/plotting/heatmap.py +356 -0
  23. linkml_store/utils/dat_parser.py +95 -0
  24. linkml_store/utils/enrichment_analyzer.py +217 -0
  25. linkml_store/utils/format_utils.py +124 -3
  26. linkml_store/utils/llm_utils.py +4 -2
  27. linkml_store/utils/object_utils.py +9 -3
  28. linkml_store/utils/pandas_utils.py +1 -1
  29. linkml_store/utils/sql_utils.py +1 -1
  30. linkml_store/utils/vector_utils.py +3 -10
  31. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/METADATA +3 -1
  32. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/RECORD +35 -30
  33. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/WHEEL +1 -1
  34. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/LICENSE +0 -0
  35. {linkml_store-0.2.6.dist-info → linkml_store-0.2.10.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,356 @@
1
+ """
2
+ Heatmap visualization module for LinkML data.
3
+
4
+ This module provides functions to generate heatmaps from pandas DataFrames or tabular data files.
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
11
+
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import pandas as pd
15
+ import seaborn as sns
16
+ from matplotlib.colors import LinearSegmentedColormap
17
+ from scipy.cluster import hierarchy
18
+ from scipy.spatial import distance
19
+
20
+ from linkml_store.utils.format_utils import Format, load_objects, write_output
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def create_heatmap(
26
+ data: pd.DataFrame,
27
+ x_column: str,
28
+ y_column: str,
29
+ value_column: Optional[str] = None,
30
+ title: Optional[str] = None,
31
+ figsize: Tuple[int, int] = (10, 8),
32
+ cmap: Union[str, LinearSegmentedColormap] = "YlGnBu",
33
+ annot: bool = True,
34
+ fmt: Optional[str] = None, # Dynamically determined based on data
35
+ linewidths: float = 0.5,
36
+ linecolor: str = "white",
37
+ square: bool = False,
38
+ output_file: Optional[str] = None,
39
+ dpi: int = 300,
40
+ missing_value: Any = np.nan,
41
+ vmin: Optional[float] = None,
42
+ vmax: Optional[float] = None,
43
+ robust: bool = False,
44
+ remove_duplicates: bool = True,
45
+ font_size: int = 10,
46
+ cluster: Union[bool, Literal["both", "x", "y"]] = False,
47
+ cluster_method: str = "complete", # linkage method: complete, average, single, etc.
48
+ cluster_metric: str = "euclidean", # distance metric: euclidean, cosine, etc.
49
+ **kwargs,
50
+ ) -> Tuple[plt.Figure, plt.Axes]:
51
+ """
52
+ Create a heatmap from a pandas DataFrame.
53
+
54
+ Args:
55
+ data: Input DataFrame containing the data to plot
56
+ x_column: Column to use for x-axis categories
57
+ y_column: Column to use for y-axis categories
58
+ value_column: Column containing values for the heatmap. If None, frequency counts will be used.
59
+ title: Title for the heatmap
60
+ figsize: Figure size as (width, height) in inches
61
+ cmap: Colormap for the heatmap
62
+ annot: Whether to annotate cells with values
63
+ fmt: String formatting code for annotations (auto-detected if None)
64
+ linewidths: Width of lines between cells
65
+ linecolor: Color of lines between cells
66
+ square: Whether to make cells square
67
+ output_file: File path to save the figure (optional)
68
+ dpi: Resolution for saved figure
69
+ missing_value: Value to use for missing data (defaults to NaN)
70
+ vmin: Minimum value for colormap scaling
71
+ vmax: Maximum value for colormap scaling
72
+ robust: If True, compute colormap limits using robust quantiles instead of min/max
73
+ remove_duplicates: If True, removes duplicate rows before creating the heatmap
74
+ font_size: Font size for annotations
75
+ cluster: Whether and which axes to cluster:
76
+ - False: No clustering (default)
77
+ - True or "both": Cluster both x and y axes
78
+ - "x": Cluster only x-axis
79
+ - "y": Cluster only y-axis
80
+ cluster_method: Linkage method for hierarchical clustering
81
+ (e.g., "single", "complete", "average", "ward")
82
+ cluster_metric: Distance metric for clustering (e.g., "euclidean", "correlation", "cosine")
83
+ **kwargs: Additional keyword arguments to pass to seaborn's heatmap function
84
+
85
+ Returns:
86
+ Tuple containing the figure and axes objects
87
+ """
88
+ # Validate input
89
+ if x_column not in data.columns:
90
+ raise ValueError(f"x_column '{x_column}' not found in DataFrame columns: {list(data.columns)}")
91
+ if y_column not in data.columns:
92
+ raise ValueError(f"y_column '{y_column}' not found in DataFrame columns: {list(data.columns)}")
93
+ if value_column and value_column not in data.columns:
94
+ raise ValueError(f"value_column '{value_column}' not found in DataFrame columns: {list(data.columns)}")
95
+
96
+ # Remove duplicates by default (assume they're accidents unless user overrides)
97
+ if remove_duplicates:
98
+ data = data.drop_duplicates()
99
+
100
+ # Prepare the data
101
+ if value_column:
102
+ # Use the provided value column
103
+ pivot_data = data.pivot_table(
104
+ index=y_column,
105
+ columns=x_column,
106
+ values=value_column,
107
+ aggfunc='mean',
108
+ fill_value=missing_value
109
+ )
110
+ else:
111
+ # Use frequency counts
112
+ cross_tab = pd.crosstab(data[y_column], data[x_column])
113
+ pivot_data = cross_tab
114
+
115
+ # Auto-detect format string if not provided
116
+ if fmt is None:
117
+ # Check if the pivot table contains integers only
118
+ if pivot_data.dtypes.apply(lambda x: pd.api.types.is_integer_dtype(x)).all():
119
+ fmt = 'd' # Integer format
120
+ else:
121
+ fmt = '.1f' # One decimal place for floats
122
+
123
+ # Make sure all cells have a reasonable minimum size
124
+ min_height = max(4, 80 / len(pivot_data.index) if len(pivot_data.index) > 0 else 10)
125
+ min_width = max(4, 80 / len(pivot_data.columns) if len(pivot_data.columns) > 0 else 10)
126
+
127
+ # Adjust figure size based on the number of rows and columns
128
+ adjusted_height = max(figsize[1], min_height * len(pivot_data.index) / 10)
129
+ adjusted_width = max(figsize[0], min_width * len(pivot_data.columns) / 10)
130
+ adjusted_figsize = (adjusted_width, adjusted_height)
131
+
132
+ # Create figure and axes
133
+ fig, ax = plt.subplots(figsize=adjusted_figsize)
134
+
135
+ # Apply clustering if requested
136
+ row_linkage = None
137
+ col_linkage = None
138
+
139
+ if cluster:
140
+ cluster_axes = cluster
141
+ if cluster_axes is True:
142
+ cluster_axes = "both"
143
+
144
+ # Fill NAs for clustering
145
+ pivot_data_for_clustering = pivot_data.fillna(0)
146
+
147
+ # Cluster rows (y-axis)
148
+ if cluster_axes in ["both", "y"]:
149
+ try:
150
+ # Calculate distance matrix and linkage for rows
151
+ row_distances = distance.pdist(pivot_data_for_clustering.values, metric=cluster_metric)
152
+ row_linkage = hierarchy.linkage(row_distances, method=cluster_method)
153
+
154
+ # Reorder rows based on clustering
155
+ row_dendrogram = hierarchy.dendrogram(row_linkage, no_plot=True)
156
+ row_order = row_dendrogram['leaves']
157
+ pivot_data = pivot_data.iloc[row_order]
158
+
159
+ logger.info(f"Applied clustering to rows using {cluster_method} linkage and {cluster_metric} metric")
160
+ except Exception as e:
161
+ logger.warning(f"Failed to cluster rows: {e}")
162
+
163
+ # Cluster columns (x-axis)
164
+ if cluster_axes in ["both", "x"]:
165
+ try:
166
+ # Calculate distance matrix and linkage for columns
167
+ col_distances = distance.pdist(pivot_data_for_clustering.values.T, metric=cluster_metric)
168
+ col_linkage = hierarchy.linkage(col_distances, method=cluster_method)
169
+
170
+ # Reorder columns based on clustering
171
+ col_dendrogram = hierarchy.dendrogram(col_linkage, no_plot=True)
172
+ col_order = col_dendrogram['leaves']
173
+ pivot_data = pivot_data.iloc[:, col_order]
174
+
175
+ logger.info(f"Applied clustering to columns using {cluster_method} linkage and {cluster_metric} metric")
176
+ except Exception as e:
177
+ logger.warning(f"Failed to cluster columns: {e}")
178
+
179
+ # Create the heatmap
180
+ sns.heatmap(
181
+ pivot_data,
182
+ cmap=cmap,
183
+ annot=annot,
184
+ fmt=fmt,
185
+ linewidths=linewidths,
186
+ linecolor=linecolor,
187
+ square=square,
188
+ vmin=vmin,
189
+ vmax=vmax,
190
+ robust=robust,
191
+ ax=ax,
192
+ annot_kws={'fontsize': font_size},
193
+ **kwargs
194
+ )
195
+
196
+ # Set title if provided
197
+ if title:
198
+ ax.set_title(title, fontsize=font_size + 4)
199
+
200
+ # Improve display of tick labels
201
+ plt.xticks(rotation=45, ha="right", fontsize=font_size)
202
+ plt.yticks(rotation=0, fontsize=font_size)
203
+
204
+ # Add grid lines to make the table more readable
205
+ ax.grid(False)
206
+
207
+ # Improve contrast for better readability
208
+ for _, spine in ax.spines.items():
209
+ spine.set_visible(True)
210
+ spine.set_color('black')
211
+ spine.set_linewidth(1)
212
+
213
+ # Adjust layout
214
+ plt.tight_layout()
215
+
216
+ # Save the figure if output file is specified
217
+ if output_file:
218
+ output_path = Path(output_file)
219
+ output_dir = output_path.parent
220
+ if not output_dir.exists():
221
+ output_dir.mkdir(parents=True, exist_ok=True)
222
+ plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
223
+ logger.info(f"Heatmap saved to {output_file}")
224
+
225
+ return fig, ax
226
+
227
+
228
+ def heatmap_from_file(
229
+ file_path: str,
230
+ x_column: str,
231
+ y_column: str,
232
+ value_column: Optional[str] = None,
233
+ format: Optional[Union[Format, str]] = None,
234
+ compression: Optional[str] = None,
235
+ output_file: Optional[str] = None,
236
+ remove_duplicates: bool = True,
237
+ **kwargs,
238
+ ) -> Tuple[plt.Figure, plt.Axes]:
239
+ """
240
+ Create a heatmap from a file (CSV, TSV, etc.).
241
+
242
+ Args:
243
+ file_path: Path to the input file or "-" for stdin
244
+ x_column: Column to use for x-axis categories
245
+ y_column: Column to use for y-axis categories
246
+ value_column: Column containing values for the heatmap. If None, frequency counts will be used.
247
+ format: Format of the input file (auto-detected if None)
248
+ compression: Compression format ('gz' or 'tgz')
249
+ output_file: File path to save the figure (optional)
250
+ remove_duplicates: If True, removes duplicate rows before creating the heatmap
251
+ **kwargs: Additional arguments to pass to create_heatmap
252
+
253
+ Returns:
254
+ Tuple containing the figure and axes objects
255
+ """
256
+ # Handle stdin input safely
257
+ import sys
258
+ import io
259
+ import pandas as pd
260
+ import click
261
+
262
+ # Load the data
263
+ if file_path == "-":
264
+ # Read directly from stdin since format_utils will use sys.stdin which may already be consumed
265
+ if not format or str(format).lower() in ['csv', 'tsv']:
266
+ # Default to CSV if no format specified
267
+ delimiter = ',' if not format or str(format).lower() == 'csv' else '\t'
268
+ df = pd.read_csv(sys.stdin, delimiter=delimiter)
269
+ else:
270
+ # Try to use format_utils but with a backup plan
271
+ try:
272
+ objs = load_objects(file_path, format=format, compression=compression)
273
+ df = pd.DataFrame(objs)
274
+ except ValueError as e:
275
+ if "I/O operation on closed file" in str(e):
276
+ logger.warning("Could not read from stdin. It may have been consumed already.")
277
+ raise click.UsageError("Error reading from stdin. Please provide a file path or ensure stdin has data.")
278
+ else:
279
+ raise
280
+ else:
281
+ # For regular files, use format_utils as normal
282
+ if (not format or format in ["csv", "tsv"]) and not compression:
283
+ df = pd.read_csv(file_path)
284
+ else:
285
+ objs = load_objects(file_path, format=format, compression=compression)
286
+ df = pd.DataFrame(objs)
287
+
288
+ # Create the heatmap
289
+ return create_heatmap(
290
+ data=df,
291
+ x_column=x_column,
292
+ y_column=y_column,
293
+ value_column=value_column,
294
+ output_file=output_file,
295
+ remove_duplicates=remove_duplicates,
296
+ **kwargs
297
+ )
298
+
299
+
300
+ def export_heatmap_data(
301
+ data: pd.DataFrame,
302
+ x_column: str,
303
+ y_column: str,
304
+ value_column: Optional[str] = None,
305
+ output_file: Optional[str] = None,
306
+ format: Union[Format, str] = Format.CSV,
307
+ missing_value: Any = np.nan,
308
+ remove_duplicates: bool = True,
309
+ ) -> pd.DataFrame:
310
+ """
311
+ Export heatmap data to a file or return it as a DataFrame.
312
+
313
+ Args:
314
+ data: Input DataFrame containing the data
315
+ x_column: Column to use for x-axis categories
316
+ y_column: Column to use for y-axis categories
317
+ value_column: Column containing values for the heatmap. If None, frequency counts will be used.
318
+ output_file: File path to save the data (optional)
319
+ format: Output format for the file
320
+ missing_value: Value to use for missing data
321
+ remove_duplicates: If True, removes duplicate rows before creating the pivot table
322
+
323
+ Returns:
324
+ DataFrame containing the pivot table data
325
+ """
326
+ # Remove duplicates by default (assume they're accidents unless user overrides)
327
+ if remove_duplicates:
328
+ # Keep the first occurrence of each x_column, y_column combination
329
+ data = data.drop_duplicates(subset=[x_column, y_column])
330
+
331
+ # Prepare the data
332
+ if value_column:
333
+ # Use the provided value column
334
+ pivot_data = data.pivot_table(
335
+ index=y_column,
336
+ columns=x_column,
337
+ values=value_column,
338
+ aggfunc='mean',
339
+ fill_value=missing_value
340
+ )
341
+ else:
342
+ # Use frequency counts
343
+ cross_tab = pd.crosstab(data[y_column], data[x_column])
344
+ pivot_data = cross_tab
345
+
346
+ # Reset index to make the y_column a regular column
347
+ result_df = pivot_data.reset_index()
348
+
349
+ # Write to file if output_file is provided
350
+ if output_file:
351
+ # Convert to records format for writing
352
+ records = result_df.to_dict(orient='records')
353
+ write_output(records, format=format, target=output_file)
354
+ logger.info(f"Heatmap data saved to {output_file}")
355
+
356
+ return result_df
@@ -0,0 +1,95 @@
1
+ from typing import Any, Dict, List, Optional, Tuple
2
+
3
+ ENTRY = Dict[str, Any]
4
+
5
+
6
+ def parse_sib_format(text) -> Tuple[Optional[ENTRY], List[ENTRY]]:
7
+ """
8
+ Parse SIB/Swiss-Prot format data into a structured dictionary.
9
+
10
+ Args:
11
+ text (str): The text in SIB/Swiss-Prot format
12
+
13
+ Returns:
14
+ dict: A dictionary with entry IDs as keys and parsed data as values
15
+ """
16
+ # Split the text into entries (separated by //)
17
+ entries = text.split("//\n")
18
+ header = None
19
+
20
+ # Initialize results dictionary
21
+ results = []
22
+
23
+ # Parse each entry
24
+ for entry in entries:
25
+ if not entry.strip():
26
+ continue
27
+
28
+ # Initialize dictionary for current entry
29
+ current_entry = {}
30
+ current_code = None
31
+
32
+ # Process each line
33
+ for line in entry.strip().split("\n"):
34
+ if not line.strip():
35
+ continue
36
+
37
+ # Check if this is a new field (starts with a 2-letter code followed by space)
38
+ if len(line) > 2 and line[2] == " ":
39
+ current_code = line[0:2]
40
+ # Remove the code and the following space(s)
41
+ value = line[3:].strip()
42
+
43
+ # Initialize as list if needed for multi-line fields
44
+ if current_code not in current_entry:
45
+ current_entry[current_code] = []
46
+
47
+ current_entry[current_code].append(value)
48
+
49
+ # Continuation of previous field
50
+ elif current_code is not None:
51
+ # Handle continuation lines (typically indented)
52
+ if current_code == "CC":
53
+ # For comments, preserve the indentation
54
+ current_entry[current_code].append(line)
55
+ else:
56
+ # For other fields, strip and append
57
+ current_entry[current_code].append(line.strip())
58
+
59
+ # Combine multiline comments; e.g
60
+ # -!- ...
61
+ # ...
62
+ # -!- ...
63
+ ccs = current_entry.get("CC", [])
64
+ new_ccs = []
65
+ for cc in ccs:
66
+ if not cc.startswith("-!-") and new_ccs:
67
+ new_ccs[-1] += " " + cc
68
+ else:
69
+ new_ccs.append(cc)
70
+ current_entry["CC"] = new_ccs
71
+ for k, vs in current_entry.items():
72
+ if k != "CC":
73
+ combined = "".join(vs)
74
+ combined = combined.strip()
75
+ if combined.endswith("."):
76
+ combined = combined.split(".")
77
+ combined = [c.strip() for c in combined if c.strip()]
78
+ if k == "DE":
79
+ combined = combined[0]
80
+ current_entry[k] = combined
81
+
82
+ if "ID" in current_entry:
83
+ results.append(current_entry)
84
+ else:
85
+ header = current_entry
86
+
87
+ return header, results
88
+
89
+
90
+ # Example usage:
91
+ # data = parse_sib_format(text)
92
+ # for entry_id, entry_data in data.items():
93
+ # print(f"Entry: {entry_id}")
94
+ # for code, values in entry_data.items():
95
+ # print(f" {code}: {values}")
@@ -0,0 +1,217 @@
1
+ from collections import Counter
2
+ from typing import Dict, List
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from pydantic import BaseModel
7
+ from scipy import stats
8
+
9
+ from linkml_store.api import Collection
10
+
11
+
12
+ class EnrichedCategory(BaseModel):
13
+ """
14
+ Information about a category enriched in a sample
15
+ """
16
+
17
+ category: str
18
+ fold_change: float
19
+ original_p_value: float
20
+ adjusted_p_value: float
21
+
22
+
23
+ class EnrichmentAnalyzer:
24
+ def __init__(self, df: pd.DataFrame, sample_key: str, classification_key: str):
25
+ """
26
+ Initialize the analyzer with a DataFrame and key column names.
27
+ Precomputes category frequencies for the entire dataset.
28
+
29
+ Args:
30
+ df: DataFrame containing the data
31
+ sample_key: Column name for sample IDs
32
+ classification_key: Column name for category lists
33
+ """
34
+ self.df = df
35
+ self.sample_key = sample_key
36
+ self.classification_key = classification_key
37
+
38
+ # Precompute global category statistics
39
+ self.global_stats = self._compute_global_stats()
40
+
41
+ # Cache for sample-specific category counts
42
+ self.sample_cache: Dict[str, Counter] = {}
43
+
44
+ @classmethod
45
+ def from_collection(cls, collection: Collection, sample_key: str, classification_key: str) -> "EnrichmentAnalyzer":
46
+ """
47
+ Initialize the analyzer with a Collection and key column names.
48
+ Precomputes category frequencies for the entire dataset.
49
+
50
+ Args:
51
+ collection: Collection containing the data
52
+ sample_key: Column name for sample IDs
53
+ classification_key: Column name for category lists
54
+ """
55
+ column_atts = [sample_key, classification_key]
56
+ results = collection.find(select_cols=column_atts, limit=-1)
57
+ df = results.rows_dataframe
58
+ ea = cls(df, sample_key=sample_key, classification_key=classification_key)
59
+ return ea
60
+
61
+ def _compute_global_stats(self) -> Dict[str, int]:
62
+ """
63
+ Compute global category frequencies across all samples.
64
+ Returns a dictionary of category -> count
65
+ """
66
+ global_counter = Counter()
67
+
68
+ # Flatten all categories and count
69
+ for categories in self.df[self.classification_key]:
70
+ if isinstance(categories, list):
71
+ global_counter.update(categories)
72
+ else:
73
+ # Handle case where categories might be a string
74
+ global_counter.update([categories])
75
+
76
+ return global_counter
77
+
78
+ @property
79
+ def sample_ids(self) -> List[str]:
80
+ df = self.df
81
+ return df[self.sample_key].unique().tolist()
82
+
83
+ def _get_sample_stats(self, sample_id: str) -> Counter:
84
+ """
85
+ Get category frequencies for a specific sample.
86
+ Uses caching to avoid recomputation.
87
+ """
88
+ if sample_id in self.sample_cache:
89
+ return self.sample_cache[sample_id]
90
+
91
+ sample_data = self.df[self.df[self.sample_key] == sample_id]
92
+ if sample_data.empty:
93
+ raise KeyError(f"Sample ID '{sample_id}' not found")
94
+ sample_data = sample_data.dropna()
95
+ # if sample_data.empty:
96
+ # raise ValueError(f"Sample ID '{sample_id}' has missing values after dropping NA")
97
+ counter = Counter()
98
+
99
+ for categories in sample_data[self.classification_key]:
100
+ if isinstance(categories, list):
101
+ counter.update(categories)
102
+ else:
103
+ counter.update([categories])
104
+
105
+ self.sample_cache[sample_id] = counter
106
+ return counter
107
+
108
+ def find_enriched_categories(
109
+ self,
110
+ sample_id: str,
111
+ min_occurrences: int = 5,
112
+ p_value_threshold: float = 0.05,
113
+ multiple_testing_correction: str = "bh",
114
+ ) -> List[EnrichedCategory]:
115
+ """
116
+ Find categories that are enriched in the given sample.
117
+
118
+ Args:
119
+ sample_id: ID of the sample to analyze
120
+ min_occurrences: Minimum number of occurrences required for a category
121
+ p_value_threshold: P-value threshold for significance
122
+
123
+ Returns:
124
+ List of tuples (category, fold_change, p_value) sorted by significance
125
+ """
126
+ sample_stats = self._get_sample_stats(sample_id)
127
+ total_sample_annotations = sum(sample_stats.values())
128
+ total_global_annotations = sum(self.global_stats.values())
129
+
130
+ results = []
131
+
132
+ for category, sample_count in sample_stats.items():
133
+ global_count = self.global_stats[category]
134
+
135
+ # Skip rare categories
136
+ if global_count < min_occurrences:
137
+ continue
138
+
139
+ # Calculate fold change
140
+ sample_freq = sample_count / total_sample_annotations
141
+ global_freq = global_count / total_global_annotations
142
+ fold_change = sample_freq / global_freq if global_freq > 0 else float("inf")
143
+
144
+ # Perform Fisher's exact test
145
+ contingency_table = np.array(
146
+ [
147
+ [sample_count, global_count - sample_count],
148
+ [
149
+ total_sample_annotations - sample_count,
150
+ total_global_annotations - total_sample_annotations - (global_count - sample_count),
151
+ ],
152
+ ]
153
+ )
154
+
155
+ _, p_value = stats.fisher_exact(contingency_table)
156
+
157
+ if p_value < p_value_threshold:
158
+ results.append((category, fold_change, p_value))
159
+
160
+ if not results:
161
+ return results
162
+
163
+ # Sort by p-value
164
+ results.sort(key=lambda x: x[2])
165
+
166
+ # Apply multiple testing correction
167
+ categories, fold_changes, p_values = zip(*results)
168
+
169
+ if multiple_testing_correction.lower() == "bonf":
170
+ # Bonferroni correction
171
+ n_tests = len(self.global_stats) # Total number of categories tested
172
+ adjusted_p_values = [min(1.0, p * n_tests) for p in p_values]
173
+
174
+ elif multiple_testing_correction.lower() == "bh":
175
+ # Benjamini-Hochberg correction
176
+ n = len(p_values)
177
+ sorted_indices = np.argsort(p_values)
178
+ sorted_p_values = np.array(p_values)[sorted_indices]
179
+
180
+ # Calculate BH adjusted p-values
181
+ adjusted_p_values = np.zeros(n)
182
+ for i, p in enumerate(sorted_p_values):
183
+ adjusted_p_values[i] = p * n / (i + 1)
184
+
185
+ # Ensure monotonicity
186
+ for i in range(n - 2, -1, -1):
187
+ adjusted_p_values[i] = min(adjusted_p_values[i], adjusted_p_values[i + 1])
188
+
189
+ # Restore original order
190
+ inverse_indices = np.argsort(sorted_indices)
191
+ adjusted_p_values = adjusted_p_values[inverse_indices]
192
+
193
+ # Ensure we don't exceed 1.0
194
+ adjusted_p_values = np.minimum(adjusted_p_values, 1.0)
195
+
196
+ else:
197
+ # No correction
198
+ adjusted_p_values = p_values
199
+
200
+ # Filter by adjusted p-value threshold and create final results
201
+ # Create EnrichedCategory objects
202
+ final_results = [
203
+ EnrichedCategory(category=cat, fold_change=fc, original_p_value=p, adjusted_p_value=adj_p)
204
+ for cat, fc, p, adj_p in zip(categories, fold_changes, p_values, adjusted_p_values)
205
+ if adj_p < p_value_threshold
206
+ ]
207
+
208
+ # Sort by adjusted p-value
209
+ final_results.sort(key=lambda x: x.adjusted_p_value)
210
+ return final_results
211
+
212
+
213
+ # Example usage:
214
+ # analyzer = EnrichmentAnalyzer(df, 'sample_id', 'categories')
215
+ # enriched = analyzer.find_enriched_categories('sample1')
216
+ # for category, fold_change, p_value in enriched:
217
+ # print(f"{category}: {fold_change:.2f}x enrichment (p={p_value:.2e})")