linkml-store 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. linkml_store/__init__.py +7 -0
  2. linkml_store/api/__init__.py +8 -0
  3. linkml_store/api/client.py +414 -0
  4. linkml_store/api/collection.py +1280 -0
  5. linkml_store/api/config.py +187 -0
  6. linkml_store/api/database.py +862 -0
  7. linkml_store/api/queries.py +69 -0
  8. linkml_store/api/stores/__init__.py +0 -0
  9. linkml_store/api/stores/chromadb/__init__.py +7 -0
  10. linkml_store/api/stores/chromadb/chromadb_collection.py +121 -0
  11. linkml_store/api/stores/chromadb/chromadb_database.py +89 -0
  12. linkml_store/api/stores/dremio/__init__.py +10 -0
  13. linkml_store/api/stores/dremio/dremio_collection.py +555 -0
  14. linkml_store/api/stores/dremio/dremio_database.py +1052 -0
  15. linkml_store/api/stores/dremio/mappings.py +105 -0
  16. linkml_store/api/stores/dremio_rest/__init__.py +11 -0
  17. linkml_store/api/stores/dremio_rest/dremio_rest_collection.py +502 -0
  18. linkml_store/api/stores/dremio_rest/dremio_rest_database.py +1023 -0
  19. linkml_store/api/stores/duckdb/__init__.py +16 -0
  20. linkml_store/api/stores/duckdb/duckdb_collection.py +339 -0
  21. linkml_store/api/stores/duckdb/duckdb_database.py +283 -0
  22. linkml_store/api/stores/duckdb/mappings.py +8 -0
  23. linkml_store/api/stores/filesystem/__init__.py +15 -0
  24. linkml_store/api/stores/filesystem/filesystem_collection.py +186 -0
  25. linkml_store/api/stores/filesystem/filesystem_database.py +81 -0
  26. linkml_store/api/stores/hdf5/__init__.py +7 -0
  27. linkml_store/api/stores/hdf5/hdf5_collection.py +104 -0
  28. linkml_store/api/stores/hdf5/hdf5_database.py +79 -0
  29. linkml_store/api/stores/ibis/__init__.py +5 -0
  30. linkml_store/api/stores/ibis/ibis_collection.py +488 -0
  31. linkml_store/api/stores/ibis/ibis_database.py +328 -0
  32. linkml_store/api/stores/mongodb/__init__.py +25 -0
  33. linkml_store/api/stores/mongodb/mongodb_collection.py +379 -0
  34. linkml_store/api/stores/mongodb/mongodb_database.py +114 -0
  35. linkml_store/api/stores/neo4j/__init__.py +0 -0
  36. linkml_store/api/stores/neo4j/neo4j_collection.py +429 -0
  37. linkml_store/api/stores/neo4j/neo4j_database.py +154 -0
  38. linkml_store/api/stores/solr/__init__.py +3 -0
  39. linkml_store/api/stores/solr/solr_collection.py +224 -0
  40. linkml_store/api/stores/solr/solr_database.py +83 -0
  41. linkml_store/api/stores/solr/solr_utils.py +0 -0
  42. linkml_store/api/types.py +4 -0
  43. linkml_store/cli.py +1147 -0
  44. linkml_store/constants.py +7 -0
  45. linkml_store/graphs/__init__.py +0 -0
  46. linkml_store/graphs/graph_map.py +24 -0
  47. linkml_store/index/__init__.py +53 -0
  48. linkml_store/index/implementations/__init__.py +0 -0
  49. linkml_store/index/implementations/llm_indexer.py +174 -0
  50. linkml_store/index/implementations/simple_indexer.py +43 -0
  51. linkml_store/index/indexer.py +211 -0
  52. linkml_store/inference/__init__.py +13 -0
  53. linkml_store/inference/evaluation.py +195 -0
  54. linkml_store/inference/implementations/__init__.py +0 -0
  55. linkml_store/inference/implementations/llm_inference_engine.py +154 -0
  56. linkml_store/inference/implementations/rag_inference_engine.py +276 -0
  57. linkml_store/inference/implementations/rule_based_inference_engine.py +169 -0
  58. linkml_store/inference/implementations/sklearn_inference_engine.py +314 -0
  59. linkml_store/inference/inference_config.py +66 -0
  60. linkml_store/inference/inference_engine.py +209 -0
  61. linkml_store/inference/inference_engine_registry.py +74 -0
  62. linkml_store/plotting/__init__.py +5 -0
  63. linkml_store/plotting/cli.py +826 -0
  64. linkml_store/plotting/dimensionality_reduction.py +453 -0
  65. linkml_store/plotting/embedding_plot.py +489 -0
  66. linkml_store/plotting/facet_chart.py +73 -0
  67. linkml_store/plotting/heatmap.py +383 -0
  68. linkml_store/utils/__init__.py +0 -0
  69. linkml_store/utils/change_utils.py +17 -0
  70. linkml_store/utils/dat_parser.py +95 -0
  71. linkml_store/utils/embedding_matcher.py +424 -0
  72. linkml_store/utils/embedding_utils.py +299 -0
  73. linkml_store/utils/enrichment_analyzer.py +217 -0
  74. linkml_store/utils/file_utils.py +37 -0
  75. linkml_store/utils/format_utils.py +550 -0
  76. linkml_store/utils/io.py +38 -0
  77. linkml_store/utils/llm_utils.py +122 -0
  78. linkml_store/utils/mongodb_utils.py +145 -0
  79. linkml_store/utils/neo4j_utils.py +42 -0
  80. linkml_store/utils/object_utils.py +190 -0
  81. linkml_store/utils/pandas_utils.py +93 -0
  82. linkml_store/utils/patch_utils.py +126 -0
  83. linkml_store/utils/query_utils.py +89 -0
  84. linkml_store/utils/schema_utils.py +23 -0
  85. linkml_store/utils/sklearn_utils.py +193 -0
  86. linkml_store/utils/sql_utils.py +177 -0
  87. linkml_store/utils/stats_utils.py +53 -0
  88. linkml_store/utils/vector_utils.py +158 -0
  89. linkml_store/webapi/__init__.py +0 -0
  90. linkml_store/webapi/html/__init__.py +3 -0
  91. linkml_store/webapi/html/base.html.j2 +24 -0
  92. linkml_store/webapi/html/collection_details.html.j2 +15 -0
  93. linkml_store/webapi/html/database_details.html.j2 +16 -0
  94. linkml_store/webapi/html/databases.html.j2 +14 -0
  95. linkml_store/webapi/html/generic.html.j2 +43 -0
  96. linkml_store/webapi/main.py +855 -0
  97. linkml_store-0.3.0.dist-info/METADATA +226 -0
  98. linkml_store-0.3.0.dist-info/RECORD +101 -0
  99. linkml_store-0.3.0.dist-info/WHEEL +4 -0
  100. linkml_store-0.3.0.dist-info/entry_points.txt +3 -0
  101. linkml_store-0.3.0.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,299 @@
1
+ """Utilities for extracting and processing embeddings from indexed collections."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+ import numpy as np
6
+ from dataclasses import dataclass, field
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ @dataclass
12
+ class EmbeddingData:
13
+ """Container for embedding data from collections."""
14
+
15
+ vectors: np.ndarray
16
+ metadata: List[Dict]
17
+ collection_names: List[str]
18
+ collection_indices: List[int]
19
+ object_ids: List[str]
20
+
21
+ @property
22
+ def n_samples(self) -> int:
23
+ """Number of samples."""
24
+ return len(self.vectors)
25
+
26
+ @property
27
+ def n_dimensions(self) -> int:
28
+ """Number of dimensions in embeddings."""
29
+ return self.vectors.shape[1] if len(self.vectors.shape) > 1 else 0
30
+
31
+ def get_metadata_values(self, field: str) -> List:
32
+ """Extract values for a specific metadata field."""
33
+ return [m.get(field) for m in self.metadata]
34
+
35
+
36
+ def extract_embeddings_from_collection(
37
+ collection,
38
+ index_name: str = None,
39
+ limit: Optional[int] = None,
40
+ include_metadata: bool = True,
41
+ metadata_fields: Optional[List[str]] = None
42
+ ) -> EmbeddingData:
43
+ """
44
+ Extract embeddings from an indexed collection.
45
+
46
+ Args:
47
+ collection: LinkML collection object
48
+ index_name: Name of the index to use (defaults to first available)
49
+ limit: Maximum number of embeddings to extract
50
+ include_metadata: Whether to include source object metadata
51
+ metadata_fields: Specific metadata fields to include (None = all)
52
+
53
+ Returns:
54
+ EmbeddingData object containing vectors and metadata
55
+ """
56
+ # Get the index name - handle collections without loaded indexers
57
+ if index_name is None:
58
+ # Try to find index collections directly
59
+ db = collection.parent
60
+ all_collections = db.list_collection_names()
61
+ # TODO: use the indexer metadata to find the index name
62
+ index_prefix = f"internal__index__{collection.alias}__"
63
+ index_collections = [c for c in all_collections if c.startswith(index_prefix)]
64
+
65
+ if not index_collections:
66
+ raise ValueError(f"Collection {collection.alias} has no indexes")
67
+
68
+ # Extract index name from first index collection
69
+ index_name = index_collections[0].replace(index_prefix, "")
70
+ if len(index_collections) > 1:
71
+ logger.warning(f"Multiple indexes found, using: {index_name}")
72
+
73
+ # Get the index collection
74
+ index_collection_name = f"internal__index__{collection.alias}__{index_name}"
75
+ index_collection = collection.parent.get_collection(index_collection_name)
76
+
77
+ # Query the index collection
78
+ query_result = index_collection.find(limit=limit)
79
+
80
+ if query_result.num_rows == 0:
81
+ raise ValueError(f"No indexed data found in {index_collection_name}")
82
+
83
+ vectors = []
84
+ metadata = []
85
+ object_ids = []
86
+
87
+ for row in query_result.rows:
88
+ # Extract vector (usually stored in __index__ field)
89
+ vector = row.get("__index__")
90
+ if vector is None:
91
+ logger.warning(f"No vector found for object {row.get('id')}")
92
+ continue
93
+
94
+ vectors.append(vector)
95
+
96
+ # Extract object ID
97
+ obj_id = row.get("id") or row.get("_id") or str(len(vectors))
98
+ object_ids.append(obj_id)
99
+
100
+ # Extract metadata
101
+ if include_metadata:
102
+ meta = {}
103
+ if metadata_fields:
104
+ # Only include specified fields
105
+ for field in metadata_fields:
106
+ if field in row:
107
+ meta[field] = row[field]
108
+ else:
109
+ # Include all fields except the vector
110
+ meta = {k: v for k, v in row.items() if k != "__index__"}
111
+ metadata.append(meta)
112
+
113
+ return EmbeddingData(
114
+ vectors=np.array(vectors),
115
+ metadata=metadata,
116
+ collection_names=[collection.alias] * len(vectors),
117
+ collection_indices=[0] * len(vectors),
118
+ object_ids=object_ids
119
+ )
120
+
121
+
122
+ def extract_embeddings_from_multiple_collections(
123
+ database,
124
+ collection_names: List[str],
125
+ index_name: Optional[str] = None,
126
+ limit_per_collection: Optional[int] = None,
127
+ include_metadata: bool = True,
128
+ metadata_fields: Optional[List[str]] = None,
129
+ normalize: bool = False
130
+ ) -> EmbeddingData:
131
+ """
132
+ Extract embeddings from multiple collections.
133
+
134
+ Args:
135
+ database: LinkML database object
136
+ collection_names: List of collection names
137
+ index_name: Name of index to use (must be same across collections)
138
+ limit_per_collection: Max embeddings per collection
139
+ include_metadata: Whether to include source object metadata
140
+ metadata_fields: Specific metadata fields to include
141
+ normalize: Whether to normalize vectors to unit length
142
+
143
+ Returns:
144
+ Combined EmbeddingData object
145
+ """
146
+ all_vectors = []
147
+ all_metadata = []
148
+ all_collection_names = []
149
+ all_collection_indices = []
150
+ all_object_ids = []
151
+
152
+ for i, coll_name in enumerate(collection_names):
153
+ try:
154
+ collection = database.get_collection(coll_name)
155
+ data = extract_embeddings_from_collection(
156
+ collection,
157
+ index_name=index_name,
158
+ limit=limit_per_collection,
159
+ include_metadata=include_metadata,
160
+ metadata_fields=metadata_fields
161
+ )
162
+
163
+ all_vectors.append(data.vectors)
164
+ all_metadata.extend(data.metadata)
165
+ all_collection_names.extend([coll_name] * data.n_samples)
166
+ all_collection_indices.extend([i] * data.n_samples)
167
+ all_object_ids.extend(data.object_ids)
168
+
169
+ except Exception as e:
170
+ logger.error(f"Failed to extract embeddings from {coll_name}: {e}")
171
+ continue
172
+
173
+ if not all_vectors:
174
+ raise ValueError("No embeddings extracted from any collection")
175
+
176
+ combined_vectors = np.vstack(all_vectors)
177
+
178
+ # Normalize if requested
179
+ if normalize:
180
+ # Ensure vectors are float type for division
181
+ combined_vectors = combined_vectors.astype(np.float64)
182
+ norms = np.linalg.norm(combined_vectors, axis=1, keepdims=True)
183
+ # Avoid division by zero
184
+ norms = np.where(norms == 0, 1, norms)
185
+ combined_vectors = combined_vectors / norms
186
+
187
+ return EmbeddingData(
188
+ vectors=combined_vectors,
189
+ metadata=all_metadata,
190
+ collection_names=all_collection_names,
191
+ collection_indices=all_collection_indices,
192
+ object_ids=all_object_ids
193
+ )
194
+
195
+
196
+ def sample_embeddings(
197
+ embedding_data: EmbeddingData,
198
+ n_samples: int = 1000,
199
+ method: str = "random",
200
+ random_state: Optional[int] = None
201
+ ) -> EmbeddingData:
202
+ """
203
+ Sample embeddings for visualization.
204
+
205
+ Args:
206
+ embedding_data: Original embedding data
207
+ n_samples: Number of samples to select
208
+ method: Sampling method ('random', 'uniform', 'density')
209
+ random_state: Random seed for reproducibility
210
+
211
+ Returns:
212
+ Sampled EmbeddingData object
213
+ """
214
+ if embedding_data.n_samples <= n_samples:
215
+ return embedding_data
216
+
217
+ if random_state is not None:
218
+ np.random.seed(random_state)
219
+
220
+ if method == "random":
221
+ indices = np.random.choice(
222
+ embedding_data.n_samples,
223
+ size=n_samples,
224
+ replace=False
225
+ )
226
+ elif method == "uniform":
227
+ # Sample uniformly across collections
228
+ indices = []
229
+ for coll_idx in set(embedding_data.collection_indices):
230
+ coll_mask = np.array(embedding_data.collection_indices) == coll_idx
231
+ coll_indices = np.where(coll_mask)[0]
232
+ n_from_coll = min(
233
+ len(coll_indices),
234
+ n_samples // len(set(embedding_data.collection_indices))
235
+ )
236
+ indices.extend(
237
+ np.random.choice(coll_indices, size=n_from_coll, replace=False)
238
+ )
239
+ indices = np.array(indices[:n_samples])
240
+ else:
241
+ raise ValueError(f"Unknown sampling method: {method}")
242
+
243
+ return EmbeddingData(
244
+ vectors=embedding_data.vectors[indices],
245
+ metadata=[embedding_data.metadata[i] for i in indices],
246
+ collection_names=[embedding_data.collection_names[i] for i in indices],
247
+ collection_indices=[embedding_data.collection_indices[i] for i in indices],
248
+ object_ids=[embedding_data.object_ids[i] for i in indices]
249
+ )
250
+
251
+
252
+ def compute_embedding_statistics(embedding_data: EmbeddingData) -> Dict:
253
+ """
254
+ Compute statistics about embeddings.
255
+
256
+ Args:
257
+ embedding_data: Embedding data
258
+
259
+ Returns:
260
+ Dictionary of statistics
261
+ """
262
+ stats = {
263
+ "n_samples": embedding_data.n_samples,
264
+ "n_dimensions": embedding_data.n_dimensions,
265
+ "n_collections": len(set(embedding_data.collection_names)),
266
+ "collections": list(set(embedding_data.collection_names)),
267
+ }
268
+
269
+ # Per-collection counts
270
+ from collections import Counter
271
+ collection_counts = Counter(embedding_data.collection_names)
272
+ stats["samples_per_collection"] = dict(collection_counts)
273
+
274
+ # Vector statistics
275
+ if embedding_data.n_samples > 0:
276
+ stats["mean_norm"] = float(np.mean(np.linalg.norm(embedding_data.vectors, axis=1)))
277
+ stats["std_norm"] = float(np.std(np.linalg.norm(embedding_data.vectors, axis=1)))
278
+
279
+ # Compute average pairwise similarity (on sample if large)
280
+ sample_size = min(100, embedding_data.n_samples)
281
+ if sample_size > 1:
282
+ sample_indices = np.random.choice(
283
+ embedding_data.n_samples,
284
+ size=sample_size,
285
+ replace=False
286
+ )
287
+ sample_vectors = embedding_data.vectors[sample_indices]
288
+
289
+ # Normalize for cosine similarity
290
+ norms = np.linalg.norm(sample_vectors, axis=1, keepdims=True)
291
+ normalized = sample_vectors / (norms + 1e-10)
292
+ similarities = np.dot(normalized, normalized.T)
293
+
294
+ # Extract upper triangle (excluding diagonal)
295
+ upper_tri = similarities[np.triu_indices(sample_size, k=1)]
296
+ stats["mean_similarity"] = float(np.mean(upper_tri))
297
+ stats["std_similarity"] = float(np.std(upper_tri))
298
+
299
+ return stats
@@ -0,0 +1,217 @@
1
+ from collections import Counter
2
+ from typing import Dict, List
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from pydantic import BaseModel
7
+ from scipy import stats
8
+
9
+ from linkml_store.api import Collection
10
+
11
+
12
+ class EnrichedCategory(BaseModel):
13
+ """
14
+ Information about a category enriched in a sample
15
+ """
16
+
17
+ category: str
18
+ fold_change: float
19
+ original_p_value: float
20
+ adjusted_p_value: float
21
+
22
+
23
+ class EnrichmentAnalyzer:
24
+ def __init__(self, df: pd.DataFrame, sample_key: str, classification_key: str):
25
+ """
26
+ Initialize the analyzer with a DataFrame and key column names.
27
+ Precomputes category frequencies for the entire dataset.
28
+
29
+ Args:
30
+ df: DataFrame containing the data
31
+ sample_key: Column name for sample IDs
32
+ classification_key: Column name for category lists
33
+ """
34
+ self.df = df
35
+ self.sample_key = sample_key
36
+ self.classification_key = classification_key
37
+
38
+ # Precompute global category statistics
39
+ self.global_stats = self._compute_global_stats()
40
+
41
+ # Cache for sample-specific category counts
42
+ self.sample_cache: Dict[str, Counter] = {}
43
+
44
+ @classmethod
45
+ def from_collection(cls, collection: Collection, sample_key: str, classification_key: str) -> "EnrichmentAnalyzer":
46
+ """
47
+ Initialize the analyzer with a Collection and key column names.
48
+ Precomputes category frequencies for the entire dataset.
49
+
50
+ Args:
51
+ collection: Collection containing the data
52
+ sample_key: Column name for sample IDs
53
+ classification_key: Column name for category lists
54
+ """
55
+ column_atts = [sample_key, classification_key]
56
+ results = collection.find(select_cols=column_atts, limit=-1)
57
+ df = results.rows_dataframe
58
+ ea = cls(df, sample_key=sample_key, classification_key=classification_key)
59
+ return ea
60
+
61
+ def _compute_global_stats(self) -> Dict[str, int]:
62
+ """
63
+ Compute global category frequencies across all samples.
64
+ Returns a dictionary of category -> count
65
+ """
66
+ global_counter = Counter()
67
+
68
+ # Flatten all categories and count
69
+ for categories in self.df[self.classification_key]:
70
+ if isinstance(categories, list):
71
+ global_counter.update(categories)
72
+ else:
73
+ # Handle case where categories might be a string
74
+ global_counter.update([categories])
75
+
76
+ return global_counter
77
+
78
+ @property
79
+ def sample_ids(self) -> List[str]:
80
+ df = self.df
81
+ return df[self.sample_key].unique().tolist()
82
+
83
+ def _get_sample_stats(self, sample_id: str) -> Counter:
84
+ """
85
+ Get category frequencies for a specific sample.
86
+ Uses caching to avoid recomputation.
87
+ """
88
+ if sample_id in self.sample_cache:
89
+ return self.sample_cache[sample_id]
90
+
91
+ sample_data = self.df[self.df[self.sample_key] == sample_id]
92
+ if sample_data.empty:
93
+ raise KeyError(f"Sample ID '{sample_id}' not found")
94
+ sample_data = sample_data.dropna()
95
+ # if sample_data.empty:
96
+ # raise ValueError(f"Sample ID '{sample_id}' has missing values after dropping NA")
97
+ counter = Counter()
98
+
99
+ for categories in sample_data[self.classification_key]:
100
+ if isinstance(categories, list):
101
+ counter.update(categories)
102
+ else:
103
+ counter.update([categories])
104
+
105
+ self.sample_cache[sample_id] = counter
106
+ return counter
107
+
108
+ def find_enriched_categories(
109
+ self,
110
+ sample_id: str,
111
+ min_occurrences: int = 5,
112
+ p_value_threshold: float = 0.05,
113
+ multiple_testing_correction: str = "bh",
114
+ ) -> List[EnrichedCategory]:
115
+ """
116
+ Find categories that are enriched in the given sample.
117
+
118
+ Args:
119
+ sample_id: ID of the sample to analyze
120
+ min_occurrences: Minimum number of occurrences required for a category
121
+ p_value_threshold: P-value threshold for significance
122
+
123
+ Returns:
124
+ List of tuples (category, fold_change, p_value) sorted by significance
125
+ """
126
+ sample_stats = self._get_sample_stats(sample_id)
127
+ total_sample_annotations = sum(sample_stats.values())
128
+ total_global_annotations = sum(self.global_stats.values())
129
+
130
+ results = []
131
+
132
+ for category, sample_count in sample_stats.items():
133
+ global_count = self.global_stats[category]
134
+
135
+ # Skip rare categories
136
+ if global_count < min_occurrences:
137
+ continue
138
+
139
+ # Calculate fold change
140
+ sample_freq = sample_count / total_sample_annotations
141
+ global_freq = global_count / total_global_annotations
142
+ fold_change = sample_freq / global_freq if global_freq > 0 else float("inf")
143
+
144
+ # Perform Fisher's exact test
145
+ contingency_table = np.array(
146
+ [
147
+ [sample_count, global_count - sample_count],
148
+ [
149
+ total_sample_annotations - sample_count,
150
+ total_global_annotations - total_sample_annotations - (global_count - sample_count),
151
+ ],
152
+ ]
153
+ )
154
+
155
+ _, p_value = stats.fisher_exact(contingency_table)
156
+
157
+ if p_value < p_value_threshold:
158
+ results.append((category, fold_change, p_value))
159
+
160
+ if not results:
161
+ return results
162
+
163
+ # Sort by p-value
164
+ results.sort(key=lambda x: x[2])
165
+
166
+ # Apply multiple testing correction
167
+ categories, fold_changes, p_values = zip(*results)
168
+
169
+ if multiple_testing_correction.lower() == "bonf":
170
+ # Bonferroni correction
171
+ n_tests = len(self.global_stats) # Total number of categories tested
172
+ adjusted_p_values = [min(1.0, p * n_tests) for p in p_values]
173
+
174
+ elif multiple_testing_correction.lower() == "bh":
175
+ # Benjamini-Hochberg correction
176
+ n = len(p_values)
177
+ sorted_indices = np.argsort(p_values)
178
+ sorted_p_values = np.array(p_values)[sorted_indices]
179
+
180
+ # Calculate BH adjusted p-values
181
+ adjusted_p_values = np.zeros(n)
182
+ for i, p in enumerate(sorted_p_values):
183
+ adjusted_p_values[i] = p * n / (i + 1)
184
+
185
+ # Ensure monotonicity
186
+ for i in range(n - 2, -1, -1):
187
+ adjusted_p_values[i] = min(adjusted_p_values[i], adjusted_p_values[i + 1])
188
+
189
+ # Restore original order
190
+ inverse_indices = np.argsort(sorted_indices)
191
+ adjusted_p_values = adjusted_p_values[inverse_indices]
192
+
193
+ # Ensure we don't exceed 1.0
194
+ adjusted_p_values = np.minimum(adjusted_p_values, 1.0)
195
+
196
+ else:
197
+ # No correction
198
+ adjusted_p_values = p_values
199
+
200
+ # Filter by adjusted p-value threshold and create final results
201
+ # Create EnrichedCategory objects
202
+ final_results = [
203
+ EnrichedCategory(category=cat, fold_change=fc, original_p_value=p, adjusted_p_value=adj_p)
204
+ for cat, fc, p, adj_p in zip(categories, fold_changes, p_values, adjusted_p_values)
205
+ if adj_p < p_value_threshold
206
+ ]
207
+
208
+ # Sort by adjusted p-value
209
+ final_results.sort(key=lambda x: x.adjusted_p_value)
210
+ return final_results
211
+
212
+
213
+ # Example usage:
214
+ # analyzer = EnrichmentAnalyzer(df, 'sample_id', 'categories')
215
+ # enriched = analyzer.find_enriched_categories('sample1')
216
+ # for category, fold_change, p_value in enriched:
217
+ # print(f"{category}: {fold_change:.2f}x enrichment (p={p_value:.2e})")
@@ -0,0 +1,37 @@
1
+ import logging
2
+ import shutil
3
+ import tempfile
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ # Set up logging
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def safe_remove_directory(dir_path: Path, no_backup: bool = False) -> Optional[Path]:
13
+ # Ensure the directory exists
14
+ if not dir_path.exists():
15
+ raise FileNotFoundError(f"Directory does not exist: {dir_path}")
16
+ try:
17
+ if no_backup:
18
+ # Move to a temporary directory instead of permanent removal
19
+ with tempfile.TemporaryDirectory() as tmpdir:
20
+ tmp_path = Path(tmpdir) / dir_path.name
21
+ shutil.move(str(dir_path), str(tmp_path))
22
+ logger.info(f"Directory moved to temporary location: {tmp_path}")
23
+ # The directory will be automatically removed when exiting the context manager
24
+ return None
25
+ else:
26
+ # Create a backup directory name with timestamp
27
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
28
+ backup_dir = dir_path.with_name(f"{dir_path.name}_backup_{timestamp}")
29
+
30
+ # Move the directory to the backup location
31
+ shutil.move(str(dir_path), str(backup_dir))
32
+ logger.info(f"Directory backed up to: {backup_dir}")
33
+ return backup_dir
34
+
35
+ except Exception as e:
36
+ logger.error(f"An error occurred: {e}")
37
+ return None