chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,595 @@
1
+ """
2
+ Multi-sample condition comparison analysis for spatial transcriptomics data.
3
+
4
+ This module implements pseudobulk differential expression analysis for comparing
5
+ experimental conditions (e.g., Treatment vs Control) across biological samples.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from scipy import sparse
13
+
14
+ from ..models.analysis import (
15
+ CellTypeComparisonResult,
16
+ ConditionComparisonResult,
17
+ DEGene,
18
+ )
19
+ from ..models.data import ConditionComparisonParameters
20
+ from ..spatial_mcp_adapter import ToolContext
21
+ from ..utils import validate_obs_column
22
+ from ..utils.adata_utils import store_analysis_metadata
23
+ from ..utils.dependency_manager import require
24
+ from ..utils.exceptions import DataError, ParameterError, ProcessingError
25
+
26
+
27
+ async def compare_conditions(
28
+ data_id: str,
29
+ ctx: ToolContext,
30
+ params: ConditionComparisonParameters,
31
+ ) -> ConditionComparisonResult:
32
+ """Compare experimental conditions across multiple biological samples.
33
+
34
+ This function performs pseudobulk differential expression analysis using DESeq2.
35
+ It aggregates cells by sample, then compares conditions (e.g., Treatment vs Control).
36
+
37
+ Optionally, analysis can be stratified by cell type to identify cell type-specific
38
+ condition effects.
39
+
40
+ Args:
41
+ data_id: Dataset ID
42
+ ctx: Tool context for data access and logging
43
+ params: Condition comparison parameters
44
+
45
+ Returns:
46
+ ConditionComparisonResult with differential expression results
47
+
48
+ Example:
49
+ # Global comparison (all cells)
50
+ compare_conditions(
51
+ data_id="data1",
52
+ condition_key="treatment",
53
+ condition1="Drug",
54
+ condition2="Control",
55
+ sample_key="patient_id"
56
+ )
57
+
58
+ # Cell type stratified comparison
59
+ compare_conditions(
60
+ data_id="data1",
61
+ condition_key="treatment",
62
+ condition1="Drug",
63
+ condition2="Control",
64
+ sample_key="patient_id",
65
+ cell_type_key="cell_type"
66
+ )
67
+ """
68
+ # Check pydeseq2 availability early (required for pseudobulk analysis)
69
+ require("pydeseq2", ctx, feature="Condition comparison with DESeq2")
70
+
71
+ # Get data
72
+ adata = await ctx.get_adata(data_id)
73
+
74
+ # Validate required columns
75
+ validate_obs_column(adata, params.condition_key, "Condition")
76
+ validate_obs_column(adata, params.sample_key, "Sample")
77
+ if params.cell_type_key is not None:
78
+ validate_obs_column(adata, params.cell_type_key, "Cell type")
79
+
80
+ # Validate conditions exist
81
+ unique_conditions = adata.obs[params.condition_key].unique()
82
+ if params.condition1 not in unique_conditions:
83
+ raise ParameterError(
84
+ f"Condition '{params.condition1}' not found in '{params.condition_key}'.\n"
85
+ f"Available conditions: {list(unique_conditions)}"
86
+ )
87
+ if params.condition2 not in unique_conditions:
88
+ raise ParameterError(
89
+ f"Condition '{params.condition2}' not found in '{params.condition_key}'.\n"
90
+ f"Available conditions: {list(unique_conditions)}"
91
+ )
92
+
93
+ # Filter to only the two conditions of interest
94
+ mask = adata.obs[params.condition_key].isin([params.condition1, params.condition2])
95
+ adata_filtered = adata[mask].copy()
96
+
97
+ await ctx.info(
98
+ f"Comparing {params.condition1} vs {params.condition2}: "
99
+ f"{adata_filtered.n_obs} cells from {adata_filtered.obs[params.sample_key].nunique()} samples"
100
+ )
101
+
102
+ # Get raw counts (required for DESeq2)
103
+ raw_X, var_names = _get_raw_counts(adata_filtered)
104
+
105
+ # Validate counts are integers
106
+ if not np.allclose(raw_X, raw_X.astype(int)):
107
+ await ctx.warning(
108
+ "Data appears to be normalized. DESeq2 requires raw integer counts. "
109
+ "Results may be inaccurate. Consider using adata.raw."
110
+ )
111
+
112
+ # Count samples per condition
113
+ sample_condition_map = adata_filtered.obs.groupby(params.sample_key)[
114
+ params.condition_key
115
+ ].first()
116
+ n_samples_cond1 = (sample_condition_map == params.condition1).sum()
117
+ n_samples_cond2 = (sample_condition_map == params.condition2).sum()
118
+
119
+ await ctx.info(
120
+ f"Sample distribution: {params.condition1}={n_samples_cond1}, "
121
+ f"{params.condition2}={n_samples_cond2}"
122
+ )
123
+
124
+ # Check minimum samples requirement
125
+ if n_samples_cond1 < params.min_samples_per_condition:
126
+ raise DataError(
127
+ f"Insufficient samples for {params.condition1}: {n_samples_cond1} "
128
+ f"(minimum: {params.min_samples_per_condition})"
129
+ )
130
+ if n_samples_cond2 < params.min_samples_per_condition:
131
+ raise DataError(
132
+ f"Insufficient samples for {params.condition2}: {n_samples_cond2} "
133
+ f"(minimum: {params.min_samples_per_condition})"
134
+ )
135
+
136
+ # Determine analysis mode
137
+ if params.cell_type_key is None:
138
+ # Global analysis (all cells together)
139
+ result = await _run_global_comparison(
140
+ adata_filtered, raw_X, var_names, ctx, params
141
+ )
142
+ else:
143
+ # Cell type stratified analysis
144
+ result = await _run_stratified_comparison(
145
+ adata_filtered, raw_X, var_names, ctx, params
146
+ )
147
+
148
+ # Update result with common fields
149
+ result.data_id = data_id
150
+ result.n_samples_condition1 = int(n_samples_cond1)
151
+ result.n_samples_condition2 = int(n_samples_cond2)
152
+
153
+ # Store results in adata
154
+ results_key = f"condition_comparison_{params.condition1}_vs_{params.condition2}"
155
+ adata.uns[results_key] = {
156
+ "comparison": result.comparison,
157
+ "method": result.method,
158
+ "statistics": result.statistics,
159
+ }
160
+
161
+ # Store metadata for provenance
162
+ store_analysis_metadata(
163
+ adata,
164
+ analysis_name="condition_comparison",
165
+ method="pseudobulk_deseq2",
166
+ parameters={
167
+ "condition_key": params.condition_key,
168
+ "condition1": params.condition1,
169
+ "condition2": params.condition2,
170
+ "sample_key": params.sample_key,
171
+ "cell_type_key": params.cell_type_key,
172
+ },
173
+ results_keys={"uns": [results_key]},
174
+ statistics=result.statistics,
175
+ )
176
+
177
+ result.results_key = results_key
178
+ return result
179
+
180
+
181
+ def _get_raw_counts(adata) -> tuple[np.ndarray, pd.Index]:
182
+ """Extract raw count matrix from AnnData.
183
+
184
+ Args:
185
+ adata: AnnData object
186
+
187
+ Returns:
188
+ Tuple of (count_matrix, var_names)
189
+ """
190
+ if adata.raw is not None:
191
+ raw_X = adata.raw.X
192
+ var_names = adata.raw.var_names
193
+ else:
194
+ raw_X = adata.X
195
+ var_names = adata.var_names
196
+
197
+ # Convert to dense if sparse
198
+ if sparse.issparse(raw_X):
199
+ raw_X = raw_X.toarray()
200
+
201
+ return raw_X, var_names
202
+
203
+
204
+ def _create_pseudobulk(
205
+ adata,
206
+ raw_X: np.ndarray,
207
+ var_names: pd.Index,
208
+ sample_key: str,
209
+ condition_key: str,
210
+ cell_type: Optional[str] = None,
211
+ cell_type_key: Optional[str] = None,
212
+ min_cells_per_sample: int = 10,
213
+ ) -> tuple[pd.DataFrame, pd.DataFrame, dict[str, int]]:
214
+ """Create pseudobulk count matrix by aggregating cells per sample.
215
+
216
+ Args:
217
+ adata: AnnData object
218
+ raw_X: Raw count matrix
219
+ var_names: Gene names
220
+ sample_key: Column for sample identification
221
+ condition_key: Column for condition
222
+ cell_type: Specific cell type to filter (optional)
223
+ cell_type_key: Column for cell type (required if cell_type is provided)
224
+ min_cells_per_sample: Minimum cells required per sample
225
+
226
+ Returns:
227
+ Tuple of (counts_df, metadata_df, cell_counts)
228
+ """
229
+ # Filter to specific cell type if provided
230
+ if cell_type is not None and cell_type_key is not None:
231
+ mask = adata.obs[cell_type_key] == cell_type
232
+ adata = adata[mask].copy()
233
+ raw_X = raw_X[mask.values]
234
+
235
+ # Group by sample
236
+ sample_groups = adata.obs.groupby(sample_key)
237
+
238
+ pseudobulk_data = []
239
+ metadata_list = []
240
+ cell_counts = {}
241
+
242
+ for sample_id, group in sample_groups:
243
+ n_cells = len(group)
244
+ if n_cells < min_cells_per_sample:
245
+ continue
246
+
247
+ # Get integer indices for this sample
248
+ int_idx = adata.obs.index.get_indexer(group.index)
249
+
250
+ # Sum counts
251
+ sample_counts = raw_X[int_idx].sum(axis=0).astype(np.int64)
252
+
253
+ # Get condition for this sample
254
+ condition = group[condition_key].iloc[0]
255
+
256
+ pseudobulk_data.append(sample_counts)
257
+ metadata_list.append(
258
+ {
259
+ "sample_id": sample_id,
260
+ "condition": condition,
261
+ }
262
+ )
263
+ cell_counts[str(sample_id)] = n_cells
264
+
265
+ if len(pseudobulk_data) == 0:
266
+ raise DataError(
267
+ f"No samples have >= {min_cells_per_sample} cells. "
268
+ "Try lowering min_cells_per_sample."
269
+ )
270
+
271
+ # Create DataFrames
272
+ sample_ids = [m["sample_id"] for m in metadata_list]
273
+ counts_df = pd.DataFrame(
274
+ np.array(pseudobulk_data),
275
+ index=sample_ids,
276
+ columns=var_names,
277
+ )
278
+ metadata_df = pd.DataFrame(metadata_list).set_index("sample_id")
279
+
280
+ return counts_df, metadata_df, cell_counts
281
+
282
+
283
+ def _run_deseq2(
284
+ counts_df: pd.DataFrame,
285
+ metadata_df: pd.DataFrame,
286
+ condition1: str,
287
+ condition2: str,
288
+ n_top_genes: int = 50,
289
+ padj_threshold: float = 0.05,
290
+ log2fc_threshold: float = 0.0,
291
+ ) -> tuple[list[DEGene], list[DEGene], int, pd.DataFrame]:
292
+ """Run DESeq2 analysis on pseudobulk data.
293
+
294
+ Args:
295
+ counts_df: Pseudobulk count matrix
296
+ metadata_df: Sample metadata with condition column
297
+ condition1: First condition (experimental)
298
+ condition2: Second condition (reference/control)
299
+ n_top_genes: Number of top genes to return
300
+ padj_threshold: Adjusted p-value threshold for significance
301
+ log2fc_threshold: Log2 fold change threshold
302
+
303
+ Returns:
304
+ Tuple of (top_upregulated, top_downregulated, n_significant, results_df)
305
+ """
306
+ from pydeseq2.dds import DeseqDataSet
307
+ from pydeseq2.ds import DeseqStats
308
+
309
+ # Create DESeq2 dataset
310
+ dds = DeseqDataSet(
311
+ counts=counts_df,
312
+ metadata=metadata_df,
313
+ design_factors="condition",
314
+ )
315
+
316
+ # Run DESeq2 pipeline
317
+ dds.deseq2()
318
+
319
+ # Get results (condition1 vs condition2)
320
+ stat_res = DeseqStats(dds, contrast=["condition", condition1, condition2])
321
+ stat_res.summary()
322
+
323
+ results_df = stat_res.results_df.dropna(subset=["padj"])
324
+
325
+ # Filter by thresholds
326
+ sig_mask = (results_df["padj"] < padj_threshold) & (
327
+ np.abs(results_df["log2FoldChange"]) > log2fc_threshold
328
+ )
329
+ n_significant = sig_mask.sum()
330
+
331
+ # Separate upregulated and downregulated
332
+ upregulated = results_df[
333
+ (results_df["padj"] < padj_threshold)
334
+ & (results_df["log2FoldChange"] > log2fc_threshold)
335
+ ].sort_values("padj")
336
+
337
+ downregulated = results_df[
338
+ (results_df["padj"] < padj_threshold)
339
+ & (results_df["log2FoldChange"] < -log2fc_threshold)
340
+ ].sort_values("padj")
341
+
342
+ # Convert to DEGene objects
343
+ def df_to_degenes(df: pd.DataFrame, n: int) -> list[DEGene]:
344
+ genes = []
345
+ for gene_name, row in df.head(n).iterrows():
346
+ genes.append(
347
+ DEGene(
348
+ gene=str(gene_name),
349
+ log2fc=float(row["log2FoldChange"]),
350
+ pvalue=float(row["pvalue"]),
351
+ padj=float(row["padj"]),
352
+ )
353
+ )
354
+ return genes
355
+
356
+ top_up = df_to_degenes(upregulated, n_top_genes)
357
+ top_down = df_to_degenes(downregulated, n_top_genes)
358
+
359
+ return top_up, top_down, int(n_significant), results_df
360
+
361
+
362
+ async def _run_global_comparison(
363
+ adata,
364
+ raw_X: np.ndarray,
365
+ var_names: pd.Index,
366
+ ctx: ToolContext,
367
+ params: ConditionComparisonParameters,
368
+ ) -> ConditionComparisonResult:
369
+ """Run global comparison (all cells, no cell type stratification).
370
+
371
+ Args:
372
+ adata: Filtered AnnData object
373
+ raw_X: Raw count matrix
374
+ var_names: Gene names
375
+ ctx: Tool context
376
+ params: Comparison parameters
377
+
378
+ Returns:
379
+ ConditionComparisonResult
380
+ """
381
+ await ctx.info("Running global pseudobulk analysis (all cells)...")
382
+
383
+ # Create pseudobulk
384
+ counts_df, metadata_df, cell_counts = _create_pseudobulk(
385
+ adata,
386
+ raw_X,
387
+ var_names,
388
+ sample_key=params.sample_key,
389
+ condition_key=params.condition_key,
390
+ min_cells_per_sample=params.min_cells_per_sample,
391
+ )
392
+
393
+ # Check sample distribution
394
+ cond_counts = metadata_df["condition"].value_counts()
395
+ n_cond1 = cond_counts.get(params.condition1, 0)
396
+ n_cond2 = cond_counts.get(params.condition2, 0)
397
+
398
+ if n_cond1 < 2 or n_cond2 < 2:
399
+ raise DataError(
400
+ f"DESeq2 requires at least 2 samples per condition. "
401
+ f"Found: {params.condition1}={n_cond1}, {params.condition2}={n_cond2}"
402
+ )
403
+
404
+ await ctx.info(
405
+ f"Created {len(counts_df)} pseudobulk samples "
406
+ f"({params.condition1}={n_cond1}, {params.condition2}={n_cond2})"
407
+ )
408
+
409
+ # Run DESeq2
410
+ try:
411
+ top_up, top_down, n_significant, results_df = _run_deseq2(
412
+ counts_df,
413
+ metadata_df,
414
+ condition1=params.condition1,
415
+ condition2=params.condition2,
416
+ n_top_genes=params.n_top_genes,
417
+ padj_threshold=params.padj_threshold,
418
+ log2fc_threshold=params.log2fc_threshold,
419
+ )
420
+ except Exception as e:
421
+ raise ProcessingError(f"DESeq2 analysis failed: {e}") from e
422
+
423
+ await ctx.info(f"Found {n_significant} significant DE genes")
424
+
425
+ # Build result
426
+ comparison = f"{params.condition1} vs {params.condition2}"
427
+
428
+ return ConditionComparisonResult(
429
+ data_id="", # Will be filled by caller
430
+ method="pseudobulk",
431
+ comparison=comparison,
432
+ condition_key=params.condition_key,
433
+ condition1=params.condition1,
434
+ condition2=params.condition2,
435
+ sample_key=params.sample_key,
436
+ cell_type_key=None,
437
+ n_samples_condition1=0, # Will be filled by caller
438
+ n_samples_condition2=0, # Will be filled by caller
439
+ global_n_significant=n_significant,
440
+ global_top_upregulated=top_up,
441
+ global_top_downregulated=top_down,
442
+ cell_type_results=None,
443
+ results_key="", # Will be filled by caller
444
+ statistics={
445
+ "analysis_type": "global",
446
+ "n_pseudobulk_samples": len(counts_df),
447
+ "n_significant_genes": n_significant,
448
+ "n_upregulated": len([g for g in top_up if g.padj < params.padj_threshold]),
449
+ "n_downregulated": len(
450
+ [g for g in top_down if g.padj < params.padj_threshold]
451
+ ),
452
+ },
453
+ )
454
+
455
+
456
+ async def _run_stratified_comparison(
457
+ adata,
458
+ raw_X: np.ndarray,
459
+ var_names: pd.Index,
460
+ ctx: ToolContext,
461
+ params: ConditionComparisonParameters,
462
+ ) -> ConditionComparisonResult:
463
+ """Run cell type stratified comparison.
464
+
465
+ Args:
466
+ adata: Filtered AnnData object
467
+ raw_X: Raw count matrix
468
+ var_names: Gene names
469
+ ctx: Tool context
470
+ params: Comparison parameters
471
+
472
+ Returns:
473
+ ConditionComparisonResult with cell type stratified results
474
+ """
475
+ await ctx.info(f"Running stratified analysis by {params.cell_type_key}...")
476
+
477
+ cell_types = adata.obs[params.cell_type_key].unique()
478
+ await ctx.info(f"Found {len(cell_types)} cell types")
479
+
480
+ cell_type_results: list[CellTypeComparisonResult] = []
481
+ total_significant = 0
482
+
483
+ for ct in cell_types:
484
+ ct_mask = adata.obs[params.cell_type_key] == ct
485
+ n_cells_ct = ct_mask.sum()
486
+
487
+ if n_cells_ct < params.min_cells_per_sample * 2:
488
+ await ctx.warning(
489
+ f"Skipping {ct}: only {n_cells_ct} cells "
490
+ f"(need {params.min_cells_per_sample * 2})"
491
+ )
492
+ continue
493
+
494
+ try:
495
+ # Create pseudobulk for this cell type
496
+ counts_df, metadata_df, cell_counts = _create_pseudobulk(
497
+ adata,
498
+ raw_X,
499
+ var_names,
500
+ sample_key=params.sample_key,
501
+ condition_key=params.condition_key,
502
+ cell_type=ct,
503
+ cell_type_key=params.cell_type_key,
504
+ min_cells_per_sample=params.min_cells_per_sample,
505
+ )
506
+
507
+ # Check sample distribution
508
+ cond_counts = metadata_df["condition"].value_counts()
509
+ n_cond1 = cond_counts.get(params.condition1, 0)
510
+ n_cond2 = cond_counts.get(params.condition2, 0)
511
+
512
+ if n_cond1 < 2 or n_cond2 < 2:
513
+ await ctx.warning(
514
+ f"Skipping {ct}: insufficient samples "
515
+ f"({params.condition1}={n_cond1}, {params.condition2}={n_cond2})"
516
+ )
517
+ continue
518
+
519
+ # Run DESeq2
520
+ top_up, top_down, n_significant, results_df = _run_deseq2(
521
+ counts_df,
522
+ metadata_df,
523
+ condition1=params.condition1,
524
+ condition2=params.condition2,
525
+ n_top_genes=params.n_top_genes,
526
+ padj_threshold=params.padj_threshold,
527
+ log2fc_threshold=params.log2fc_threshold,
528
+ )
529
+
530
+ total_significant += n_significant
531
+
532
+ # Count cells per condition for this cell type
533
+ ct_adata = adata[ct_mask]
534
+ n_cells_cond1 = (
535
+ ct_adata.obs[params.condition_key] == params.condition1
536
+ ).sum()
537
+ n_cells_cond2 = (
538
+ ct_adata.obs[params.condition_key] == params.condition2
539
+ ).sum()
540
+
541
+ cell_type_results.append(
542
+ CellTypeComparisonResult(
543
+ cell_type=str(ct),
544
+ n_cells_condition1=int(n_cells_cond1),
545
+ n_cells_condition2=int(n_cells_cond2),
546
+ n_samples_condition1=int(n_cond1),
547
+ n_samples_condition2=int(n_cond2),
548
+ n_significant_genes=n_significant,
549
+ top_upregulated=top_up,
550
+ top_downregulated=top_down,
551
+ )
552
+ )
553
+
554
+ await ctx.info(
555
+ f"{ct}: {n_significant} significant genes "
556
+ f"({len(top_up)} up, {len(top_down)} down)"
557
+ )
558
+
559
+ except Exception as e:
560
+ await ctx.warning(f"Analysis failed for {ct}: {e}")
561
+ continue
562
+
563
+ if not cell_type_results:
564
+ raise ProcessingError(
565
+ "No cell types had sufficient samples for DESeq2 analysis. "
566
+ "Try lowering min_cells_per_sample or min_samples_per_condition."
567
+ )
568
+
569
+ comparison = f"{params.condition1} vs {params.condition2}"
570
+
571
+ return ConditionComparisonResult(
572
+ data_id="", # Will be filled by caller
573
+ method="pseudobulk",
574
+ comparison=comparison,
575
+ condition_key=params.condition_key,
576
+ condition1=params.condition1,
577
+ condition2=params.condition2,
578
+ sample_key=params.sample_key,
579
+ cell_type_key=params.cell_type_key,
580
+ n_samples_condition1=0, # Will be filled by caller
581
+ n_samples_condition2=0, # Will be filled by caller
582
+ global_n_significant=None,
583
+ global_top_upregulated=None,
584
+ global_top_downregulated=None,
585
+ cell_type_results=cell_type_results,
586
+ results_key="", # Will be filled by caller
587
+ statistics={
588
+ "analysis_type": "cell_type_stratified",
589
+ "n_cell_types_analyzed": len(cell_type_results),
590
+ "total_significant_genes": total_significant,
591
+ "cell_types_with_de_genes": len(
592
+ [r for r in cell_type_results if r.n_significant_genes > 0]
593
+ ),
594
+ },
595
+ )