chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-sample condition comparison analysis for spatial transcriptomics data.
|
|
3
|
+
|
|
4
|
+
This module implements pseudobulk differential expression analysis for comparing
|
|
5
|
+
experimental conditions (e.g., Treatment vs Control) across biological samples.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from scipy import sparse
|
|
13
|
+
|
|
14
|
+
from ..models.analysis import (
|
|
15
|
+
CellTypeComparisonResult,
|
|
16
|
+
ConditionComparisonResult,
|
|
17
|
+
DEGene,
|
|
18
|
+
)
|
|
19
|
+
from ..models.data import ConditionComparisonParameters
|
|
20
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
21
|
+
from ..utils import validate_obs_column
|
|
22
|
+
from ..utils.adata_utils import store_analysis_metadata
|
|
23
|
+
from ..utils.dependency_manager import require
|
|
24
|
+
from ..utils.exceptions import DataError, ParameterError, ProcessingError
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def compare_conditions(
|
|
28
|
+
data_id: str,
|
|
29
|
+
ctx: ToolContext,
|
|
30
|
+
params: ConditionComparisonParameters,
|
|
31
|
+
) -> ConditionComparisonResult:
|
|
32
|
+
"""Compare experimental conditions across multiple biological samples.
|
|
33
|
+
|
|
34
|
+
This function performs pseudobulk differential expression analysis using DESeq2.
|
|
35
|
+
It aggregates cells by sample, then compares conditions (e.g., Treatment vs Control).
|
|
36
|
+
|
|
37
|
+
Optionally, analysis can be stratified by cell type to identify cell type-specific
|
|
38
|
+
condition effects.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
data_id: Dataset ID
|
|
42
|
+
ctx: Tool context for data access and logging
|
|
43
|
+
params: Condition comparison parameters
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
ConditionComparisonResult with differential expression results
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
# Global comparison (all cells)
|
|
50
|
+
compare_conditions(
|
|
51
|
+
data_id="data1",
|
|
52
|
+
condition_key="treatment",
|
|
53
|
+
condition1="Drug",
|
|
54
|
+
condition2="Control",
|
|
55
|
+
sample_key="patient_id"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Cell type stratified comparison
|
|
59
|
+
compare_conditions(
|
|
60
|
+
data_id="data1",
|
|
61
|
+
condition_key="treatment",
|
|
62
|
+
condition1="Drug",
|
|
63
|
+
condition2="Control",
|
|
64
|
+
sample_key="patient_id",
|
|
65
|
+
cell_type_key="cell_type"
|
|
66
|
+
)
|
|
67
|
+
"""
|
|
68
|
+
# Check pydeseq2 availability early (required for pseudobulk analysis)
|
|
69
|
+
require("pydeseq2", ctx, feature="Condition comparison with DESeq2")
|
|
70
|
+
|
|
71
|
+
# Get data
|
|
72
|
+
adata = await ctx.get_adata(data_id)
|
|
73
|
+
|
|
74
|
+
# Validate required columns
|
|
75
|
+
validate_obs_column(adata, params.condition_key, "Condition")
|
|
76
|
+
validate_obs_column(adata, params.sample_key, "Sample")
|
|
77
|
+
if params.cell_type_key is not None:
|
|
78
|
+
validate_obs_column(adata, params.cell_type_key, "Cell type")
|
|
79
|
+
|
|
80
|
+
# Validate conditions exist
|
|
81
|
+
unique_conditions = adata.obs[params.condition_key].unique()
|
|
82
|
+
if params.condition1 not in unique_conditions:
|
|
83
|
+
raise ParameterError(
|
|
84
|
+
f"Condition '{params.condition1}' not found in '{params.condition_key}'.\n"
|
|
85
|
+
f"Available conditions: {list(unique_conditions)}"
|
|
86
|
+
)
|
|
87
|
+
if params.condition2 not in unique_conditions:
|
|
88
|
+
raise ParameterError(
|
|
89
|
+
f"Condition '{params.condition2}' not found in '{params.condition_key}'.\n"
|
|
90
|
+
f"Available conditions: {list(unique_conditions)}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Filter to only the two conditions of interest
|
|
94
|
+
mask = adata.obs[params.condition_key].isin([params.condition1, params.condition2])
|
|
95
|
+
adata_filtered = adata[mask].copy()
|
|
96
|
+
|
|
97
|
+
await ctx.info(
|
|
98
|
+
f"Comparing {params.condition1} vs {params.condition2}: "
|
|
99
|
+
f"{adata_filtered.n_obs} cells from {adata_filtered.obs[params.sample_key].nunique()} samples"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Get raw counts (required for DESeq2)
|
|
103
|
+
raw_X, var_names = _get_raw_counts(adata_filtered)
|
|
104
|
+
|
|
105
|
+
# Validate counts are integers
|
|
106
|
+
if not np.allclose(raw_X, raw_X.astype(int)):
|
|
107
|
+
await ctx.warning(
|
|
108
|
+
"Data appears to be normalized. DESeq2 requires raw integer counts. "
|
|
109
|
+
"Results may be inaccurate. Consider using adata.raw."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Count samples per condition
|
|
113
|
+
sample_condition_map = adata_filtered.obs.groupby(params.sample_key)[
|
|
114
|
+
params.condition_key
|
|
115
|
+
].first()
|
|
116
|
+
n_samples_cond1 = (sample_condition_map == params.condition1).sum()
|
|
117
|
+
n_samples_cond2 = (sample_condition_map == params.condition2).sum()
|
|
118
|
+
|
|
119
|
+
await ctx.info(
|
|
120
|
+
f"Sample distribution: {params.condition1}={n_samples_cond1}, "
|
|
121
|
+
f"{params.condition2}={n_samples_cond2}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Check minimum samples requirement
|
|
125
|
+
if n_samples_cond1 < params.min_samples_per_condition:
|
|
126
|
+
raise DataError(
|
|
127
|
+
f"Insufficient samples for {params.condition1}: {n_samples_cond1} "
|
|
128
|
+
f"(minimum: {params.min_samples_per_condition})"
|
|
129
|
+
)
|
|
130
|
+
if n_samples_cond2 < params.min_samples_per_condition:
|
|
131
|
+
raise DataError(
|
|
132
|
+
f"Insufficient samples for {params.condition2}: {n_samples_cond2} "
|
|
133
|
+
f"(minimum: {params.min_samples_per_condition})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Determine analysis mode
|
|
137
|
+
if params.cell_type_key is None:
|
|
138
|
+
# Global analysis (all cells together)
|
|
139
|
+
result = await _run_global_comparison(
|
|
140
|
+
adata_filtered, raw_X, var_names, ctx, params
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
# Cell type stratified analysis
|
|
144
|
+
result = await _run_stratified_comparison(
|
|
145
|
+
adata_filtered, raw_X, var_names, ctx, params
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Update result with common fields
|
|
149
|
+
result.data_id = data_id
|
|
150
|
+
result.n_samples_condition1 = int(n_samples_cond1)
|
|
151
|
+
result.n_samples_condition2 = int(n_samples_cond2)
|
|
152
|
+
|
|
153
|
+
# Store results in adata
|
|
154
|
+
results_key = f"condition_comparison_{params.condition1}_vs_{params.condition2}"
|
|
155
|
+
adata.uns[results_key] = {
|
|
156
|
+
"comparison": result.comparison,
|
|
157
|
+
"method": result.method,
|
|
158
|
+
"statistics": result.statistics,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
# Store metadata for provenance
|
|
162
|
+
store_analysis_metadata(
|
|
163
|
+
adata,
|
|
164
|
+
analysis_name="condition_comparison",
|
|
165
|
+
method="pseudobulk_deseq2",
|
|
166
|
+
parameters={
|
|
167
|
+
"condition_key": params.condition_key,
|
|
168
|
+
"condition1": params.condition1,
|
|
169
|
+
"condition2": params.condition2,
|
|
170
|
+
"sample_key": params.sample_key,
|
|
171
|
+
"cell_type_key": params.cell_type_key,
|
|
172
|
+
},
|
|
173
|
+
results_keys={"uns": [results_key]},
|
|
174
|
+
statistics=result.statistics,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
result.results_key = results_key
|
|
178
|
+
return result
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _get_raw_counts(adata) -> tuple[np.ndarray, pd.Index]:
|
|
182
|
+
"""Extract raw count matrix from AnnData.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
adata: AnnData object
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Tuple of (count_matrix, var_names)
|
|
189
|
+
"""
|
|
190
|
+
if adata.raw is not None:
|
|
191
|
+
raw_X = adata.raw.X
|
|
192
|
+
var_names = adata.raw.var_names
|
|
193
|
+
else:
|
|
194
|
+
raw_X = adata.X
|
|
195
|
+
var_names = adata.var_names
|
|
196
|
+
|
|
197
|
+
# Convert to dense if sparse
|
|
198
|
+
if sparse.issparse(raw_X):
|
|
199
|
+
raw_X = raw_X.toarray()
|
|
200
|
+
|
|
201
|
+
return raw_X, var_names
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _create_pseudobulk(
|
|
205
|
+
adata,
|
|
206
|
+
raw_X: np.ndarray,
|
|
207
|
+
var_names: pd.Index,
|
|
208
|
+
sample_key: str,
|
|
209
|
+
condition_key: str,
|
|
210
|
+
cell_type: Optional[str] = None,
|
|
211
|
+
cell_type_key: Optional[str] = None,
|
|
212
|
+
min_cells_per_sample: int = 10,
|
|
213
|
+
) -> tuple[pd.DataFrame, pd.DataFrame, dict[str, int]]:
|
|
214
|
+
"""Create pseudobulk count matrix by aggregating cells per sample.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
adata: AnnData object
|
|
218
|
+
raw_X: Raw count matrix
|
|
219
|
+
var_names: Gene names
|
|
220
|
+
sample_key: Column for sample identification
|
|
221
|
+
condition_key: Column for condition
|
|
222
|
+
cell_type: Specific cell type to filter (optional)
|
|
223
|
+
cell_type_key: Column for cell type (required if cell_type is provided)
|
|
224
|
+
min_cells_per_sample: Minimum cells required per sample
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Tuple of (counts_df, metadata_df, cell_counts)
|
|
228
|
+
"""
|
|
229
|
+
# Filter to specific cell type if provided
|
|
230
|
+
if cell_type is not None and cell_type_key is not None:
|
|
231
|
+
mask = adata.obs[cell_type_key] == cell_type
|
|
232
|
+
adata = adata[mask].copy()
|
|
233
|
+
raw_X = raw_X[mask.values]
|
|
234
|
+
|
|
235
|
+
# Group by sample
|
|
236
|
+
sample_groups = adata.obs.groupby(sample_key)
|
|
237
|
+
|
|
238
|
+
pseudobulk_data = []
|
|
239
|
+
metadata_list = []
|
|
240
|
+
cell_counts = {}
|
|
241
|
+
|
|
242
|
+
for sample_id, group in sample_groups:
|
|
243
|
+
n_cells = len(group)
|
|
244
|
+
if n_cells < min_cells_per_sample:
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
# Get integer indices for this sample
|
|
248
|
+
int_idx = adata.obs.index.get_indexer(group.index)
|
|
249
|
+
|
|
250
|
+
# Sum counts
|
|
251
|
+
sample_counts = raw_X[int_idx].sum(axis=0).astype(np.int64)
|
|
252
|
+
|
|
253
|
+
# Get condition for this sample
|
|
254
|
+
condition = group[condition_key].iloc[0]
|
|
255
|
+
|
|
256
|
+
pseudobulk_data.append(sample_counts)
|
|
257
|
+
metadata_list.append(
|
|
258
|
+
{
|
|
259
|
+
"sample_id": sample_id,
|
|
260
|
+
"condition": condition,
|
|
261
|
+
}
|
|
262
|
+
)
|
|
263
|
+
cell_counts[str(sample_id)] = n_cells
|
|
264
|
+
|
|
265
|
+
if len(pseudobulk_data) == 0:
|
|
266
|
+
raise DataError(
|
|
267
|
+
f"No samples have >= {min_cells_per_sample} cells. "
|
|
268
|
+
"Try lowering min_cells_per_sample."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Create DataFrames
|
|
272
|
+
sample_ids = [m["sample_id"] for m in metadata_list]
|
|
273
|
+
counts_df = pd.DataFrame(
|
|
274
|
+
np.array(pseudobulk_data),
|
|
275
|
+
index=sample_ids,
|
|
276
|
+
columns=var_names,
|
|
277
|
+
)
|
|
278
|
+
metadata_df = pd.DataFrame(metadata_list).set_index("sample_id")
|
|
279
|
+
|
|
280
|
+
return counts_df, metadata_df, cell_counts
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _run_deseq2(
|
|
284
|
+
counts_df: pd.DataFrame,
|
|
285
|
+
metadata_df: pd.DataFrame,
|
|
286
|
+
condition1: str,
|
|
287
|
+
condition2: str,
|
|
288
|
+
n_top_genes: int = 50,
|
|
289
|
+
padj_threshold: float = 0.05,
|
|
290
|
+
log2fc_threshold: float = 0.0,
|
|
291
|
+
) -> tuple[list[DEGene], list[DEGene], int, pd.DataFrame]:
|
|
292
|
+
"""Run DESeq2 analysis on pseudobulk data.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
counts_df: Pseudobulk count matrix
|
|
296
|
+
metadata_df: Sample metadata with condition column
|
|
297
|
+
condition1: First condition (experimental)
|
|
298
|
+
condition2: Second condition (reference/control)
|
|
299
|
+
n_top_genes: Number of top genes to return
|
|
300
|
+
padj_threshold: Adjusted p-value threshold for significance
|
|
301
|
+
log2fc_threshold: Log2 fold change threshold
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Tuple of (top_upregulated, top_downregulated, n_significant, results_df)
|
|
305
|
+
"""
|
|
306
|
+
from pydeseq2.dds import DeseqDataSet
|
|
307
|
+
from pydeseq2.ds import DeseqStats
|
|
308
|
+
|
|
309
|
+
# Create DESeq2 dataset
|
|
310
|
+
dds = DeseqDataSet(
|
|
311
|
+
counts=counts_df,
|
|
312
|
+
metadata=metadata_df,
|
|
313
|
+
design_factors="condition",
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Run DESeq2 pipeline
|
|
317
|
+
dds.deseq2()
|
|
318
|
+
|
|
319
|
+
# Get results (condition1 vs condition2)
|
|
320
|
+
stat_res = DeseqStats(dds, contrast=["condition", condition1, condition2])
|
|
321
|
+
stat_res.summary()
|
|
322
|
+
|
|
323
|
+
results_df = stat_res.results_df.dropna(subset=["padj"])
|
|
324
|
+
|
|
325
|
+
# Filter by thresholds
|
|
326
|
+
sig_mask = (results_df["padj"] < padj_threshold) & (
|
|
327
|
+
np.abs(results_df["log2FoldChange"]) > log2fc_threshold
|
|
328
|
+
)
|
|
329
|
+
n_significant = sig_mask.sum()
|
|
330
|
+
|
|
331
|
+
# Separate upregulated and downregulated
|
|
332
|
+
upregulated = results_df[
|
|
333
|
+
(results_df["padj"] < padj_threshold)
|
|
334
|
+
& (results_df["log2FoldChange"] > log2fc_threshold)
|
|
335
|
+
].sort_values("padj")
|
|
336
|
+
|
|
337
|
+
downregulated = results_df[
|
|
338
|
+
(results_df["padj"] < padj_threshold)
|
|
339
|
+
& (results_df["log2FoldChange"] < -log2fc_threshold)
|
|
340
|
+
].sort_values("padj")
|
|
341
|
+
|
|
342
|
+
# Convert to DEGene objects
|
|
343
|
+
def df_to_degenes(df: pd.DataFrame, n: int) -> list[DEGene]:
|
|
344
|
+
genes = []
|
|
345
|
+
for gene_name, row in df.head(n).iterrows():
|
|
346
|
+
genes.append(
|
|
347
|
+
DEGene(
|
|
348
|
+
gene=str(gene_name),
|
|
349
|
+
log2fc=float(row["log2FoldChange"]),
|
|
350
|
+
pvalue=float(row["pvalue"]),
|
|
351
|
+
padj=float(row["padj"]),
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
return genes
|
|
355
|
+
|
|
356
|
+
top_up = df_to_degenes(upregulated, n_top_genes)
|
|
357
|
+
top_down = df_to_degenes(downregulated, n_top_genes)
|
|
358
|
+
|
|
359
|
+
return top_up, top_down, int(n_significant), results_df
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
async def _run_global_comparison(
|
|
363
|
+
adata,
|
|
364
|
+
raw_X: np.ndarray,
|
|
365
|
+
var_names: pd.Index,
|
|
366
|
+
ctx: ToolContext,
|
|
367
|
+
params: ConditionComparisonParameters,
|
|
368
|
+
) -> ConditionComparisonResult:
|
|
369
|
+
"""Run global comparison (all cells, no cell type stratification).
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
adata: Filtered AnnData object
|
|
373
|
+
raw_X: Raw count matrix
|
|
374
|
+
var_names: Gene names
|
|
375
|
+
ctx: Tool context
|
|
376
|
+
params: Comparison parameters
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
ConditionComparisonResult
|
|
380
|
+
"""
|
|
381
|
+
await ctx.info("Running global pseudobulk analysis (all cells)...")
|
|
382
|
+
|
|
383
|
+
# Create pseudobulk
|
|
384
|
+
counts_df, metadata_df, cell_counts = _create_pseudobulk(
|
|
385
|
+
adata,
|
|
386
|
+
raw_X,
|
|
387
|
+
var_names,
|
|
388
|
+
sample_key=params.sample_key,
|
|
389
|
+
condition_key=params.condition_key,
|
|
390
|
+
min_cells_per_sample=params.min_cells_per_sample,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Check sample distribution
|
|
394
|
+
cond_counts = metadata_df["condition"].value_counts()
|
|
395
|
+
n_cond1 = cond_counts.get(params.condition1, 0)
|
|
396
|
+
n_cond2 = cond_counts.get(params.condition2, 0)
|
|
397
|
+
|
|
398
|
+
if n_cond1 < 2 or n_cond2 < 2:
|
|
399
|
+
raise DataError(
|
|
400
|
+
f"DESeq2 requires at least 2 samples per condition. "
|
|
401
|
+
f"Found: {params.condition1}={n_cond1}, {params.condition2}={n_cond2}"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
await ctx.info(
|
|
405
|
+
f"Created {len(counts_df)} pseudobulk samples "
|
|
406
|
+
f"({params.condition1}={n_cond1}, {params.condition2}={n_cond2})"
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# Run DESeq2
|
|
410
|
+
try:
|
|
411
|
+
top_up, top_down, n_significant, results_df = _run_deseq2(
|
|
412
|
+
counts_df,
|
|
413
|
+
metadata_df,
|
|
414
|
+
condition1=params.condition1,
|
|
415
|
+
condition2=params.condition2,
|
|
416
|
+
n_top_genes=params.n_top_genes,
|
|
417
|
+
padj_threshold=params.padj_threshold,
|
|
418
|
+
log2fc_threshold=params.log2fc_threshold,
|
|
419
|
+
)
|
|
420
|
+
except Exception as e:
|
|
421
|
+
raise ProcessingError(f"DESeq2 analysis failed: {e}") from e
|
|
422
|
+
|
|
423
|
+
await ctx.info(f"Found {n_significant} significant DE genes")
|
|
424
|
+
|
|
425
|
+
# Build result
|
|
426
|
+
comparison = f"{params.condition1} vs {params.condition2}"
|
|
427
|
+
|
|
428
|
+
return ConditionComparisonResult(
|
|
429
|
+
data_id="", # Will be filled by caller
|
|
430
|
+
method="pseudobulk",
|
|
431
|
+
comparison=comparison,
|
|
432
|
+
condition_key=params.condition_key,
|
|
433
|
+
condition1=params.condition1,
|
|
434
|
+
condition2=params.condition2,
|
|
435
|
+
sample_key=params.sample_key,
|
|
436
|
+
cell_type_key=None,
|
|
437
|
+
n_samples_condition1=0, # Will be filled by caller
|
|
438
|
+
n_samples_condition2=0, # Will be filled by caller
|
|
439
|
+
global_n_significant=n_significant,
|
|
440
|
+
global_top_upregulated=top_up,
|
|
441
|
+
global_top_downregulated=top_down,
|
|
442
|
+
cell_type_results=None,
|
|
443
|
+
results_key="", # Will be filled by caller
|
|
444
|
+
statistics={
|
|
445
|
+
"analysis_type": "global",
|
|
446
|
+
"n_pseudobulk_samples": len(counts_df),
|
|
447
|
+
"n_significant_genes": n_significant,
|
|
448
|
+
"n_upregulated": len([g for g in top_up if g.padj < params.padj_threshold]),
|
|
449
|
+
"n_downregulated": len(
|
|
450
|
+
[g for g in top_down if g.padj < params.padj_threshold]
|
|
451
|
+
),
|
|
452
|
+
},
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
async def _run_stratified_comparison(
|
|
457
|
+
adata,
|
|
458
|
+
raw_X: np.ndarray,
|
|
459
|
+
var_names: pd.Index,
|
|
460
|
+
ctx: ToolContext,
|
|
461
|
+
params: ConditionComparisonParameters,
|
|
462
|
+
) -> ConditionComparisonResult:
|
|
463
|
+
"""Run cell type stratified comparison.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
adata: Filtered AnnData object
|
|
467
|
+
raw_X: Raw count matrix
|
|
468
|
+
var_names: Gene names
|
|
469
|
+
ctx: Tool context
|
|
470
|
+
params: Comparison parameters
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
ConditionComparisonResult with cell type stratified results
|
|
474
|
+
"""
|
|
475
|
+
await ctx.info(f"Running stratified analysis by {params.cell_type_key}...")
|
|
476
|
+
|
|
477
|
+
cell_types = adata.obs[params.cell_type_key].unique()
|
|
478
|
+
await ctx.info(f"Found {len(cell_types)} cell types")
|
|
479
|
+
|
|
480
|
+
cell_type_results: list[CellTypeComparisonResult] = []
|
|
481
|
+
total_significant = 0
|
|
482
|
+
|
|
483
|
+
for ct in cell_types:
|
|
484
|
+
ct_mask = adata.obs[params.cell_type_key] == ct
|
|
485
|
+
n_cells_ct = ct_mask.sum()
|
|
486
|
+
|
|
487
|
+
if n_cells_ct < params.min_cells_per_sample * 2:
|
|
488
|
+
await ctx.warning(
|
|
489
|
+
f"Skipping {ct}: only {n_cells_ct} cells "
|
|
490
|
+
f"(need {params.min_cells_per_sample * 2})"
|
|
491
|
+
)
|
|
492
|
+
continue
|
|
493
|
+
|
|
494
|
+
try:
|
|
495
|
+
# Create pseudobulk for this cell type
|
|
496
|
+
counts_df, metadata_df, cell_counts = _create_pseudobulk(
|
|
497
|
+
adata,
|
|
498
|
+
raw_X,
|
|
499
|
+
var_names,
|
|
500
|
+
sample_key=params.sample_key,
|
|
501
|
+
condition_key=params.condition_key,
|
|
502
|
+
cell_type=ct,
|
|
503
|
+
cell_type_key=params.cell_type_key,
|
|
504
|
+
min_cells_per_sample=params.min_cells_per_sample,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Check sample distribution
|
|
508
|
+
cond_counts = metadata_df["condition"].value_counts()
|
|
509
|
+
n_cond1 = cond_counts.get(params.condition1, 0)
|
|
510
|
+
n_cond2 = cond_counts.get(params.condition2, 0)
|
|
511
|
+
|
|
512
|
+
if n_cond1 < 2 or n_cond2 < 2:
|
|
513
|
+
await ctx.warning(
|
|
514
|
+
f"Skipping {ct}: insufficient samples "
|
|
515
|
+
f"({params.condition1}={n_cond1}, {params.condition2}={n_cond2})"
|
|
516
|
+
)
|
|
517
|
+
continue
|
|
518
|
+
|
|
519
|
+
# Run DESeq2
|
|
520
|
+
top_up, top_down, n_significant, results_df = _run_deseq2(
|
|
521
|
+
counts_df,
|
|
522
|
+
metadata_df,
|
|
523
|
+
condition1=params.condition1,
|
|
524
|
+
condition2=params.condition2,
|
|
525
|
+
n_top_genes=params.n_top_genes,
|
|
526
|
+
padj_threshold=params.padj_threshold,
|
|
527
|
+
log2fc_threshold=params.log2fc_threshold,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
total_significant += n_significant
|
|
531
|
+
|
|
532
|
+
# Count cells per condition for this cell type
|
|
533
|
+
ct_adata = adata[ct_mask]
|
|
534
|
+
n_cells_cond1 = (
|
|
535
|
+
ct_adata.obs[params.condition_key] == params.condition1
|
|
536
|
+
).sum()
|
|
537
|
+
n_cells_cond2 = (
|
|
538
|
+
ct_adata.obs[params.condition_key] == params.condition2
|
|
539
|
+
).sum()
|
|
540
|
+
|
|
541
|
+
cell_type_results.append(
|
|
542
|
+
CellTypeComparisonResult(
|
|
543
|
+
cell_type=str(ct),
|
|
544
|
+
n_cells_condition1=int(n_cells_cond1),
|
|
545
|
+
n_cells_condition2=int(n_cells_cond2),
|
|
546
|
+
n_samples_condition1=int(n_cond1),
|
|
547
|
+
n_samples_condition2=int(n_cond2),
|
|
548
|
+
n_significant_genes=n_significant,
|
|
549
|
+
top_upregulated=top_up,
|
|
550
|
+
top_downregulated=top_down,
|
|
551
|
+
)
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
await ctx.info(
|
|
555
|
+
f"{ct}: {n_significant} significant genes "
|
|
556
|
+
f"({len(top_up)} up, {len(top_down)} down)"
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
except Exception as e:
|
|
560
|
+
await ctx.warning(f"Analysis failed for {ct}: {e}")
|
|
561
|
+
continue
|
|
562
|
+
|
|
563
|
+
if not cell_type_results:
|
|
564
|
+
raise ProcessingError(
|
|
565
|
+
"No cell types had sufficient samples for DESeq2 analysis. "
|
|
566
|
+
"Try lowering min_cells_per_sample or min_samples_per_condition."
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
comparison = f"{params.condition1} vs {params.condition2}"
|
|
570
|
+
|
|
571
|
+
return ConditionComparisonResult(
|
|
572
|
+
data_id="", # Will be filled by caller
|
|
573
|
+
method="pseudobulk",
|
|
574
|
+
comparison=comparison,
|
|
575
|
+
condition_key=params.condition_key,
|
|
576
|
+
condition1=params.condition1,
|
|
577
|
+
condition2=params.condition2,
|
|
578
|
+
sample_key=params.sample_key,
|
|
579
|
+
cell_type_key=params.cell_type_key,
|
|
580
|
+
n_samples_condition1=0, # Will be filled by caller
|
|
581
|
+
n_samples_condition2=0, # Will be filled by caller
|
|
582
|
+
global_n_significant=None,
|
|
583
|
+
global_top_upregulated=None,
|
|
584
|
+
global_top_downregulated=None,
|
|
585
|
+
cell_type_results=cell_type_results,
|
|
586
|
+
results_key="", # Will be filled by caller
|
|
587
|
+
statistics={
|
|
588
|
+
"analysis_type": "cell_type_stratified",
|
|
589
|
+
"n_cell_types_analyzed": len(cell_type_results),
|
|
590
|
+
"total_significant_genes": total_significant,
|
|
591
|
+
"cell_types_with_de_genes": len(
|
|
592
|
+
[r for r in cell_type_results if r.n_significant_genes > 0]
|
|
593
|
+
),
|
|
594
|
+
},
|
|
595
|
+
)
|