chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Differential expression analysis tools for spatial transcriptomics data.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import scanpy as sc
|
|
8
|
+
from scipy import sparse
|
|
9
|
+
|
|
10
|
+
from ..models.analysis import DifferentialExpressionResult
|
|
11
|
+
from ..models.data import DifferentialExpressionParameters
|
|
12
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
13
|
+
from ..utils import validate_obs_column
|
|
14
|
+
from ..utils.adata_utils import store_analysis_metadata, to_dense
|
|
15
|
+
from ..utils.dependency_manager import require
|
|
16
|
+
from ..utils.exceptions import (
|
|
17
|
+
DataError,
|
|
18
|
+
DataNotFoundError,
|
|
19
|
+
ParameterError,
|
|
20
|
+
ProcessingError,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def differential_expression(
|
|
25
|
+
data_id: str,
|
|
26
|
+
ctx: ToolContext,
|
|
27
|
+
params: DifferentialExpressionParameters,
|
|
28
|
+
) -> DifferentialExpressionResult:
|
|
29
|
+
"""Perform differential expression analysis.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
data_id: Dataset ID
|
|
33
|
+
ctx: Tool context for data access and logging
|
|
34
|
+
params: Differential expression parameters
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Differential expression analysis result
|
|
38
|
+
"""
|
|
39
|
+
# Extract parameters from params object
|
|
40
|
+
group_key = params.group_key
|
|
41
|
+
group1 = params.group1
|
|
42
|
+
group2 = params.group2
|
|
43
|
+
method = params.method
|
|
44
|
+
n_top_genes = params.n_top_genes
|
|
45
|
+
pseudocount = params.pseudocount
|
|
46
|
+
min_cells = params.min_cells
|
|
47
|
+
|
|
48
|
+
# Dispatch to pydeseq2 if requested
|
|
49
|
+
if method == "pydeseq2":
|
|
50
|
+
return await _run_pydeseq2(data_id, ctx, params)
|
|
51
|
+
|
|
52
|
+
# Get AnnData directly via ToolContext (no redundant dict wrapping)
|
|
53
|
+
adata = await ctx.get_adata(data_id)
|
|
54
|
+
|
|
55
|
+
# Check if the group_key exists in adata.obs
|
|
56
|
+
validate_obs_column(adata, group_key, "Group")
|
|
57
|
+
|
|
58
|
+
# Check if dtype conversion is needed (numba doesn't support float16)
|
|
59
|
+
# Defer conversion to after subsetting for memory efficiency
|
|
60
|
+
needs_dtype_fix = hasattr(adata.X, "dtype") and adata.X.dtype == np.float16
|
|
61
|
+
|
|
62
|
+
# If group1 is None, find markers for all groups
|
|
63
|
+
if group1 is None:
|
|
64
|
+
|
|
65
|
+
# Filter out groups with too few cells (user-configurable threshold)
|
|
66
|
+
group_sizes = adata.obs[group_key].value_counts()
|
|
67
|
+
# min_cells is now a parameter (default=3, minimum for Wilcoxon test)
|
|
68
|
+
valid_groups = group_sizes[group_sizes >= min_cells]
|
|
69
|
+
skipped_groups = group_sizes[group_sizes < min_cells]
|
|
70
|
+
|
|
71
|
+
# Warn about skipped groups
|
|
72
|
+
if len(skipped_groups) > 0:
|
|
73
|
+
skipped_list = "\n".join(
|
|
74
|
+
[f" - {g}: {n} cell(s)" for g, n in skipped_groups.items()]
|
|
75
|
+
)
|
|
76
|
+
await ctx.warning(
|
|
77
|
+
f"Skipped {len(skipped_groups)} group(s) with <{min_cells} cells:\n{skipped_list}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Check if any valid groups remain
|
|
81
|
+
if len(valid_groups) == 0:
|
|
82
|
+
all_sizes = "\n".join(
|
|
83
|
+
[f" • {g}: {n} cell(s)" for g, n in group_sizes.items()]
|
|
84
|
+
)
|
|
85
|
+
raise DataError(
|
|
86
|
+
f"All groups have <{min_cells} cells. Cannot perform {method} test.\n\n"
|
|
87
|
+
f"Group sizes:\n{all_sizes}\n\n"
|
|
88
|
+
f"Try: find_markers(group_key='leiden') or merge small groups"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Filter data to only include valid groups
|
|
92
|
+
adata_filtered = adata[adata.obs[group_key].isin(valid_groups.index)].copy()
|
|
93
|
+
|
|
94
|
+
# Convert dtype after subsetting (4x more memory efficient than copying first)
|
|
95
|
+
if needs_dtype_fix:
|
|
96
|
+
adata_filtered.X = adata_filtered.X.astype(np.float32)
|
|
97
|
+
|
|
98
|
+
# Run rank_genes_groups on filtered data
|
|
99
|
+
sc.tl.rank_genes_groups(
|
|
100
|
+
adata_filtered,
|
|
101
|
+
groupby=group_key,
|
|
102
|
+
method=method,
|
|
103
|
+
n_genes=n_top_genes,
|
|
104
|
+
reference="rest",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Get all groups (from filtered data)
|
|
108
|
+
groups = adata_filtered.obs[group_key].unique()
|
|
109
|
+
|
|
110
|
+
# Collect top genes from all groups
|
|
111
|
+
all_top_genes = []
|
|
112
|
+
if (
|
|
113
|
+
"rank_genes_groups" in adata_filtered.uns
|
|
114
|
+
and "names" in adata_filtered.uns["rank_genes_groups"]
|
|
115
|
+
):
|
|
116
|
+
gene_names = adata_filtered.uns["rank_genes_groups"]["names"]
|
|
117
|
+
for group in groups:
|
|
118
|
+
if str(group) in gene_names.dtype.names:
|
|
119
|
+
genes = list(gene_names[str(group)][:n_top_genes])
|
|
120
|
+
all_top_genes.extend(genes)
|
|
121
|
+
|
|
122
|
+
# Remove duplicates while preserving order
|
|
123
|
+
seen = set()
|
|
124
|
+
top_genes = []
|
|
125
|
+
for gene in all_top_genes:
|
|
126
|
+
if gene not in seen:
|
|
127
|
+
seen.add(gene)
|
|
128
|
+
top_genes.append(gene)
|
|
129
|
+
|
|
130
|
+
# Limit to n_top_genes
|
|
131
|
+
top_genes = top_genes[:n_top_genes]
|
|
132
|
+
|
|
133
|
+
# Copy results back to original adata for persistence
|
|
134
|
+
adata.uns["rank_genes_groups"] = adata_filtered.uns["rank_genes_groups"]
|
|
135
|
+
|
|
136
|
+
# Store metadata for scientific provenance tracking
|
|
137
|
+
store_analysis_metadata(
|
|
138
|
+
adata,
|
|
139
|
+
analysis_name="differential_expression",
|
|
140
|
+
method=method,
|
|
141
|
+
parameters={
|
|
142
|
+
"group_key": group_key,
|
|
143
|
+
"comparison_type": "all_groups",
|
|
144
|
+
"n_top_genes": n_top_genes,
|
|
145
|
+
},
|
|
146
|
+
results_keys={"uns": ["rank_genes_groups"]},
|
|
147
|
+
statistics={
|
|
148
|
+
"method": method,
|
|
149
|
+
"n_groups": len(groups),
|
|
150
|
+
"groups": list(map(str, groups)),
|
|
151
|
+
"n_cells_analyzed": adata_filtered.n_obs,
|
|
152
|
+
"n_genes_analyzed": adata_filtered.n_vars,
|
|
153
|
+
},
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return DifferentialExpressionResult(
|
|
157
|
+
data_id=data_id,
|
|
158
|
+
comparison=f"All groups in {group_key}",
|
|
159
|
+
n_genes=len(top_genes),
|
|
160
|
+
top_genes=top_genes,
|
|
161
|
+
statistics={
|
|
162
|
+
"method": method,
|
|
163
|
+
"n_groups": len(groups),
|
|
164
|
+
"groups": list(map(str, groups)),
|
|
165
|
+
},
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# Original logic for specific group comparison
|
|
169
|
+
# Check if the groups exist in the group_key
|
|
170
|
+
if group1 not in adata.obs[group_key].values:
|
|
171
|
+
raise ParameterError(f"Group '{group1}' not found in adata.obs['{group_key}']")
|
|
172
|
+
|
|
173
|
+
# Special case for 'rest' as group2 or if group2 is None
|
|
174
|
+
use_rest_as_reference = False
|
|
175
|
+
if group2 is None or group2 == "rest":
|
|
176
|
+
use_rest_as_reference = True
|
|
177
|
+
group2 = "rest" # Set it explicitly for display purposes
|
|
178
|
+
elif group2 not in adata.obs[group_key].values:
|
|
179
|
+
raise ParameterError(f"Group '{group2}' not found in adata.obs['{group_key}']")
|
|
180
|
+
|
|
181
|
+
# Perform differential expression analysis
|
|
182
|
+
|
|
183
|
+
# Prepare the AnnData object for analysis
|
|
184
|
+
if use_rest_as_reference:
|
|
185
|
+
# Use the full AnnData object when comparing with 'rest'
|
|
186
|
+
temp_adata = adata.copy()
|
|
187
|
+
else:
|
|
188
|
+
# Create a temporary copy of the AnnData object with only the two groups
|
|
189
|
+
temp_adata = adata[adata.obs[group_key].isin([group1, group2])].copy()
|
|
190
|
+
|
|
191
|
+
# Convert dtype after subsetting (4x more memory efficient than copying first)
|
|
192
|
+
if needs_dtype_fix:
|
|
193
|
+
temp_adata.X = temp_adata.X.astype(np.float32)
|
|
194
|
+
|
|
195
|
+
# Run rank_genes_groups
|
|
196
|
+
sc.tl.rank_genes_groups(
|
|
197
|
+
temp_adata,
|
|
198
|
+
groupby=group_key,
|
|
199
|
+
groups=[group1],
|
|
200
|
+
reference="rest" if use_rest_as_reference else group2,
|
|
201
|
+
method=method,
|
|
202
|
+
n_genes=n_top_genes,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Extract results
|
|
206
|
+
|
|
207
|
+
# Get the top genes
|
|
208
|
+
top_genes = []
|
|
209
|
+
if (
|
|
210
|
+
hasattr(temp_adata, "uns")
|
|
211
|
+
and "rank_genes_groups" in temp_adata.uns
|
|
212
|
+
and "names" in temp_adata.uns["rank_genes_groups"]
|
|
213
|
+
):
|
|
214
|
+
# Get the top genes for the first group (should be group1)
|
|
215
|
+
gene_names = temp_adata.uns["rank_genes_groups"]["names"]
|
|
216
|
+
if group1 in gene_names.dtype.names:
|
|
217
|
+
top_genes = list(gene_names[group1][:n_top_genes])
|
|
218
|
+
else:
|
|
219
|
+
# If group1 is not in the names, use the first column
|
|
220
|
+
top_genes = list(gene_names[gene_names.dtype.names[0]][:n_top_genes])
|
|
221
|
+
|
|
222
|
+
# If no genes were found, fail honestly
|
|
223
|
+
if not top_genes:
|
|
224
|
+
raise ProcessingError(
|
|
225
|
+
f"No DE genes found between {group1} and {group2}. "
|
|
226
|
+
f"Check sample sizes and expression differences."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Get statistics
|
|
230
|
+
n_cells_group1 = np.sum(adata.obs[group_key] == group1)
|
|
231
|
+
if use_rest_as_reference:
|
|
232
|
+
n_cells_group2 = adata.n_obs - n_cells_group1 # All cells except group1
|
|
233
|
+
else:
|
|
234
|
+
n_cells_group2 = np.sum(adata.obs[group_key] == group2)
|
|
235
|
+
|
|
236
|
+
# Get p-values from scanpy results
|
|
237
|
+
pvals = []
|
|
238
|
+
if (
|
|
239
|
+
hasattr(temp_adata, "uns")
|
|
240
|
+
and "rank_genes_groups" in temp_adata.uns
|
|
241
|
+
and "pvals_adj" in temp_adata.uns["rank_genes_groups"]
|
|
242
|
+
and group1 in temp_adata.uns["rank_genes_groups"]["pvals_adj"].dtype.names
|
|
243
|
+
):
|
|
244
|
+
pvals = list(
|
|
245
|
+
temp_adata.uns["rank_genes_groups"]["pvals_adj"][group1][:n_top_genes]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Calculate TRUE fold change from raw counts (Bug #3 Fix)
|
|
249
|
+
# Issue: scanpy's logfoldchanges uses mean(log(counts)) which is mathematically incorrect
|
|
250
|
+
# Solution: Calculate log(mean(counts1) / mean(counts2)) from raw data
|
|
251
|
+
|
|
252
|
+
# Check if raw count data is available
|
|
253
|
+
if adata.raw is None:
|
|
254
|
+
raise DataNotFoundError(
|
|
255
|
+
"Raw count data (adata.raw) required for fold change calculation. "
|
|
256
|
+
"Run preprocess_data() first to preserve raw counts."
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Get raw count data
|
|
260
|
+
raw_adata = adata.raw
|
|
261
|
+
log2fc_values = []
|
|
262
|
+
|
|
263
|
+
# Create masks for the two groups
|
|
264
|
+
if use_rest_as_reference:
|
|
265
|
+
group1_mask = adata.obs[group_key] == group1
|
|
266
|
+
group2_mask = ~group1_mask
|
|
267
|
+
else:
|
|
268
|
+
group1_mask = adata.obs[group_key] == group1
|
|
269
|
+
group2_mask = adata.obs[group_key] == group2
|
|
270
|
+
|
|
271
|
+
# CRITICAL: Normalize by library size to avoid composition bias
|
|
272
|
+
# Library size = total UMI counts per spot
|
|
273
|
+
if hasattr(raw_adata.X, "toarray"):
|
|
274
|
+
lib_sizes = np.array(raw_adata.X.sum(axis=1)).flatten()
|
|
275
|
+
else:
|
|
276
|
+
lib_sizes = raw_adata.X.sum(axis=1).flatten()
|
|
277
|
+
|
|
278
|
+
median_lib_size = float(np.median(lib_sizes))
|
|
279
|
+
|
|
280
|
+
# Calculate fold change for each top gene
|
|
281
|
+
for gene in top_genes:
|
|
282
|
+
if gene in raw_adata.var_names:
|
|
283
|
+
gene_idx = raw_adata.var_names.get_loc(gene)
|
|
284
|
+
|
|
285
|
+
# Get raw counts for this gene
|
|
286
|
+
gene_raw_counts = to_dense(raw_adata.X[:, gene_idx]).flatten()
|
|
287
|
+
|
|
288
|
+
# Normalize by library size (CPM-like normalization)
|
|
289
|
+
# normalized_counts = raw_counts * (median_lib_size / spot_lib_size)
|
|
290
|
+
gene_norm_counts = gene_raw_counts * (median_lib_size / lib_sizes)
|
|
291
|
+
|
|
292
|
+
# Calculate mean normalized counts for each group
|
|
293
|
+
mean_group1 = float(gene_norm_counts[group1_mask].mean())
|
|
294
|
+
mean_group2 = float(gene_norm_counts[group2_mask].mean())
|
|
295
|
+
|
|
296
|
+
# Calculate true log2 fold change from normalized counts
|
|
297
|
+
# Add user-configurable pseudocount to avoid log(0)
|
|
298
|
+
true_log2fc = np.log2(
|
|
299
|
+
(mean_group1 + pseudocount) / (mean_group2 + pseudocount)
|
|
300
|
+
)
|
|
301
|
+
log2fc_values.append(float(true_log2fc))
|
|
302
|
+
else:
|
|
303
|
+
# Gene not in raw data (should not happen, but handle gracefully)
|
|
304
|
+
await ctx.warning(
|
|
305
|
+
f"Gene {gene} not found in raw data, skipping fold change calculation"
|
|
306
|
+
)
|
|
307
|
+
log2fc_values.append(None)
|
|
308
|
+
|
|
309
|
+
# Calculate mean log2fc (filtering out None values)
|
|
310
|
+
valid_log2fc = [fc for fc in log2fc_values if fc is not None]
|
|
311
|
+
mean_log2fc = np.mean(valid_log2fc) if valid_log2fc else None
|
|
312
|
+
median_pvalue = np.median(pvals) if pvals else None
|
|
313
|
+
|
|
314
|
+
# Warn if fold change values are suspiciously high (indicating calculation errors)
|
|
315
|
+
if mean_log2fc is not None and abs(mean_log2fc) > 10:
|
|
316
|
+
await ctx.warning(
|
|
317
|
+
f"Extreme fold change: mean log2FC = {mean_log2fc:.2f} (>1024x). "
|
|
318
|
+
f"May indicate sparse expression or low cell counts."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Create statistics dictionary
|
|
322
|
+
statistics = {
|
|
323
|
+
"method": method,
|
|
324
|
+
"n_cells_group1": int(n_cells_group1),
|
|
325
|
+
"n_cells_group2": int(n_cells_group2),
|
|
326
|
+
"mean_log2fc": float(mean_log2fc) if mean_log2fc is not None else None,
|
|
327
|
+
"median_pvalue": float(median_pvalue) if median_pvalue is not None else None,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
# Create comparison string
|
|
331
|
+
comparison = f"{group1} vs {group2}"
|
|
332
|
+
|
|
333
|
+
# Copy results back to original adata for persistence
|
|
334
|
+
adata.uns["rank_genes_groups"] = temp_adata.uns["rank_genes_groups"]
|
|
335
|
+
|
|
336
|
+
# Store metadata for scientific provenance tracking
|
|
337
|
+
store_analysis_metadata(
|
|
338
|
+
adata,
|
|
339
|
+
analysis_name="differential_expression",
|
|
340
|
+
method=method,
|
|
341
|
+
parameters={
|
|
342
|
+
"group_key": group_key,
|
|
343
|
+
"group1": group1,
|
|
344
|
+
"group2": group2,
|
|
345
|
+
"comparison_type": "specific_groups",
|
|
346
|
+
"n_top_genes": n_top_genes,
|
|
347
|
+
"pseudocount": pseudocount, # Track for reproducibility
|
|
348
|
+
},
|
|
349
|
+
results_keys={"uns": ["rank_genes_groups"]},
|
|
350
|
+
statistics={
|
|
351
|
+
"method": method,
|
|
352
|
+
"group1": group1,
|
|
353
|
+
"group2": group2,
|
|
354
|
+
"n_cells_group1": int(n_cells_group1),
|
|
355
|
+
"n_cells_group2": int(n_cells_group2),
|
|
356
|
+
"n_genes_analyzed": temp_adata.n_vars,
|
|
357
|
+
"mean_log2fc": float(mean_log2fc) if mean_log2fc is not None else None,
|
|
358
|
+
"median_pvalue": (
|
|
359
|
+
float(median_pvalue) if median_pvalue is not None else None
|
|
360
|
+
),
|
|
361
|
+
"pseudocount_used": pseudocount, # Document in statistics
|
|
362
|
+
},
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return DifferentialExpressionResult(
|
|
366
|
+
data_id=data_id,
|
|
367
|
+
comparison=comparison,
|
|
368
|
+
n_genes=len(top_genes),
|
|
369
|
+
top_genes=top_genes,
|
|
370
|
+
statistics=statistics,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def _run_pydeseq2(
|
|
375
|
+
data_id: str,
|
|
376
|
+
ctx: ToolContext,
|
|
377
|
+
params: DifferentialExpressionParameters,
|
|
378
|
+
) -> DifferentialExpressionResult:
|
|
379
|
+
"""Run PyDESeq2 pseudobulk differential expression analysis.
|
|
380
|
+
|
|
381
|
+
This function performs pseudobulk aggregation by summing raw counts within
|
|
382
|
+
each sample/group combination, then uses PyDESeq2 for DE analysis.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
data_id: Dataset ID
|
|
386
|
+
ctx: Tool context for data access and logging
|
|
387
|
+
params: Differential expression parameters
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Differential expression analysis result
|
|
391
|
+
|
|
392
|
+
Raises:
|
|
393
|
+
ParameterError: If sample_key is not provided
|
|
394
|
+
ImportError: If pydeseq2 is not installed
|
|
395
|
+
"""
|
|
396
|
+
# Validate sample_key is provided
|
|
397
|
+
if params.sample_key is None:
|
|
398
|
+
raise ParameterError(
|
|
399
|
+
"sample_key is required for pydeseq2 method.\n"
|
|
400
|
+
"Provide a column in adata.obs that identifies biological replicates "
|
|
401
|
+
"(e.g., 'sample', 'patient_id', 'batch').\n"
|
|
402
|
+
"Example: find_markers(group_key='cell_type', method='pydeseq2', "
|
|
403
|
+
"sample_key='sample')"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Import pydeseq2 (require() raises ImportError if not available)
|
|
407
|
+
require("pydeseq2", ctx, feature="DESeq2 differential expression")
|
|
408
|
+
from pydeseq2.dds import DeseqDataSet
|
|
409
|
+
from pydeseq2.ds import DeseqStats
|
|
410
|
+
|
|
411
|
+
# Get data
|
|
412
|
+
adata = await ctx.get_adata(data_id)
|
|
413
|
+
|
|
414
|
+
# Validate columns
|
|
415
|
+
validate_obs_column(adata, params.group_key, "Group")
|
|
416
|
+
validate_obs_column(adata, params.sample_key, "Sample")
|
|
417
|
+
|
|
418
|
+
# Get raw counts (required for DESeq2)
|
|
419
|
+
if adata.raw is not None:
|
|
420
|
+
raw_X = adata.raw.X
|
|
421
|
+
var_names = adata.raw.var_names
|
|
422
|
+
else:
|
|
423
|
+
raw_X = adata.X
|
|
424
|
+
var_names = adata.var_names
|
|
425
|
+
|
|
426
|
+
# Convert to dense if sparse
|
|
427
|
+
if sparse.issparse(raw_X):
|
|
428
|
+
raw_X = raw_X.toarray()
|
|
429
|
+
|
|
430
|
+
# Validate counts are integers (DESeq2 requirement)
|
|
431
|
+
if not np.allclose(raw_X, raw_X.astype(int)):
|
|
432
|
+
await ctx.warning(
|
|
433
|
+
"Data appears to be normalized. DESeq2 requires raw integer counts. "
|
|
434
|
+
"Results may be inaccurate."
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Determine comparison groups
|
|
438
|
+
group_key = params.group_key
|
|
439
|
+
sample_key = params.sample_key
|
|
440
|
+
group1 = params.group1
|
|
441
|
+
group2 = params.group2
|
|
442
|
+
|
|
443
|
+
# If group1 is None, find first two groups for pairwise comparison
|
|
444
|
+
unique_groups = adata.obs[group_key].unique()
|
|
445
|
+
if group1 is None:
|
|
446
|
+
if len(unique_groups) < 2:
|
|
447
|
+
raise DataError(
|
|
448
|
+
f"Need at least 2 groups for DE analysis, found {len(unique_groups)}"
|
|
449
|
+
)
|
|
450
|
+
group1 = str(unique_groups[0])
|
|
451
|
+
group2 = str(unique_groups[1])
|
|
452
|
+
await ctx.info(
|
|
453
|
+
f"No group specified, comparing first two groups: {group1} vs {group2}"
|
|
454
|
+
)
|
|
455
|
+
elif group2 is None or group2 == "rest":
|
|
456
|
+
# Compare group1 vs all others combined as "rest"
|
|
457
|
+
group2 = "rest"
|
|
458
|
+
|
|
459
|
+
# Create pseudobulk aggregation
|
|
460
|
+
await ctx.info(f"Creating pseudobulk samples by {sample_key} and {group_key}...")
|
|
461
|
+
|
|
462
|
+
# Build aggregation key
|
|
463
|
+
if group2 == "rest":
|
|
464
|
+
# Binary comparison: group1 vs rest
|
|
465
|
+
condition = adata.obs[group_key].apply(
|
|
466
|
+
lambda x: group1 if x == group1 else "rest"
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
# Pairwise comparison: filter to only group1 and group2
|
|
470
|
+
mask = adata.obs[group_key].isin([group1, group2])
|
|
471
|
+
adata = adata[mask].copy()
|
|
472
|
+
raw_X = raw_X[mask.values]
|
|
473
|
+
condition = adata.obs[group_key].astype(str)
|
|
474
|
+
|
|
475
|
+
# Create pseudobulk by aggregating (summing) counts per sample+condition
|
|
476
|
+
adata.obs["_de_condition"] = condition.values
|
|
477
|
+
adata.obs["_pseudobulk_id"] = (
|
|
478
|
+
adata.obs[sample_key].astype(str) + "_" + adata.obs["_de_condition"].astype(str)
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Aggregate counts
|
|
482
|
+
pseudobulk_groups = adata.obs.groupby("_pseudobulk_id")
|
|
483
|
+
pseudobulk_ids = list(pseudobulk_groups.groups.keys())
|
|
484
|
+
n_samples = len(pseudobulk_ids)
|
|
485
|
+
|
|
486
|
+
await ctx.info(f"Aggregated into {n_samples} pseudobulk samples")
|
|
487
|
+
|
|
488
|
+
if n_samples < 4:
|
|
489
|
+
raise DataError(
|
|
490
|
+
f"DESeq2 requires at least 2 samples per group. "
|
|
491
|
+
f"Found only {n_samples} total pseudobulk samples. "
|
|
492
|
+
f"Add more biological replicates or use a different method (wilcoxon)."
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
# Build pseudobulk count matrix
|
|
496
|
+
pseudobulk_counts = np.zeros((n_samples, raw_X.shape[1]), dtype=np.int64)
|
|
497
|
+
pseudobulk_metadata = []
|
|
498
|
+
|
|
499
|
+
for i, pb_id in enumerate(pseudobulk_ids):
|
|
500
|
+
# Get indices for this pseudobulk group
|
|
501
|
+
group_labels = pseudobulk_groups.groups[pb_id]
|
|
502
|
+
# Convert pandas Index to integer positional indices for numpy array indexing
|
|
503
|
+
int_idx = adata.obs.index.get_indexer(group_labels)
|
|
504
|
+
pseudobulk_counts[i] = raw_X[int_idx].sum(axis=0).astype(np.int64)
|
|
505
|
+
# Get condition from first cell in this group
|
|
506
|
+
first_int_idx = int_idx[0]
|
|
507
|
+
pseudobulk_metadata.append(
|
|
508
|
+
{
|
|
509
|
+
"sample_id": pb_id,
|
|
510
|
+
"condition": adata.obs.iloc[first_int_idx]["_de_condition"],
|
|
511
|
+
"sample": adata.obs.iloc[first_int_idx][sample_key],
|
|
512
|
+
}
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# Create metadata DataFrame
|
|
516
|
+
metadata_df = pd.DataFrame(pseudobulk_metadata)
|
|
517
|
+
metadata_df = metadata_df.set_index("sample_id")
|
|
518
|
+
|
|
519
|
+
# Create count DataFrame
|
|
520
|
+
counts_df = pd.DataFrame(pseudobulk_counts, index=pseudobulk_ids, columns=var_names)
|
|
521
|
+
|
|
522
|
+
# Check sample counts per condition
|
|
523
|
+
condition_counts = metadata_df["condition"].value_counts()
|
|
524
|
+
await ctx.info(f"Samples per condition: {condition_counts.to_dict()}")
|
|
525
|
+
|
|
526
|
+
if any(condition_counts < 2):
|
|
527
|
+
raise DataError(
|
|
528
|
+
f"DESeq2 requires at least 2 samples per condition. "
|
|
529
|
+
f"Current counts: {condition_counts.to_dict()}"
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Run PyDESeq2
|
|
533
|
+
await ctx.info("Running PyDESeq2 differential expression analysis...")
|
|
534
|
+
|
|
535
|
+
try:
|
|
536
|
+
# Create DESeq2 dataset
|
|
537
|
+
dds = DeseqDataSet(
|
|
538
|
+
counts=counts_df,
|
|
539
|
+
metadata=metadata_df,
|
|
540
|
+
design_factors="condition",
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# Run DESeq2 pipeline
|
|
544
|
+
dds.deseq2()
|
|
545
|
+
|
|
546
|
+
# Get results
|
|
547
|
+
stat_res = DeseqStats(dds, contrast=["condition", group1, group2])
|
|
548
|
+
stat_res.summary()
|
|
549
|
+
|
|
550
|
+
# Get results DataFrame
|
|
551
|
+
results_df = stat_res.results_df
|
|
552
|
+
|
|
553
|
+
except Exception as e:
|
|
554
|
+
raise ProcessingError(
|
|
555
|
+
f"PyDESeq2 analysis failed: {e}\n"
|
|
556
|
+
"This may be due to low sample counts or data issues."
|
|
557
|
+
) from e
|
|
558
|
+
|
|
559
|
+
# Extract top DE genes
|
|
560
|
+
# Sort by adjusted p-value, filter significant genes
|
|
561
|
+
results_df = results_df.dropna(subset=["padj"])
|
|
562
|
+
results_df = results_df.sort_values("padj")
|
|
563
|
+
|
|
564
|
+
top_genes = results_df.head(params.n_top_genes).index.tolist()
|
|
565
|
+
|
|
566
|
+
if not top_genes:
|
|
567
|
+
raise ProcessingError(
|
|
568
|
+
f"No DE genes found between {group1} and {group2}. "
|
|
569
|
+
"Check sample sizes and expression differences."
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Get statistics
|
|
573
|
+
n_sig_genes = (results_df["padj"] < 0.05).sum()
|
|
574
|
+
mean_log2fc = results_df.head(params.n_top_genes)["log2FoldChange"].mean()
|
|
575
|
+
median_pvalue = results_df.head(params.n_top_genes)["padj"].median()
|
|
576
|
+
|
|
577
|
+
# Store results in adata.uns for persistence
|
|
578
|
+
adata.uns["pydeseq2_results"] = {
|
|
579
|
+
"results_df": results_df.to_dict(),
|
|
580
|
+
"comparison": f"{group1} vs {group2}",
|
|
581
|
+
"n_samples": n_samples,
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
# Store metadata for scientific provenance tracking
|
|
585
|
+
store_analysis_metadata(
|
|
586
|
+
adata,
|
|
587
|
+
analysis_name="differential_expression",
|
|
588
|
+
method="pydeseq2",
|
|
589
|
+
parameters={
|
|
590
|
+
"group_key": group_key,
|
|
591
|
+
"sample_key": sample_key,
|
|
592
|
+
"group1": group1,
|
|
593
|
+
"group2": group2,
|
|
594
|
+
"comparison_type": "pseudobulk",
|
|
595
|
+
"n_top_genes": params.n_top_genes,
|
|
596
|
+
},
|
|
597
|
+
results_keys={"uns": ["pydeseq2_results"]},
|
|
598
|
+
statistics={
|
|
599
|
+
"method": "pydeseq2",
|
|
600
|
+
"group1": group1,
|
|
601
|
+
"group2": group2,
|
|
602
|
+
"n_pseudobulk_samples": n_samples,
|
|
603
|
+
"n_significant_genes": int(n_sig_genes),
|
|
604
|
+
"mean_log2fc": float(mean_log2fc) if not np.isnan(mean_log2fc) else None,
|
|
605
|
+
"median_padj": (
|
|
606
|
+
float(median_pvalue) if not np.isnan(median_pvalue) else None
|
|
607
|
+
),
|
|
608
|
+
},
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
return DifferentialExpressionResult(
|
|
612
|
+
data_id=data_id,
|
|
613
|
+
comparison=f"{group1} vs {group2}",
|
|
614
|
+
n_genes=len(top_genes),
|
|
615
|
+
top_genes=top_genes,
|
|
616
|
+
statistics={
|
|
617
|
+
"method": "pydeseq2",
|
|
618
|
+
"n_pseudobulk_samples": n_samples,
|
|
619
|
+
"n_significant_genes": int(n_sig_genes),
|
|
620
|
+
"mean_log2fc": float(mean_log2fc) if not np.isnan(mean_log2fc) else None,
|
|
621
|
+
"median_padj": (
|
|
622
|
+
float(median_pvalue) if not np.isnan(median_pvalue) else None
|
|
623
|
+
),
|
|
624
|
+
},
|
|
625
|
+
)
|