linkml-store 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- linkml_store/__init__.py +7 -0
- linkml_store/api/__init__.py +8 -0
- linkml_store/api/client.py +414 -0
- linkml_store/api/collection.py +1280 -0
- linkml_store/api/config.py +187 -0
- linkml_store/api/database.py +862 -0
- linkml_store/api/queries.py +69 -0
- linkml_store/api/stores/__init__.py +0 -0
- linkml_store/api/stores/chromadb/__init__.py +7 -0
- linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
- linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
- linkml_store/api/stores/dremio/__init__.py +10 -0
- linkml_store/api/stores/dremio/dremio_collection.py +555 -0
- linkml_store/api/stores/dremio/dremio_database.py +1052 -0
- linkml_store/api/stores/dremio/mappings.py +105 -0
- linkml_store/api/stores/dremio_rest/__init__.py +11 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
- linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
- linkml_store/api/stores/duckdb/__init__.py +16 -0
- linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
- linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
- linkml_store/api/stores/duckdb/mappings.py +8 -0
- linkml_store/api/stores/filesystem/__init__.py +15 -0
- linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
- linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
- linkml_store/api/stores/hdf5/__init__.py +7 -0
- linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
- linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
- linkml_store/api/stores/ibis/__init__.py +5 -0
- linkml_store/api/stores/ibis/ibis_collection.py +488 -0
- linkml_store/api/stores/ibis/ibis_database.py +328 -0
- linkml_store/api/stores/mongodb/__init__.py +25 -0
- linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
- linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
- linkml_store/api/stores/neo4j/__init__.py +0 -0
- linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
- linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
- linkml_store/api/stores/solr/__init__.py +3 -0
- linkml_store/api/stores/solr/solr_collection.py +224 -0
- linkml_store/api/stores/solr/solr_database.py +83 -0
- linkml_store/api/stores/solr/solr_utils.py +0 -0
- linkml_store/api/types.py +4 -0
- linkml_store/cli.py +1147 -0
- linkml_store/constants.py +7 -0
- linkml_store/graphs/__init__.py +0 -0
- linkml_store/graphs/graph_map.py +24 -0
- linkml_store/index/__init__.py +53 -0
- linkml_store/index/implementations/__init__.py +0 -0
- linkml_store/index/implementations/llm_indexer.py +174 -0
- linkml_store/index/implementations/simple_indexer.py +43 -0
- linkml_store/index/indexer.py +211 -0
- linkml_store/inference/__init__.py +13 -0
- linkml_store/inference/evaluation.py +195 -0
- linkml_store/inference/implementations/__init__.py +0 -0
- linkml_store/inference/implementations/llm_inference_engine.py +154 -0
- linkml_store/inference/implementations/rag_inference_engine.py +276 -0
- linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
- linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
- linkml_store/inference/inference_config.py +66 -0
- linkml_store/inference/inference_engine.py +209 -0
- linkml_store/inference/inference_engine_registry.py +74 -0
- linkml_store/plotting/__init__.py +5 -0
- linkml_store/plotting/cli.py +826 -0
- linkml_store/plotting/dimensionality_reduction.py +453 -0
- linkml_store/plotting/embedding_plot.py +489 -0
- linkml_store/plotting/facet_chart.py +73 -0
- linkml_store/plotting/heatmap.py +383 -0
- linkml_store/utils/__init__.py +0 -0
- linkml_store/utils/change_utils.py +17 -0
- linkml_store/utils/dat_parser.py +95 -0
- linkml_store/utils/embedding_matcher.py +424 -0
- linkml_store/utils/embedding_utils.py +299 -0
- linkml_store/utils/enrichment_analyzer.py +217 -0
- linkml_store/utils/file_utils.py +37 -0
- linkml_store/utils/format_utils.py +550 -0
- linkml_store/utils/io.py +38 -0
- linkml_store/utils/llm_utils.py +122 -0
- linkml_store/utils/mongodb_utils.py +145 -0
- linkml_store/utils/neo4j_utils.py +42 -0
- linkml_store/utils/object_utils.py +190 -0
- linkml_store/utils/pandas_utils.py +93 -0
- linkml_store/utils/patch_utils.py +126 -0
- linkml_store/utils/query_utils.py +89 -0
- linkml_store/utils/schema_utils.py +23 -0
- linkml_store/utils/sklearn_utils.py +193 -0
- linkml_store/utils/sql_utils.py +177 -0
- linkml_store/utils/stats_utils.py +53 -0
- linkml_store/utils/vector_utils.py +158 -0
- linkml_store/webapi/__init__.py +0 -0
- linkml_store/webapi/html/__init__.py +3 -0
- linkml_store/webapi/html/base.html.j2 +24 -0
- linkml_store/webapi/html/collection_details.html.j2 +15 -0
- linkml_store/webapi/html/database_details.html.j2 +16 -0
- linkml_store/webapi/html/databases.html.j2 +14 -0
- linkml_store/webapi/html/generic.html.j2 +43 -0
- linkml_store/webapi/main.py +855 -0
- linkml_store-0.3.0.dist-info/METADATA +226 -0
- linkml_store-0.3.0.dist-info/RECORD +101 -0
- linkml_store-0.3.0.dist-info/WHEEL +4 -0
- linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
- linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""Utilities for extracting and processing embeddings from indexed collections."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
5
|
+
import numpy as np
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class EmbeddingData:
|
|
13
|
+
"""Container for embedding data from collections."""
|
|
14
|
+
|
|
15
|
+
vectors: np.ndarray
|
|
16
|
+
metadata: List[Dict]
|
|
17
|
+
collection_names: List[str]
|
|
18
|
+
collection_indices: List[int]
|
|
19
|
+
object_ids: List[str]
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def n_samples(self) -> int:
|
|
23
|
+
"""Number of samples."""
|
|
24
|
+
return len(self.vectors)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def n_dimensions(self) -> int:
|
|
28
|
+
"""Number of dimensions in embeddings."""
|
|
29
|
+
return self.vectors.shape[1] if len(self.vectors.shape) > 1 else 0
|
|
30
|
+
|
|
31
|
+
def get_metadata_values(self, field: str) -> List:
|
|
32
|
+
"""Extract values for a specific metadata field."""
|
|
33
|
+
return [m.get(field) for m in self.metadata]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def extract_embeddings_from_collection(
|
|
37
|
+
collection,
|
|
38
|
+
index_name: str = None,
|
|
39
|
+
limit: Optional[int] = None,
|
|
40
|
+
include_metadata: bool = True,
|
|
41
|
+
metadata_fields: Optional[List[str]] = None
|
|
42
|
+
) -> EmbeddingData:
|
|
43
|
+
"""
|
|
44
|
+
Extract embeddings from an indexed collection.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
collection: LinkML collection object
|
|
48
|
+
index_name: Name of the index to use (defaults to first available)
|
|
49
|
+
limit: Maximum number of embeddings to extract
|
|
50
|
+
include_metadata: Whether to include source object metadata
|
|
51
|
+
metadata_fields: Specific metadata fields to include (None = all)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
EmbeddingData object containing vectors and metadata
|
|
55
|
+
"""
|
|
56
|
+
# Get the index name - handle collections without loaded indexers
|
|
57
|
+
if index_name is None:
|
|
58
|
+
# Try to find index collections directly
|
|
59
|
+
db = collection.parent
|
|
60
|
+
all_collections = db.list_collection_names()
|
|
61
|
+
# TODO: use the indexer metadata to find the index name
|
|
62
|
+
index_prefix = f"internal__index__{collection.alias}__"
|
|
63
|
+
index_collections = [c for c in all_collections if c.startswith(index_prefix)]
|
|
64
|
+
|
|
65
|
+
if not index_collections:
|
|
66
|
+
raise ValueError(f"Collection {collection.alias} has no indexes")
|
|
67
|
+
|
|
68
|
+
# Extract index name from first index collection
|
|
69
|
+
index_name = index_collections[0].replace(index_prefix, "")
|
|
70
|
+
if len(index_collections) > 1:
|
|
71
|
+
logger.warning(f"Multiple indexes found, using: {index_name}")
|
|
72
|
+
|
|
73
|
+
# Get the index collection
|
|
74
|
+
index_collection_name = f"internal__index__{collection.alias}__{index_name}"
|
|
75
|
+
index_collection = collection.parent.get_collection(index_collection_name)
|
|
76
|
+
|
|
77
|
+
# Query the index collection
|
|
78
|
+
query_result = index_collection.find(limit=limit)
|
|
79
|
+
|
|
80
|
+
if query_result.num_rows == 0:
|
|
81
|
+
raise ValueError(f"No indexed data found in {index_collection_name}")
|
|
82
|
+
|
|
83
|
+
vectors = []
|
|
84
|
+
metadata = []
|
|
85
|
+
object_ids = []
|
|
86
|
+
|
|
87
|
+
for row in query_result.rows:
|
|
88
|
+
# Extract vector (usually stored in __index__ field)
|
|
89
|
+
vector = row.get("__index__")
|
|
90
|
+
if vector is None:
|
|
91
|
+
logger.warning(f"No vector found for object {row.get('id')}")
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
vectors.append(vector)
|
|
95
|
+
|
|
96
|
+
# Extract object ID
|
|
97
|
+
obj_id = row.get("id") or row.get("_id") or str(len(vectors))
|
|
98
|
+
object_ids.append(obj_id)
|
|
99
|
+
|
|
100
|
+
# Extract metadata
|
|
101
|
+
if include_metadata:
|
|
102
|
+
meta = {}
|
|
103
|
+
if metadata_fields:
|
|
104
|
+
# Only include specified fields
|
|
105
|
+
for field in metadata_fields:
|
|
106
|
+
if field in row:
|
|
107
|
+
meta[field] = row[field]
|
|
108
|
+
else:
|
|
109
|
+
# Include all fields except the vector
|
|
110
|
+
meta = {k: v for k, v in row.items() if k != "__index__"}
|
|
111
|
+
metadata.append(meta)
|
|
112
|
+
|
|
113
|
+
return EmbeddingData(
|
|
114
|
+
vectors=np.array(vectors),
|
|
115
|
+
metadata=metadata,
|
|
116
|
+
collection_names=[collection.alias] * len(vectors),
|
|
117
|
+
collection_indices=[0] * len(vectors),
|
|
118
|
+
object_ids=object_ids
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def extract_embeddings_from_multiple_collections(
|
|
123
|
+
database,
|
|
124
|
+
collection_names: List[str],
|
|
125
|
+
index_name: Optional[str] = None,
|
|
126
|
+
limit_per_collection: Optional[int] = None,
|
|
127
|
+
include_metadata: bool = True,
|
|
128
|
+
metadata_fields: Optional[List[str]] = None,
|
|
129
|
+
normalize: bool = False
|
|
130
|
+
) -> EmbeddingData:
|
|
131
|
+
"""
|
|
132
|
+
Extract embeddings from multiple collections.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
database: LinkML database object
|
|
136
|
+
collection_names: List of collection names
|
|
137
|
+
index_name: Name of index to use (must be same across collections)
|
|
138
|
+
limit_per_collection: Max embeddings per collection
|
|
139
|
+
include_metadata: Whether to include source object metadata
|
|
140
|
+
metadata_fields: Specific metadata fields to include
|
|
141
|
+
normalize: Whether to normalize vectors to unit length
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Combined EmbeddingData object
|
|
145
|
+
"""
|
|
146
|
+
all_vectors = []
|
|
147
|
+
all_metadata = []
|
|
148
|
+
all_collection_names = []
|
|
149
|
+
all_collection_indices = []
|
|
150
|
+
all_object_ids = []
|
|
151
|
+
|
|
152
|
+
for i, coll_name in enumerate(collection_names):
|
|
153
|
+
try:
|
|
154
|
+
collection = database.get_collection(coll_name)
|
|
155
|
+
data = extract_embeddings_from_collection(
|
|
156
|
+
collection,
|
|
157
|
+
index_name=index_name,
|
|
158
|
+
limit=limit_per_collection,
|
|
159
|
+
include_metadata=include_metadata,
|
|
160
|
+
metadata_fields=metadata_fields
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
all_vectors.append(data.vectors)
|
|
164
|
+
all_metadata.extend(data.metadata)
|
|
165
|
+
all_collection_names.extend([coll_name] * data.n_samples)
|
|
166
|
+
all_collection_indices.extend([i] * data.n_samples)
|
|
167
|
+
all_object_ids.extend(data.object_ids)
|
|
168
|
+
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"Failed to extract embeddings from {coll_name}: {e}")
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
if not all_vectors:
|
|
174
|
+
raise ValueError("No embeddings extracted from any collection")
|
|
175
|
+
|
|
176
|
+
combined_vectors = np.vstack(all_vectors)
|
|
177
|
+
|
|
178
|
+
# Normalize if requested
|
|
179
|
+
if normalize:
|
|
180
|
+
# Ensure vectors are float type for division
|
|
181
|
+
combined_vectors = combined_vectors.astype(np.float64)
|
|
182
|
+
norms = np.linalg.norm(combined_vectors, axis=1, keepdims=True)
|
|
183
|
+
# Avoid division by zero
|
|
184
|
+
norms = np.where(norms == 0, 1, norms)
|
|
185
|
+
combined_vectors = combined_vectors / norms
|
|
186
|
+
|
|
187
|
+
return EmbeddingData(
|
|
188
|
+
vectors=combined_vectors,
|
|
189
|
+
metadata=all_metadata,
|
|
190
|
+
collection_names=all_collection_names,
|
|
191
|
+
collection_indices=all_collection_indices,
|
|
192
|
+
object_ids=all_object_ids
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def sample_embeddings(
|
|
197
|
+
embedding_data: EmbeddingData,
|
|
198
|
+
n_samples: int = 1000,
|
|
199
|
+
method: str = "random",
|
|
200
|
+
random_state: Optional[int] = None
|
|
201
|
+
) -> EmbeddingData:
|
|
202
|
+
"""
|
|
203
|
+
Sample embeddings for visualization.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
embedding_data: Original embedding data
|
|
207
|
+
n_samples: Number of samples to select
|
|
208
|
+
method: Sampling method ('random', 'uniform', 'density')
|
|
209
|
+
random_state: Random seed for reproducibility
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Sampled EmbeddingData object
|
|
213
|
+
"""
|
|
214
|
+
if embedding_data.n_samples <= n_samples:
|
|
215
|
+
return embedding_data
|
|
216
|
+
|
|
217
|
+
if random_state is not None:
|
|
218
|
+
np.random.seed(random_state)
|
|
219
|
+
|
|
220
|
+
if method == "random":
|
|
221
|
+
indices = np.random.choice(
|
|
222
|
+
embedding_data.n_samples,
|
|
223
|
+
size=n_samples,
|
|
224
|
+
replace=False
|
|
225
|
+
)
|
|
226
|
+
elif method == "uniform":
|
|
227
|
+
# Sample uniformly across collections
|
|
228
|
+
indices = []
|
|
229
|
+
for coll_idx in set(embedding_data.collection_indices):
|
|
230
|
+
coll_mask = np.array(embedding_data.collection_indices) == coll_idx
|
|
231
|
+
coll_indices = np.where(coll_mask)[0]
|
|
232
|
+
n_from_coll = min(
|
|
233
|
+
len(coll_indices),
|
|
234
|
+
n_samples // len(set(embedding_data.collection_indices))
|
|
235
|
+
)
|
|
236
|
+
indices.extend(
|
|
237
|
+
np.random.choice(coll_indices, size=n_from_coll, replace=False)
|
|
238
|
+
)
|
|
239
|
+
indices = np.array(indices[:n_samples])
|
|
240
|
+
else:
|
|
241
|
+
raise ValueError(f"Unknown sampling method: {method}")
|
|
242
|
+
|
|
243
|
+
return EmbeddingData(
|
|
244
|
+
vectors=embedding_data.vectors[indices],
|
|
245
|
+
metadata=[embedding_data.metadata[i] for i in indices],
|
|
246
|
+
collection_names=[embedding_data.collection_names[i] for i in indices],
|
|
247
|
+
collection_indices=[embedding_data.collection_indices[i] for i in indices],
|
|
248
|
+
object_ids=[embedding_data.object_ids[i] for i in indices]
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def compute_embedding_statistics(embedding_data: EmbeddingData) -> Dict:
|
|
253
|
+
"""
|
|
254
|
+
Compute statistics about embeddings.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
embedding_data: Embedding data
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Dictionary of statistics
|
|
261
|
+
"""
|
|
262
|
+
stats = {
|
|
263
|
+
"n_samples": embedding_data.n_samples,
|
|
264
|
+
"n_dimensions": embedding_data.n_dimensions,
|
|
265
|
+
"n_collections": len(set(embedding_data.collection_names)),
|
|
266
|
+
"collections": list(set(embedding_data.collection_names)),
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
# Per-collection counts
|
|
270
|
+
from collections import Counter
|
|
271
|
+
collection_counts = Counter(embedding_data.collection_names)
|
|
272
|
+
stats["samples_per_collection"] = dict(collection_counts)
|
|
273
|
+
|
|
274
|
+
# Vector statistics
|
|
275
|
+
if embedding_data.n_samples > 0:
|
|
276
|
+
stats["mean_norm"] = float(np.mean(np.linalg.norm(embedding_data.vectors, axis=1)))
|
|
277
|
+
stats["std_norm"] = float(np.std(np.linalg.norm(embedding_data.vectors, axis=1)))
|
|
278
|
+
|
|
279
|
+
# Compute average pairwise similarity (on sample if large)
|
|
280
|
+
sample_size = min(100, embedding_data.n_samples)
|
|
281
|
+
if sample_size > 1:
|
|
282
|
+
sample_indices = np.random.choice(
|
|
283
|
+
embedding_data.n_samples,
|
|
284
|
+
size=sample_size,
|
|
285
|
+
replace=False
|
|
286
|
+
)
|
|
287
|
+
sample_vectors = embedding_data.vectors[sample_indices]
|
|
288
|
+
|
|
289
|
+
# Normalize for cosine similarity
|
|
290
|
+
norms = np.linalg.norm(sample_vectors, axis=1, keepdims=True)
|
|
291
|
+
normalized = sample_vectors / (norms + 1e-10)
|
|
292
|
+
similarities = np.dot(normalized, normalized.T)
|
|
293
|
+
|
|
294
|
+
# Extract upper triangle (excluding diagonal)
|
|
295
|
+
upper_tri = similarities[np.triu_indices(sample_size, k=1)]
|
|
296
|
+
stats["mean_similarity"] = float(np.mean(upper_tri))
|
|
297
|
+
stats["std_similarity"] = float(np.std(upper_tri))
|
|
298
|
+
|
|
299
|
+
return stats
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from collections import Counter
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from scipy import stats
|
|
8
|
+
|
|
9
|
+
from linkml_store.api import Collection
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EnrichedCategory(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Information about a category enriched in a sample
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
category: str
|
|
18
|
+
fold_change: float
|
|
19
|
+
original_p_value: float
|
|
20
|
+
adjusted_p_value: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EnrichmentAnalyzer:
|
|
24
|
+
def __init__(self, df: pd.DataFrame, sample_key: str, classification_key: str):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the analyzer with a DataFrame and key column names.
|
|
27
|
+
Precomputes category frequencies for the entire dataset.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
df: DataFrame containing the data
|
|
31
|
+
sample_key: Column name for sample IDs
|
|
32
|
+
classification_key: Column name for category lists
|
|
33
|
+
"""
|
|
34
|
+
self.df = df
|
|
35
|
+
self.sample_key = sample_key
|
|
36
|
+
self.classification_key = classification_key
|
|
37
|
+
|
|
38
|
+
# Precompute global category statistics
|
|
39
|
+
self.global_stats = self._compute_global_stats()
|
|
40
|
+
|
|
41
|
+
# Cache for sample-specific category counts
|
|
42
|
+
self.sample_cache: Dict[str, Counter] = {}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_collection(cls, collection: Collection, sample_key: str, classification_key: str) -> "EnrichmentAnalyzer":
|
|
46
|
+
"""
|
|
47
|
+
Initialize the analyzer with a Collection and key column names.
|
|
48
|
+
Precomputes category frequencies for the entire dataset.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
collection: Collection containing the data
|
|
52
|
+
sample_key: Column name for sample IDs
|
|
53
|
+
classification_key: Column name for category lists
|
|
54
|
+
"""
|
|
55
|
+
column_atts = [sample_key, classification_key]
|
|
56
|
+
results = collection.find(select_cols=column_atts, limit=-1)
|
|
57
|
+
df = results.rows_dataframe
|
|
58
|
+
ea = cls(df, sample_key=sample_key, classification_key=classification_key)
|
|
59
|
+
return ea
|
|
60
|
+
|
|
61
|
+
def _compute_global_stats(self) -> Dict[str, int]:
|
|
62
|
+
"""
|
|
63
|
+
Compute global category frequencies across all samples.
|
|
64
|
+
Returns a dictionary of category -> count
|
|
65
|
+
"""
|
|
66
|
+
global_counter = Counter()
|
|
67
|
+
|
|
68
|
+
# Flatten all categories and count
|
|
69
|
+
for categories in self.df[self.classification_key]:
|
|
70
|
+
if isinstance(categories, list):
|
|
71
|
+
global_counter.update(categories)
|
|
72
|
+
else:
|
|
73
|
+
# Handle case where categories might be a string
|
|
74
|
+
global_counter.update([categories])
|
|
75
|
+
|
|
76
|
+
return global_counter
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def sample_ids(self) -> List[str]:
|
|
80
|
+
df = self.df
|
|
81
|
+
return df[self.sample_key].unique().tolist()
|
|
82
|
+
|
|
83
|
+
def _get_sample_stats(self, sample_id: str) -> Counter:
|
|
84
|
+
"""
|
|
85
|
+
Get category frequencies for a specific sample.
|
|
86
|
+
Uses caching to avoid recomputation.
|
|
87
|
+
"""
|
|
88
|
+
if sample_id in self.sample_cache:
|
|
89
|
+
return self.sample_cache[sample_id]
|
|
90
|
+
|
|
91
|
+
sample_data = self.df[self.df[self.sample_key] == sample_id]
|
|
92
|
+
if sample_data.empty:
|
|
93
|
+
raise KeyError(f"Sample ID '{sample_id}' not found")
|
|
94
|
+
sample_data = sample_data.dropna()
|
|
95
|
+
# if sample_data.empty:
|
|
96
|
+
# raise ValueError(f"Sample ID '{sample_id}' has missing values after dropping NA")
|
|
97
|
+
counter = Counter()
|
|
98
|
+
|
|
99
|
+
for categories in sample_data[self.classification_key]:
|
|
100
|
+
if isinstance(categories, list):
|
|
101
|
+
counter.update(categories)
|
|
102
|
+
else:
|
|
103
|
+
counter.update([categories])
|
|
104
|
+
|
|
105
|
+
self.sample_cache[sample_id] = counter
|
|
106
|
+
return counter
|
|
107
|
+
|
|
108
|
+
def find_enriched_categories(
|
|
109
|
+
self,
|
|
110
|
+
sample_id: str,
|
|
111
|
+
min_occurrences: int = 5,
|
|
112
|
+
p_value_threshold: float = 0.05,
|
|
113
|
+
multiple_testing_correction: str = "bh",
|
|
114
|
+
) -> List[EnrichedCategory]:
|
|
115
|
+
"""
|
|
116
|
+
Find categories that are enriched in the given sample.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
sample_id: ID of the sample to analyze
|
|
120
|
+
min_occurrences: Minimum number of occurrences required for a category
|
|
121
|
+
p_value_threshold: P-value threshold for significance
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of tuples (category, fold_change, p_value) sorted by significance
|
|
125
|
+
"""
|
|
126
|
+
sample_stats = self._get_sample_stats(sample_id)
|
|
127
|
+
total_sample_annotations = sum(sample_stats.values())
|
|
128
|
+
total_global_annotations = sum(self.global_stats.values())
|
|
129
|
+
|
|
130
|
+
results = []
|
|
131
|
+
|
|
132
|
+
for category, sample_count in sample_stats.items():
|
|
133
|
+
global_count = self.global_stats[category]
|
|
134
|
+
|
|
135
|
+
# Skip rare categories
|
|
136
|
+
if global_count < min_occurrences:
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
# Calculate fold change
|
|
140
|
+
sample_freq = sample_count / total_sample_annotations
|
|
141
|
+
global_freq = global_count / total_global_annotations
|
|
142
|
+
fold_change = sample_freq / global_freq if global_freq > 0 else float("inf")
|
|
143
|
+
|
|
144
|
+
# Perform Fisher's exact test
|
|
145
|
+
contingency_table = np.array(
|
|
146
|
+
[
|
|
147
|
+
[sample_count, global_count - sample_count],
|
|
148
|
+
[
|
|
149
|
+
total_sample_annotations - sample_count,
|
|
150
|
+
total_global_annotations - total_sample_annotations - (global_count - sample_count),
|
|
151
|
+
],
|
|
152
|
+
]
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
_, p_value = stats.fisher_exact(contingency_table)
|
|
156
|
+
|
|
157
|
+
if p_value < p_value_threshold:
|
|
158
|
+
results.append((category, fold_change, p_value))
|
|
159
|
+
|
|
160
|
+
if not results:
|
|
161
|
+
return results
|
|
162
|
+
|
|
163
|
+
# Sort by p-value
|
|
164
|
+
results.sort(key=lambda x: x[2])
|
|
165
|
+
|
|
166
|
+
# Apply multiple testing correction
|
|
167
|
+
categories, fold_changes, p_values = zip(*results)
|
|
168
|
+
|
|
169
|
+
if multiple_testing_correction.lower() == "bonf":
|
|
170
|
+
# Bonferroni correction
|
|
171
|
+
n_tests = len(self.global_stats) # Total number of categories tested
|
|
172
|
+
adjusted_p_values = [min(1.0, p * n_tests) for p in p_values]
|
|
173
|
+
|
|
174
|
+
elif multiple_testing_correction.lower() == "bh":
|
|
175
|
+
# Benjamini-Hochberg correction
|
|
176
|
+
n = len(p_values)
|
|
177
|
+
sorted_indices = np.argsort(p_values)
|
|
178
|
+
sorted_p_values = np.array(p_values)[sorted_indices]
|
|
179
|
+
|
|
180
|
+
# Calculate BH adjusted p-values
|
|
181
|
+
adjusted_p_values = np.zeros(n)
|
|
182
|
+
for i, p in enumerate(sorted_p_values):
|
|
183
|
+
adjusted_p_values[i] = p * n / (i + 1)
|
|
184
|
+
|
|
185
|
+
# Ensure monotonicity
|
|
186
|
+
for i in range(n - 2, -1, -1):
|
|
187
|
+
adjusted_p_values[i] = min(adjusted_p_values[i], adjusted_p_values[i + 1])
|
|
188
|
+
|
|
189
|
+
# Restore original order
|
|
190
|
+
inverse_indices = np.argsort(sorted_indices)
|
|
191
|
+
adjusted_p_values = adjusted_p_values[inverse_indices]
|
|
192
|
+
|
|
193
|
+
# Ensure we don't exceed 1.0
|
|
194
|
+
adjusted_p_values = np.minimum(adjusted_p_values, 1.0)
|
|
195
|
+
|
|
196
|
+
else:
|
|
197
|
+
# No correction
|
|
198
|
+
adjusted_p_values = p_values
|
|
199
|
+
|
|
200
|
+
# Filter by adjusted p-value threshold and create final results
|
|
201
|
+
# Create EnrichedCategory objects
|
|
202
|
+
final_results = [
|
|
203
|
+
EnrichedCategory(category=cat, fold_change=fc, original_p_value=p, adjusted_p_value=adj_p)
|
|
204
|
+
for cat, fc, p, adj_p in zip(categories, fold_changes, p_values, adjusted_p_values)
|
|
205
|
+
if adj_p < p_value_threshold
|
|
206
|
+
]
|
|
207
|
+
|
|
208
|
+
# Sort by adjusted p-value
|
|
209
|
+
final_results.sort(key=lambda x: x.adjusted_p_value)
|
|
210
|
+
return final_results
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# Example usage:
|
|
214
|
+
# analyzer = EnrichmentAnalyzer(df, 'sample_id', 'categories')
|
|
215
|
+
# enriched = analyzer.find_enriched_categories('sample1')
|
|
216
|
+
# for category, fold_change, p_value in enriched:
|
|
217
|
+
# print(f"{category}: {fold_change:.2f}x enrichment (p={p_value:.2e})")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import shutil
|
|
3
|
+
import tempfile
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
# Set up logging
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def safe_remove_directory(dir_path: Path, no_backup: bool = False) -> Optional[Path]:
|
|
13
|
+
# Ensure the directory exists
|
|
14
|
+
if not dir_path.exists():
|
|
15
|
+
raise FileNotFoundError(f"Directory does not exist: {dir_path}")
|
|
16
|
+
try:
|
|
17
|
+
if no_backup:
|
|
18
|
+
# Move to a temporary directory instead of permanent removal
|
|
19
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
20
|
+
tmp_path = Path(tmpdir) / dir_path.name
|
|
21
|
+
shutil.move(str(dir_path), str(tmp_path))
|
|
22
|
+
logger.info(f"Directory moved to temporary location: {tmp_path}")
|
|
23
|
+
# The directory will be automatically removed when exiting the context manager
|
|
24
|
+
return None
|
|
25
|
+
else:
|
|
26
|
+
# Create a backup directory name with timestamp
|
|
27
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
28
|
+
backup_dir = dir_path.with_name(f"{dir_path.name}_backup_{timestamp}")
|
|
29
|
+
|
|
30
|
+
# Move the directory to the backup location
|
|
31
|
+
shutil.move(str(dir_path), str(backup_dir))
|
|
32
|
+
logger.info(f"Directory backed up to: {backup_dir}")
|
|
33
|
+
return backup_dir
|
|
34
|
+
|
|
35
|
+
except Exception as e:
|
|
36
|
+
logger.error(f"An error occurred: {e}")
|
|
37
|
+
return None
|