linkml-store 0.2.6__py3-none-any.whl → 0.2.10rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of linkml-store might be problematic. Click here for more details.
- linkml_store/api/client.py +2 -3
- linkml_store/api/collection.py +63 -8
- linkml_store/api/database.py +20 -3
- linkml_store/api/stores/duckdb/duckdb_collection.py +168 -4
- linkml_store/api/stores/duckdb/duckdb_database.py +5 -5
- linkml_store/api/stores/filesystem/__init__.py +1 -1
- linkml_store/api/stores/filesystem/filesystem_database.py +1 -1
- linkml_store/api/stores/mongodb/mongodb_collection.py +132 -15
- linkml_store/api/stores/mongodb/mongodb_database.py +2 -1
- linkml_store/api/stores/neo4j/neo4j_database.py +1 -1
- linkml_store/api/stores/solr/solr_collection.py +107 -18
- linkml_store/cli.py +201 -21
- linkml_store/index/implementations/llm_indexer.py +13 -6
- linkml_store/index/indexer.py +9 -5
- linkml_store/inference/implementations/llm_inference_engine.py +15 -13
- linkml_store/inference/implementations/rag_inference_engine.py +13 -10
- linkml_store/inference/implementations/sklearn_inference_engine.py +7 -1
- linkml_store/inference/inference_config.py +2 -1
- linkml_store/inference/inference_engine.py +1 -1
- linkml_store/plotting/__init__.py +5 -0
- linkml_store/plotting/cli.py +172 -0
- linkml_store/plotting/heatmap.py +356 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/format_utils.py +124 -3
- linkml_store/utils/llm_utils.py +4 -2
- linkml_store/utils/object_utils.py +9 -3
- linkml_store/utils/pandas_utils.py +1 -1
- linkml_store/utils/sql_utils.py +1 -1
- linkml_store/utils/vector_utils.py +3 -10
- {linkml_store-0.2.6.dist-info → linkml_store-0.2.10rc1.dist-info}/METADATA +3 -1
- {linkml_store-0.2.6.dist-info → linkml_store-0.2.10rc1.dist-info}/RECORD +35 -30
- {linkml_store-0.2.6.dist-info → linkml_store-0.2.10rc1.dist-info}/WHEEL +1 -1
- {linkml_store-0.2.6.dist-info → linkml_store-0.2.10rc1.dist-info}/LICENSE +0 -0
- {linkml_store-0.2.6.dist-info → linkml_store-0.2.10rc1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Heatmap visualization module for LinkML data.
|
|
3
|
+
|
|
4
|
+
This module provides functions to generate heatmaps from pandas DataFrames or tabular data files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import matplotlib.pyplot as plt
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import seaborn as sns
|
|
16
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
17
|
+
from scipy.cluster import hierarchy
|
|
18
|
+
from scipy.spatial import distance
|
|
19
|
+
|
|
20
|
+
from linkml_store.utils.format_utils import Format, load_objects, write_output
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_heatmap(
|
|
26
|
+
data: pd.DataFrame,
|
|
27
|
+
x_column: str,
|
|
28
|
+
y_column: str,
|
|
29
|
+
value_column: Optional[str] = None,
|
|
30
|
+
title: Optional[str] = None,
|
|
31
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
32
|
+
cmap: Union[str, LinearSegmentedColormap] = "YlGnBu",
|
|
33
|
+
annot: bool = True,
|
|
34
|
+
fmt: Optional[str] = None, # Dynamically determined based on data
|
|
35
|
+
linewidths: float = 0.5,
|
|
36
|
+
linecolor: str = "white",
|
|
37
|
+
square: bool = False,
|
|
38
|
+
output_file: Optional[str] = None,
|
|
39
|
+
dpi: int = 300,
|
|
40
|
+
missing_value: Any = np.nan,
|
|
41
|
+
vmin: Optional[float] = None,
|
|
42
|
+
vmax: Optional[float] = None,
|
|
43
|
+
robust: bool = False,
|
|
44
|
+
remove_duplicates: bool = True,
|
|
45
|
+
font_size: int = 10,
|
|
46
|
+
cluster: Union[bool, Literal["both", "x", "y"]] = False,
|
|
47
|
+
cluster_method: str = "complete", # linkage method: complete, average, single, etc.
|
|
48
|
+
cluster_metric: str = "euclidean", # distance metric: euclidean, cosine, etc.
|
|
49
|
+
**kwargs,
|
|
50
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
|
51
|
+
"""
|
|
52
|
+
Create a heatmap from a pandas DataFrame.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
data: Input DataFrame containing the data to plot
|
|
56
|
+
x_column: Column to use for x-axis categories
|
|
57
|
+
y_column: Column to use for y-axis categories
|
|
58
|
+
value_column: Column containing values for the heatmap. If None, frequency counts will be used.
|
|
59
|
+
title: Title for the heatmap
|
|
60
|
+
figsize: Figure size as (width, height) in inches
|
|
61
|
+
cmap: Colormap for the heatmap
|
|
62
|
+
annot: Whether to annotate cells with values
|
|
63
|
+
fmt: String formatting code for annotations (auto-detected if None)
|
|
64
|
+
linewidths: Width of lines between cells
|
|
65
|
+
linecolor: Color of lines between cells
|
|
66
|
+
square: Whether to make cells square
|
|
67
|
+
output_file: File path to save the figure (optional)
|
|
68
|
+
dpi: Resolution for saved figure
|
|
69
|
+
missing_value: Value to use for missing data (defaults to NaN)
|
|
70
|
+
vmin: Minimum value for colormap scaling
|
|
71
|
+
vmax: Maximum value for colormap scaling
|
|
72
|
+
robust: If True, compute colormap limits using robust quantiles instead of min/max
|
|
73
|
+
remove_duplicates: If True, removes duplicate rows before creating the heatmap
|
|
74
|
+
font_size: Font size for annotations
|
|
75
|
+
cluster: Whether and which axes to cluster:
|
|
76
|
+
- False: No clustering (default)
|
|
77
|
+
- True or "both": Cluster both x and y axes
|
|
78
|
+
- "x": Cluster only x-axis
|
|
79
|
+
- "y": Cluster only y-axis
|
|
80
|
+
cluster_method: Linkage method for hierarchical clustering
|
|
81
|
+
(e.g., "single", "complete", "average", "ward")
|
|
82
|
+
cluster_metric: Distance metric for clustering (e.g., "euclidean", "correlation", "cosine")
|
|
83
|
+
**kwargs: Additional keyword arguments to pass to seaborn's heatmap function
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Tuple containing the figure and axes objects
|
|
87
|
+
"""
|
|
88
|
+
# Validate input
|
|
89
|
+
if x_column not in data.columns:
|
|
90
|
+
raise ValueError(f"x_column '{x_column}' not found in DataFrame columns: {list(data.columns)}")
|
|
91
|
+
if y_column not in data.columns:
|
|
92
|
+
raise ValueError(f"y_column '{y_column}' not found in DataFrame columns: {list(data.columns)}")
|
|
93
|
+
if value_column and value_column not in data.columns:
|
|
94
|
+
raise ValueError(f"value_column '{value_column}' not found in DataFrame columns: {list(data.columns)}")
|
|
95
|
+
|
|
96
|
+
# Remove duplicates by default (assume they're accidents unless user overrides)
|
|
97
|
+
if remove_duplicates:
|
|
98
|
+
data = data.drop_duplicates()
|
|
99
|
+
|
|
100
|
+
# Prepare the data
|
|
101
|
+
if value_column:
|
|
102
|
+
# Use the provided value column
|
|
103
|
+
pivot_data = data.pivot_table(
|
|
104
|
+
index=y_column,
|
|
105
|
+
columns=x_column,
|
|
106
|
+
values=value_column,
|
|
107
|
+
aggfunc='mean',
|
|
108
|
+
fill_value=missing_value
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
# Use frequency counts
|
|
112
|
+
cross_tab = pd.crosstab(data[y_column], data[x_column])
|
|
113
|
+
pivot_data = cross_tab
|
|
114
|
+
|
|
115
|
+
# Auto-detect format string if not provided
|
|
116
|
+
if fmt is None:
|
|
117
|
+
# Check if the pivot table contains integers only
|
|
118
|
+
if pivot_data.dtypes.apply(lambda x: pd.api.types.is_integer_dtype(x)).all():
|
|
119
|
+
fmt = 'd' # Integer format
|
|
120
|
+
else:
|
|
121
|
+
fmt = '.1f' # One decimal place for floats
|
|
122
|
+
|
|
123
|
+
# Make sure all cells have a reasonable minimum size
|
|
124
|
+
min_height = max(4, 80 / len(pivot_data.index) if len(pivot_data.index) > 0 else 10)
|
|
125
|
+
min_width = max(4, 80 / len(pivot_data.columns) if len(pivot_data.columns) > 0 else 10)
|
|
126
|
+
|
|
127
|
+
# Adjust figure size based on the number of rows and columns
|
|
128
|
+
adjusted_height = max(figsize[1], min_height * len(pivot_data.index) / 10)
|
|
129
|
+
adjusted_width = max(figsize[0], min_width * len(pivot_data.columns) / 10)
|
|
130
|
+
adjusted_figsize = (adjusted_width, adjusted_height)
|
|
131
|
+
|
|
132
|
+
# Create figure and axes
|
|
133
|
+
fig, ax = plt.subplots(figsize=adjusted_figsize)
|
|
134
|
+
|
|
135
|
+
# Apply clustering if requested
|
|
136
|
+
row_linkage = None
|
|
137
|
+
col_linkage = None
|
|
138
|
+
|
|
139
|
+
if cluster:
|
|
140
|
+
cluster_axes = cluster
|
|
141
|
+
if cluster_axes is True:
|
|
142
|
+
cluster_axes = "both"
|
|
143
|
+
|
|
144
|
+
# Fill NAs for clustering
|
|
145
|
+
pivot_data_for_clustering = pivot_data.fillna(0)
|
|
146
|
+
|
|
147
|
+
# Cluster rows (y-axis)
|
|
148
|
+
if cluster_axes in ["both", "y"]:
|
|
149
|
+
try:
|
|
150
|
+
# Calculate distance matrix and linkage for rows
|
|
151
|
+
row_distances = distance.pdist(pivot_data_for_clustering.values, metric=cluster_metric)
|
|
152
|
+
row_linkage = hierarchy.linkage(row_distances, method=cluster_method)
|
|
153
|
+
|
|
154
|
+
# Reorder rows based on clustering
|
|
155
|
+
row_dendrogram = hierarchy.dendrogram(row_linkage, no_plot=True)
|
|
156
|
+
row_order = row_dendrogram['leaves']
|
|
157
|
+
pivot_data = pivot_data.iloc[row_order]
|
|
158
|
+
|
|
159
|
+
logger.info(f"Applied clustering to rows using {cluster_method} linkage and {cluster_metric} metric")
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.warning(f"Failed to cluster rows: {e}")
|
|
162
|
+
|
|
163
|
+
# Cluster columns (x-axis)
|
|
164
|
+
if cluster_axes in ["both", "x"]:
|
|
165
|
+
try:
|
|
166
|
+
# Calculate distance matrix and linkage for columns
|
|
167
|
+
col_distances = distance.pdist(pivot_data_for_clustering.values.T, metric=cluster_metric)
|
|
168
|
+
col_linkage = hierarchy.linkage(col_distances, method=cluster_method)
|
|
169
|
+
|
|
170
|
+
# Reorder columns based on clustering
|
|
171
|
+
col_dendrogram = hierarchy.dendrogram(col_linkage, no_plot=True)
|
|
172
|
+
col_order = col_dendrogram['leaves']
|
|
173
|
+
pivot_data = pivot_data.iloc[:, col_order]
|
|
174
|
+
|
|
175
|
+
logger.info(f"Applied clustering to columns using {cluster_method} linkage and {cluster_metric} metric")
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.warning(f"Failed to cluster columns: {e}")
|
|
178
|
+
|
|
179
|
+
# Create the heatmap
|
|
180
|
+
sns.heatmap(
|
|
181
|
+
pivot_data,
|
|
182
|
+
cmap=cmap,
|
|
183
|
+
annot=annot,
|
|
184
|
+
fmt=fmt,
|
|
185
|
+
linewidths=linewidths,
|
|
186
|
+
linecolor=linecolor,
|
|
187
|
+
square=square,
|
|
188
|
+
vmin=vmin,
|
|
189
|
+
vmax=vmax,
|
|
190
|
+
robust=robust,
|
|
191
|
+
ax=ax,
|
|
192
|
+
annot_kws={'fontsize': font_size},
|
|
193
|
+
**kwargs
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Set title if provided
|
|
197
|
+
if title:
|
|
198
|
+
ax.set_title(title, fontsize=font_size + 4)
|
|
199
|
+
|
|
200
|
+
# Improve display of tick labels
|
|
201
|
+
plt.xticks(rotation=45, ha="right", fontsize=font_size)
|
|
202
|
+
plt.yticks(rotation=0, fontsize=font_size)
|
|
203
|
+
|
|
204
|
+
# Add grid lines to make the table more readable
|
|
205
|
+
ax.grid(False)
|
|
206
|
+
|
|
207
|
+
# Improve contrast for better readability
|
|
208
|
+
for _, spine in ax.spines.items():
|
|
209
|
+
spine.set_visible(True)
|
|
210
|
+
spine.set_color('black')
|
|
211
|
+
spine.set_linewidth(1)
|
|
212
|
+
|
|
213
|
+
# Adjust layout
|
|
214
|
+
plt.tight_layout()
|
|
215
|
+
|
|
216
|
+
# Save the figure if output file is specified
|
|
217
|
+
if output_file:
|
|
218
|
+
output_path = Path(output_file)
|
|
219
|
+
output_dir = output_path.parent
|
|
220
|
+
if not output_dir.exists():
|
|
221
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
222
|
+
plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
|
|
223
|
+
logger.info(f"Heatmap saved to {output_file}")
|
|
224
|
+
|
|
225
|
+
return fig, ax
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def heatmap_from_file(
|
|
229
|
+
file_path: str,
|
|
230
|
+
x_column: str,
|
|
231
|
+
y_column: str,
|
|
232
|
+
value_column: Optional[str] = None,
|
|
233
|
+
format: Optional[Union[Format, str]] = None,
|
|
234
|
+
compression: Optional[str] = None,
|
|
235
|
+
output_file: Optional[str] = None,
|
|
236
|
+
remove_duplicates: bool = True,
|
|
237
|
+
**kwargs,
|
|
238
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
|
239
|
+
"""
|
|
240
|
+
Create a heatmap from a file (CSV, TSV, etc.).
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
file_path: Path to the input file or "-" for stdin
|
|
244
|
+
x_column: Column to use for x-axis categories
|
|
245
|
+
y_column: Column to use for y-axis categories
|
|
246
|
+
value_column: Column containing values for the heatmap. If None, frequency counts will be used.
|
|
247
|
+
format: Format of the input file (auto-detected if None)
|
|
248
|
+
compression: Compression format ('gz' or 'tgz')
|
|
249
|
+
output_file: File path to save the figure (optional)
|
|
250
|
+
remove_duplicates: If True, removes duplicate rows before creating the heatmap
|
|
251
|
+
**kwargs: Additional arguments to pass to create_heatmap
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Tuple containing the figure and axes objects
|
|
255
|
+
"""
|
|
256
|
+
# Handle stdin input safely
|
|
257
|
+
import sys
|
|
258
|
+
import io
|
|
259
|
+
import pandas as pd
|
|
260
|
+
import click
|
|
261
|
+
|
|
262
|
+
# Load the data
|
|
263
|
+
if file_path == "-":
|
|
264
|
+
# Read directly from stdin since format_utils will use sys.stdin which may already be consumed
|
|
265
|
+
if not format or str(format).lower() in ['csv', 'tsv']:
|
|
266
|
+
# Default to CSV if no format specified
|
|
267
|
+
delimiter = ',' if not format or str(format).lower() == 'csv' else '\t'
|
|
268
|
+
df = pd.read_csv(sys.stdin, delimiter=delimiter)
|
|
269
|
+
else:
|
|
270
|
+
# Try to use format_utils but with a backup plan
|
|
271
|
+
try:
|
|
272
|
+
objs = load_objects(file_path, format=format, compression=compression)
|
|
273
|
+
df = pd.DataFrame(objs)
|
|
274
|
+
except ValueError as e:
|
|
275
|
+
if "I/O operation on closed file" in str(e):
|
|
276
|
+
logger.warning("Could not read from stdin. It may have been consumed already.")
|
|
277
|
+
raise click.UsageError("Error reading from stdin. Please provide a file path or ensure stdin has data.")
|
|
278
|
+
else:
|
|
279
|
+
raise
|
|
280
|
+
else:
|
|
281
|
+
# For regular files, use format_utils as normal
|
|
282
|
+
if (not format or format in ["csv", "tsv"]) and not compression:
|
|
283
|
+
df = pd.read_csv(file_path)
|
|
284
|
+
else:
|
|
285
|
+
objs = load_objects(file_path, format=format, compression=compression)
|
|
286
|
+
df = pd.DataFrame(objs)
|
|
287
|
+
|
|
288
|
+
# Create the heatmap
|
|
289
|
+
return create_heatmap(
|
|
290
|
+
data=df,
|
|
291
|
+
x_column=x_column,
|
|
292
|
+
y_column=y_column,
|
|
293
|
+
value_column=value_column,
|
|
294
|
+
output_file=output_file,
|
|
295
|
+
remove_duplicates=remove_duplicates,
|
|
296
|
+
**kwargs
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def export_heatmap_data(
|
|
301
|
+
data: pd.DataFrame,
|
|
302
|
+
x_column: str,
|
|
303
|
+
y_column: str,
|
|
304
|
+
value_column: Optional[str] = None,
|
|
305
|
+
output_file: Optional[str] = None,
|
|
306
|
+
format: Union[Format, str] = Format.CSV,
|
|
307
|
+
missing_value: Any = np.nan,
|
|
308
|
+
remove_duplicates: bool = True,
|
|
309
|
+
) -> pd.DataFrame:
|
|
310
|
+
"""
|
|
311
|
+
Export heatmap data to a file or return it as a DataFrame.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
data: Input DataFrame containing the data
|
|
315
|
+
x_column: Column to use for x-axis categories
|
|
316
|
+
y_column: Column to use for y-axis categories
|
|
317
|
+
value_column: Column containing values for the heatmap. If None, frequency counts will be used.
|
|
318
|
+
output_file: File path to save the data (optional)
|
|
319
|
+
format: Output format for the file
|
|
320
|
+
missing_value: Value to use for missing data
|
|
321
|
+
remove_duplicates: If True, removes duplicate rows before creating the pivot table
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
DataFrame containing the pivot table data
|
|
325
|
+
"""
|
|
326
|
+
# Remove duplicates by default (assume they're accidents unless user overrides)
|
|
327
|
+
if remove_duplicates:
|
|
328
|
+
# Keep the first occurrence of each x_column, y_column combination
|
|
329
|
+
data = data.drop_duplicates(subset=[x_column, y_column])
|
|
330
|
+
|
|
331
|
+
# Prepare the data
|
|
332
|
+
if value_column:
|
|
333
|
+
# Use the provided value column
|
|
334
|
+
pivot_data = data.pivot_table(
|
|
335
|
+
index=y_column,
|
|
336
|
+
columns=x_column,
|
|
337
|
+
values=value_column,
|
|
338
|
+
aggfunc='mean',
|
|
339
|
+
fill_value=missing_value
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
# Use frequency counts
|
|
343
|
+
cross_tab = pd.crosstab(data[y_column], data[x_column])
|
|
344
|
+
pivot_data = cross_tab
|
|
345
|
+
|
|
346
|
+
# Reset index to make the y_column a regular column
|
|
347
|
+
result_df = pivot_data.reset_index()
|
|
348
|
+
|
|
349
|
+
# Write to file if output_file is provided
|
|
350
|
+
if output_file:
|
|
351
|
+
# Convert to records format for writing
|
|
352
|
+
records = result_df.to_dict(orient='records')
|
|
353
|
+
write_output(records, format=format, target=output_file)
|
|
354
|
+
logger.info(f"Heatmap data saved to {output_file}")
|
|
355
|
+
|
|
356
|
+
return result_df
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
ENTRY = Dict[str, Any]
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_sib_format(text) -> Tuple[Optional[ENTRY], List[ENTRY]]:
|
|
7
|
+
"""
|
|
8
|
+
Parse SIB/Swiss-Prot format data into a structured dictionary.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
text (str): The text in SIB/Swiss-Prot format
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
dict: A dictionary with entry IDs as keys and parsed data as values
|
|
15
|
+
"""
|
|
16
|
+
# Split the text into entries (separated by //)
|
|
17
|
+
entries = text.split("//\n")
|
|
18
|
+
header = None
|
|
19
|
+
|
|
20
|
+
# Initialize results dictionary
|
|
21
|
+
results = []
|
|
22
|
+
|
|
23
|
+
# Parse each entry
|
|
24
|
+
for entry in entries:
|
|
25
|
+
if not entry.strip():
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
# Initialize dictionary for current entry
|
|
29
|
+
current_entry = {}
|
|
30
|
+
current_code = None
|
|
31
|
+
|
|
32
|
+
# Process each line
|
|
33
|
+
for line in entry.strip().split("\n"):
|
|
34
|
+
if not line.strip():
|
|
35
|
+
continue
|
|
36
|
+
|
|
37
|
+
# Check if this is a new field (starts with a 2-letter code followed by space)
|
|
38
|
+
if len(line) > 2 and line[2] == " ":
|
|
39
|
+
current_code = line[0:2]
|
|
40
|
+
# Remove the code and the following space(s)
|
|
41
|
+
value = line[3:].strip()
|
|
42
|
+
|
|
43
|
+
# Initialize as list if needed for multi-line fields
|
|
44
|
+
if current_code not in current_entry:
|
|
45
|
+
current_entry[current_code] = []
|
|
46
|
+
|
|
47
|
+
current_entry[current_code].append(value)
|
|
48
|
+
|
|
49
|
+
# Continuation of previous field
|
|
50
|
+
elif current_code is not None:
|
|
51
|
+
# Handle continuation lines (typically indented)
|
|
52
|
+
if current_code == "CC":
|
|
53
|
+
# For comments, preserve the indentation
|
|
54
|
+
current_entry[current_code].append(line)
|
|
55
|
+
else:
|
|
56
|
+
# For other fields, strip and append
|
|
57
|
+
current_entry[current_code].append(line.strip())
|
|
58
|
+
|
|
59
|
+
# Combine multiline comments; e.g
|
|
60
|
+
# -!- ...
|
|
61
|
+
# ...
|
|
62
|
+
# -!- ...
|
|
63
|
+
ccs = current_entry.get("CC", [])
|
|
64
|
+
new_ccs = []
|
|
65
|
+
for cc in ccs:
|
|
66
|
+
if not cc.startswith("-!-") and new_ccs:
|
|
67
|
+
new_ccs[-1] += " " + cc
|
|
68
|
+
else:
|
|
69
|
+
new_ccs.append(cc)
|
|
70
|
+
current_entry["CC"] = new_ccs
|
|
71
|
+
for k, vs in current_entry.items():
|
|
72
|
+
if k != "CC":
|
|
73
|
+
combined = "".join(vs)
|
|
74
|
+
combined = combined.strip()
|
|
75
|
+
if combined.endswith("."):
|
|
76
|
+
combined = combined.split(".")
|
|
77
|
+
combined = [c.strip() for c in combined if c.strip()]
|
|
78
|
+
if k == "DE":
|
|
79
|
+
combined = combined[0]
|
|
80
|
+
current_entry[k] = combined
|
|
81
|
+
|
|
82
|
+
if "ID" in current_entry:
|
|
83
|
+
results.append(current_entry)
|
|
84
|
+
else:
|
|
85
|
+
header = current_entry
|
|
86
|
+
|
|
87
|
+
return header, results
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Example usage:
|
|
91
|
+
# data = parse_sib_format(text)
|
|
92
|
+
# for entry_id, entry_data in data.items():
|
|
93
|
+
# print(f"Entry: {entry_id}")
|
|
94
|
+
# for code, values in entry_data.items():
|
|
95
|
+
# print(f" {code}: {values}")
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from collections import Counter
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from scipy import stats
|
|
8
|
+
|
|
9
|
+
from linkml_store.api import Collection
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EnrichedCategory(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Information about a category enriched in a sample
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
category: str
|
|
18
|
+
fold_change: float
|
|
19
|
+
original_p_value: float
|
|
20
|
+
adjusted_p_value: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EnrichmentAnalyzer:
|
|
24
|
+
def __init__(self, df: pd.DataFrame, sample_key: str, classification_key: str):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the analyzer with a DataFrame and key column names.
|
|
27
|
+
Precomputes category frequencies for the entire dataset.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
df: DataFrame containing the data
|
|
31
|
+
sample_key: Column name for sample IDs
|
|
32
|
+
classification_key: Column name for category lists
|
|
33
|
+
"""
|
|
34
|
+
self.df = df
|
|
35
|
+
self.sample_key = sample_key
|
|
36
|
+
self.classification_key = classification_key
|
|
37
|
+
|
|
38
|
+
# Precompute global category statistics
|
|
39
|
+
self.global_stats = self._compute_global_stats()
|
|
40
|
+
|
|
41
|
+
# Cache for sample-specific category counts
|
|
42
|
+
self.sample_cache: Dict[str, Counter] = {}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_collection(cls, collection: Collection, sample_key: str, classification_key: str) -> "EnrichmentAnalyzer":
|
|
46
|
+
"""
|
|
47
|
+
Initialize the analyzer with a Collection and key column names.
|
|
48
|
+
Precomputes category frequencies for the entire dataset.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
collection: Collection containing the data
|
|
52
|
+
sample_key: Column name for sample IDs
|
|
53
|
+
classification_key: Column name for category lists
|
|
54
|
+
"""
|
|
55
|
+
column_atts = [sample_key, classification_key]
|
|
56
|
+
results = collection.find(select_cols=column_atts, limit=-1)
|
|
57
|
+
df = results.rows_dataframe
|
|
58
|
+
ea = cls(df, sample_key=sample_key, classification_key=classification_key)
|
|
59
|
+
return ea
|
|
60
|
+
|
|
61
|
+
def _compute_global_stats(self) -> Dict[str, int]:
|
|
62
|
+
"""
|
|
63
|
+
Compute global category frequencies across all samples.
|
|
64
|
+
Returns a dictionary of category -> count
|
|
65
|
+
"""
|
|
66
|
+
global_counter = Counter()
|
|
67
|
+
|
|
68
|
+
# Flatten all categories and count
|
|
69
|
+
for categories in self.df[self.classification_key]:
|
|
70
|
+
if isinstance(categories, list):
|
|
71
|
+
global_counter.update(categories)
|
|
72
|
+
else:
|
|
73
|
+
# Handle case where categories might be a string
|
|
74
|
+
global_counter.update([categories])
|
|
75
|
+
|
|
76
|
+
return global_counter
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def sample_ids(self) -> List[str]:
|
|
80
|
+
df = self.df
|
|
81
|
+
return df[self.sample_key].unique().tolist()
|
|
82
|
+
|
|
83
|
+
def _get_sample_stats(self, sample_id: str) -> Counter:
|
|
84
|
+
"""
|
|
85
|
+
Get category frequencies for a specific sample.
|
|
86
|
+
Uses caching to avoid recomputation.
|
|
87
|
+
"""
|
|
88
|
+
if sample_id in self.sample_cache:
|
|
89
|
+
return self.sample_cache[sample_id]
|
|
90
|
+
|
|
91
|
+
sample_data = self.df[self.df[self.sample_key] == sample_id]
|
|
92
|
+
if sample_data.empty:
|
|
93
|
+
raise KeyError(f"Sample ID '{sample_id}' not found")
|
|
94
|
+
sample_data = sample_data.dropna()
|
|
95
|
+
# if sample_data.empty:
|
|
96
|
+
# raise ValueError(f"Sample ID '{sample_id}' has missing values after dropping NA")
|
|
97
|
+
counter = Counter()
|
|
98
|
+
|
|
99
|
+
for categories in sample_data[self.classification_key]:
|
|
100
|
+
if isinstance(categories, list):
|
|
101
|
+
counter.update(categories)
|
|
102
|
+
else:
|
|
103
|
+
counter.update([categories])
|
|
104
|
+
|
|
105
|
+
self.sample_cache[sample_id] = counter
|
|
106
|
+
return counter
|
|
107
|
+
|
|
108
|
+
def find_enriched_categories(
|
|
109
|
+
self,
|
|
110
|
+
sample_id: str,
|
|
111
|
+
min_occurrences: int = 5,
|
|
112
|
+
p_value_threshold: float = 0.05,
|
|
113
|
+
multiple_testing_correction: str = "bh",
|
|
114
|
+
) -> List[EnrichedCategory]:
|
|
115
|
+
"""
|
|
116
|
+
Find categories that are enriched in the given sample.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
sample_id: ID of the sample to analyze
|
|
120
|
+
min_occurrences: Minimum number of occurrences required for a category
|
|
121
|
+
p_value_threshold: P-value threshold for significance
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of tuples (category, fold_change, p_value) sorted by significance
|
|
125
|
+
"""
|
|
126
|
+
sample_stats = self._get_sample_stats(sample_id)
|
|
127
|
+
total_sample_annotations = sum(sample_stats.values())
|
|
128
|
+
total_global_annotations = sum(self.global_stats.values())
|
|
129
|
+
|
|
130
|
+
results = []
|
|
131
|
+
|
|
132
|
+
for category, sample_count in sample_stats.items():
|
|
133
|
+
global_count = self.global_stats[category]
|
|
134
|
+
|
|
135
|
+
# Skip rare categories
|
|
136
|
+
if global_count < min_occurrences:
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
# Calculate fold change
|
|
140
|
+
sample_freq = sample_count / total_sample_annotations
|
|
141
|
+
global_freq = global_count / total_global_annotations
|
|
142
|
+
fold_change = sample_freq / global_freq if global_freq > 0 else float("inf")
|
|
143
|
+
|
|
144
|
+
# Perform Fisher's exact test
|
|
145
|
+
contingency_table = np.array(
|
|
146
|
+
[
|
|
147
|
+
[sample_count, global_count - sample_count],
|
|
148
|
+
[
|
|
149
|
+
total_sample_annotations - sample_count,
|
|
150
|
+
total_global_annotations - total_sample_annotations - (global_count - sample_count),
|
|
151
|
+
],
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
_, p_value = stats.fisher_exact(contingency_table)
|
|
156
|
+
|
|
157
|
+
if p_value < p_value_threshold:
|
|
158
|
+
results.append((category, fold_change, p_value))
|
|
159
|
+
|
|
160
|
+
if not results:
|
|
161
|
+
return results
|
|
162
|
+
|
|
163
|
+
# Sort by p-value
|
|
164
|
+
results.sort(key=lambda x: x[2])
|
|
165
|
+
|
|
166
|
+
# Apply multiple testing correction
|
|
167
|
+
categories, fold_changes, p_values = zip(*results)
|
|
168
|
+
|
|
169
|
+
if multiple_testing_correction.lower() == "bonf":
|
|
170
|
+
# Bonferroni correction
|
|
171
|
+
n_tests = len(self.global_stats) # Total number of categories tested
|
|
172
|
+
adjusted_p_values = [min(1.0, p * n_tests) for p in p_values]
|
|
173
|
+
|
|
174
|
+
elif multiple_testing_correction.lower() == "bh":
|
|
175
|
+
# Benjamini-Hochberg correction
|
|
176
|
+
n = len(p_values)
|
|
177
|
+
sorted_indices = np.argsort(p_values)
|
|
178
|
+
sorted_p_values = np.array(p_values)[sorted_indices]
|
|
179
|
+
|
|
180
|
+
# Calculate BH adjusted p-values
|
|
181
|
+
adjusted_p_values = np.zeros(n)
|
|
182
|
+
for i, p in enumerate(sorted_p_values):
|
|
183
|
+
adjusted_p_values[i] = p * n / (i + 1)
|
|
184
|
+
|
|
185
|
+
# Ensure monotonicity
|
|
186
|
+
for i in range(n - 2, -1, -1):
|
|
187
|
+
adjusted_p_values[i] = min(adjusted_p_values[i], adjusted_p_values[i + 1])
|
|
188
|
+
|
|
189
|
+
# Restore original order
|
|
190
|
+
inverse_indices = np.argsort(sorted_indices)
|
|
191
|
+
adjusted_p_values = adjusted_p_values[inverse_indices]
|
|
192
|
+
|
|
193
|
+
# Ensure we don't exceed 1.0
|
|
194
|
+
adjusted_p_values = np.minimum(adjusted_p_values, 1.0)
|
|
195
|
+
|
|
196
|
+
else:
|
|
197
|
+
# No correction
|
|
198
|
+
adjusted_p_values = p_values
|
|
199
|
+
|
|
200
|
+
# Filter by adjusted p-value threshold and create final results
|
|
201
|
+
# Create EnrichedCategory objects
|
|
202
|
+
final_results = [
|
|
203
|
+
EnrichedCategory(category=cat, fold_change=fc, original_p_value=p, adjusted_p_value=adj_p)
|
|
204
|
+
for cat, fc, p, adj_p in zip(categories, fold_changes, p_values, adjusted_p_values)
|
|
205
|
+
if adj_p < p_value_threshold
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
# Sort by adjusted p-value
|
|
209
|
+
final_results.sort(key=lambda x: x.adjusted_p_value)
|
|
210
|
+
return final_results
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# Example usage:
|
|
214
|
+
# analyzer = EnrichmentAnalyzer(df, 'sample_id', 'categories')
|
|
215
|
+
# enriched = analyzer.find_enriched_categories('sample1')
|
|
216
|
+
# for category, fold_change, p_value in enriched:
|
|
217
|
+
# print(f"{category}: {fold_change:.2f}x enrichment (p={p_value:.2e})")
|