chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,625 @@
1
+ """
2
+ Differential expression analysis tools for spatial transcriptomics data.
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ from scipy import sparse
9
+
10
+ from ..models.analysis import DifferentialExpressionResult
11
+ from ..models.data import DifferentialExpressionParameters
12
+ from ..spatial_mcp_adapter import ToolContext
13
+ from ..utils import validate_obs_column
14
+ from ..utils.adata_utils import store_analysis_metadata, to_dense
15
+ from ..utils.dependency_manager import require
16
+ from ..utils.exceptions import (
17
+ DataError,
18
+ DataNotFoundError,
19
+ ParameterError,
20
+ ProcessingError,
21
+ )
22
+
23
+
24
+ async def differential_expression(
25
+ data_id: str,
26
+ ctx: ToolContext,
27
+ params: DifferentialExpressionParameters,
28
+ ) -> DifferentialExpressionResult:
29
+ """Perform differential expression analysis.
30
+
31
+ Args:
32
+ data_id: Dataset ID
33
+ ctx: Tool context for data access and logging
34
+ params: Differential expression parameters
35
+
36
+ Returns:
37
+ Differential expression analysis result
38
+ """
39
+ # Extract parameters from params object
40
+ group_key = params.group_key
41
+ group1 = params.group1
42
+ group2 = params.group2
43
+ method = params.method
44
+ n_top_genes = params.n_top_genes
45
+ pseudocount = params.pseudocount
46
+ min_cells = params.min_cells
47
+
48
+ # Dispatch to pydeseq2 if requested
49
+ if method == "pydeseq2":
50
+ return await _run_pydeseq2(data_id, ctx, params)
51
+
52
+ # Get AnnData directly via ToolContext (no redundant dict wrapping)
53
+ adata = await ctx.get_adata(data_id)
54
+
55
+ # Check if the group_key exists in adata.obs
56
+ validate_obs_column(adata, group_key, "Group")
57
+
58
+ # Check if dtype conversion is needed (numba doesn't support float16)
59
+ # Defer conversion to after subsetting for memory efficiency
60
+ needs_dtype_fix = hasattr(adata.X, "dtype") and adata.X.dtype == np.float16
61
+
62
+ # If group1 is None, find markers for all groups
63
+ if group1 is None:
64
+
65
+ # Filter out groups with too few cells (user-configurable threshold)
66
+ group_sizes = adata.obs[group_key].value_counts()
67
+ # min_cells is now a parameter (default=3, minimum for Wilcoxon test)
68
+ valid_groups = group_sizes[group_sizes >= min_cells]
69
+ skipped_groups = group_sizes[group_sizes < min_cells]
70
+
71
+ # Warn about skipped groups
72
+ if len(skipped_groups) > 0:
73
+ skipped_list = "\n".join(
74
+ [f" - {g}: {n} cell(s)" for g, n in skipped_groups.items()]
75
+ )
76
+ await ctx.warning(
77
+ f"Skipped {len(skipped_groups)} group(s) with <{min_cells} cells:\n{skipped_list}"
78
+ )
79
+
80
+ # Check if any valid groups remain
81
+ if len(valid_groups) == 0:
82
+ all_sizes = "\n".join(
83
+ [f" • {g}: {n} cell(s)" for g, n in group_sizes.items()]
84
+ )
85
+ raise DataError(
86
+ f"All groups have <{min_cells} cells. Cannot perform {method} test.\n\n"
87
+ f"Group sizes:\n{all_sizes}\n\n"
88
+ f"Try: find_markers(group_key='leiden') or merge small groups"
89
+ )
90
+
91
+ # Filter data to only include valid groups
92
+ adata_filtered = adata[adata.obs[group_key].isin(valid_groups.index)].copy()
93
+
94
+ # Convert dtype after subsetting (4x more memory efficient than copying first)
95
+ if needs_dtype_fix:
96
+ adata_filtered.X = adata_filtered.X.astype(np.float32)
97
+
98
+ # Run rank_genes_groups on filtered data
99
+ sc.tl.rank_genes_groups(
100
+ adata_filtered,
101
+ groupby=group_key,
102
+ method=method,
103
+ n_genes=n_top_genes,
104
+ reference="rest",
105
+ )
106
+
107
+ # Get all groups (from filtered data)
108
+ groups = adata_filtered.obs[group_key].unique()
109
+
110
+ # Collect top genes from all groups
111
+ all_top_genes = []
112
+ if (
113
+ "rank_genes_groups" in adata_filtered.uns
114
+ and "names" in adata_filtered.uns["rank_genes_groups"]
115
+ ):
116
+ gene_names = adata_filtered.uns["rank_genes_groups"]["names"]
117
+ for group in groups:
118
+ if str(group) in gene_names.dtype.names:
119
+ genes = list(gene_names[str(group)][:n_top_genes])
120
+ all_top_genes.extend(genes)
121
+
122
+ # Remove duplicates while preserving order
123
+ seen = set()
124
+ top_genes = []
125
+ for gene in all_top_genes:
126
+ if gene not in seen:
127
+ seen.add(gene)
128
+ top_genes.append(gene)
129
+
130
+ # Limit to n_top_genes
131
+ top_genes = top_genes[:n_top_genes]
132
+
133
+ # Copy results back to original adata for persistence
134
+ adata.uns["rank_genes_groups"] = adata_filtered.uns["rank_genes_groups"]
135
+
136
+ # Store metadata for scientific provenance tracking
137
+ store_analysis_metadata(
138
+ adata,
139
+ analysis_name="differential_expression",
140
+ method=method,
141
+ parameters={
142
+ "group_key": group_key,
143
+ "comparison_type": "all_groups",
144
+ "n_top_genes": n_top_genes,
145
+ },
146
+ results_keys={"uns": ["rank_genes_groups"]},
147
+ statistics={
148
+ "method": method,
149
+ "n_groups": len(groups),
150
+ "groups": list(map(str, groups)),
151
+ "n_cells_analyzed": adata_filtered.n_obs,
152
+ "n_genes_analyzed": adata_filtered.n_vars,
153
+ },
154
+ )
155
+
156
+ return DifferentialExpressionResult(
157
+ data_id=data_id,
158
+ comparison=f"All groups in {group_key}",
159
+ n_genes=len(top_genes),
160
+ top_genes=top_genes,
161
+ statistics={
162
+ "method": method,
163
+ "n_groups": len(groups),
164
+ "groups": list(map(str, groups)),
165
+ },
166
+ )
167
+
168
+ # Original logic for specific group comparison
169
+ # Check if the groups exist in the group_key
170
+ if group1 not in adata.obs[group_key].values:
171
+ raise ParameterError(f"Group '{group1}' not found in adata.obs['{group_key}']")
172
+
173
+ # Special case for 'rest' as group2 or if group2 is None
174
+ use_rest_as_reference = False
175
+ if group2 is None or group2 == "rest":
176
+ use_rest_as_reference = True
177
+ group2 = "rest" # Set it explicitly for display purposes
178
+ elif group2 not in adata.obs[group_key].values:
179
+ raise ParameterError(f"Group '{group2}' not found in adata.obs['{group_key}']")
180
+
181
+ # Perform differential expression analysis
182
+
183
+ # Prepare the AnnData object for analysis
184
+ if use_rest_as_reference:
185
+ # Use the full AnnData object when comparing with 'rest'
186
+ temp_adata = adata.copy()
187
+ else:
188
+ # Create a temporary copy of the AnnData object with only the two groups
189
+ temp_adata = adata[adata.obs[group_key].isin([group1, group2])].copy()
190
+
191
+ # Convert dtype after subsetting (4x more memory efficient than copying first)
192
+ if needs_dtype_fix:
193
+ temp_adata.X = temp_adata.X.astype(np.float32)
194
+
195
+ # Run rank_genes_groups
196
+ sc.tl.rank_genes_groups(
197
+ temp_adata,
198
+ groupby=group_key,
199
+ groups=[group1],
200
+ reference="rest" if use_rest_as_reference else group2,
201
+ method=method,
202
+ n_genes=n_top_genes,
203
+ )
204
+
205
+ # Extract results
206
+
207
+ # Get the top genes
208
+ top_genes = []
209
+ if (
210
+ hasattr(temp_adata, "uns")
211
+ and "rank_genes_groups" in temp_adata.uns
212
+ and "names" in temp_adata.uns["rank_genes_groups"]
213
+ ):
214
+ # Get the top genes for the first group (should be group1)
215
+ gene_names = temp_adata.uns["rank_genes_groups"]["names"]
216
+ if group1 in gene_names.dtype.names:
217
+ top_genes = list(gene_names[group1][:n_top_genes])
218
+ else:
219
+ # If group1 is not in the names, use the first column
220
+ top_genes = list(gene_names[gene_names.dtype.names[0]][:n_top_genes])
221
+
222
+ # If no genes were found, fail honestly
223
+ if not top_genes:
224
+ raise ProcessingError(
225
+ f"No DE genes found between {group1} and {group2}. "
226
+ f"Check sample sizes and expression differences."
227
+ )
228
+
229
+ # Get statistics
230
+ n_cells_group1 = np.sum(adata.obs[group_key] == group1)
231
+ if use_rest_as_reference:
232
+ n_cells_group2 = adata.n_obs - n_cells_group1 # All cells except group1
233
+ else:
234
+ n_cells_group2 = np.sum(adata.obs[group_key] == group2)
235
+
236
+ # Get p-values from scanpy results
237
+ pvals = []
238
+ if (
239
+ hasattr(temp_adata, "uns")
240
+ and "rank_genes_groups" in temp_adata.uns
241
+ and "pvals_adj" in temp_adata.uns["rank_genes_groups"]
242
+ and group1 in temp_adata.uns["rank_genes_groups"]["pvals_adj"].dtype.names
243
+ ):
244
+ pvals = list(
245
+ temp_adata.uns["rank_genes_groups"]["pvals_adj"][group1][:n_top_genes]
246
+ )
247
+
248
+ # Calculate TRUE fold change from raw counts (Bug #3 Fix)
249
+ # Issue: scanpy's logfoldchanges uses mean(log(counts)) which is mathematically incorrect
250
+ # Solution: Calculate log(mean(counts1) / mean(counts2)) from raw data
251
+
252
+ # Check if raw count data is available
253
+ if adata.raw is None:
254
+ raise DataNotFoundError(
255
+ "Raw count data (adata.raw) required for fold change calculation. "
256
+ "Run preprocess_data() first to preserve raw counts."
257
+ )
258
+
259
+ # Get raw count data
260
+ raw_adata = adata.raw
261
+ log2fc_values = []
262
+
263
+ # Create masks for the two groups
264
+ if use_rest_as_reference:
265
+ group1_mask = adata.obs[group_key] == group1
266
+ group2_mask = ~group1_mask
267
+ else:
268
+ group1_mask = adata.obs[group_key] == group1
269
+ group2_mask = adata.obs[group_key] == group2
270
+
271
+ # CRITICAL: Normalize by library size to avoid composition bias
272
+ # Library size = total UMI counts per spot
273
+ if hasattr(raw_adata.X, "toarray"):
274
+ lib_sizes = np.array(raw_adata.X.sum(axis=1)).flatten()
275
+ else:
276
+ lib_sizes = raw_adata.X.sum(axis=1).flatten()
277
+
278
+ median_lib_size = float(np.median(lib_sizes))
279
+
280
+ # Calculate fold change for each top gene
281
+ for gene in top_genes:
282
+ if gene in raw_adata.var_names:
283
+ gene_idx = raw_adata.var_names.get_loc(gene)
284
+
285
+ # Get raw counts for this gene
286
+ gene_raw_counts = to_dense(raw_adata.X[:, gene_idx]).flatten()
287
+
288
+ # Normalize by library size (CPM-like normalization)
289
+ # normalized_counts = raw_counts * (median_lib_size / spot_lib_size)
290
+ gene_norm_counts = gene_raw_counts * (median_lib_size / lib_sizes)
291
+
292
+ # Calculate mean normalized counts for each group
293
+ mean_group1 = float(gene_norm_counts[group1_mask].mean())
294
+ mean_group2 = float(gene_norm_counts[group2_mask].mean())
295
+
296
+ # Calculate true log2 fold change from normalized counts
297
+ # Add user-configurable pseudocount to avoid log(0)
298
+ true_log2fc = np.log2(
299
+ (mean_group1 + pseudocount) / (mean_group2 + pseudocount)
300
+ )
301
+ log2fc_values.append(float(true_log2fc))
302
+ else:
303
+ # Gene not in raw data (should not happen, but handle gracefully)
304
+ await ctx.warning(
305
+ f"Gene {gene} not found in raw data, skipping fold change calculation"
306
+ )
307
+ log2fc_values.append(None)
308
+
309
+ # Calculate mean log2fc (filtering out None values)
310
+ valid_log2fc = [fc for fc in log2fc_values if fc is not None]
311
+ mean_log2fc = np.mean(valid_log2fc) if valid_log2fc else None
312
+ median_pvalue = np.median(pvals) if pvals else None
313
+
314
+ # Warn if fold change values are suspiciously high (indicating calculation errors)
315
+ if mean_log2fc is not None and abs(mean_log2fc) > 10:
316
+ await ctx.warning(
317
+ f"Extreme fold change: mean log2FC = {mean_log2fc:.2f} (>1024x). "
318
+ f"May indicate sparse expression or low cell counts."
319
+ )
320
+
321
+ # Create statistics dictionary
322
+ statistics = {
323
+ "method": method,
324
+ "n_cells_group1": int(n_cells_group1),
325
+ "n_cells_group2": int(n_cells_group2),
326
+ "mean_log2fc": float(mean_log2fc) if mean_log2fc is not None else None,
327
+ "median_pvalue": float(median_pvalue) if median_pvalue is not None else None,
328
+ }
329
+
330
+ # Create comparison string
331
+ comparison = f"{group1} vs {group2}"
332
+
333
+ # Copy results back to original adata for persistence
334
+ adata.uns["rank_genes_groups"] = temp_adata.uns["rank_genes_groups"]
335
+
336
+ # Store metadata for scientific provenance tracking
337
+ store_analysis_metadata(
338
+ adata,
339
+ analysis_name="differential_expression",
340
+ method=method,
341
+ parameters={
342
+ "group_key": group_key,
343
+ "group1": group1,
344
+ "group2": group2,
345
+ "comparison_type": "specific_groups",
346
+ "n_top_genes": n_top_genes,
347
+ "pseudocount": pseudocount, # Track for reproducibility
348
+ },
349
+ results_keys={"uns": ["rank_genes_groups"]},
350
+ statistics={
351
+ "method": method,
352
+ "group1": group1,
353
+ "group2": group2,
354
+ "n_cells_group1": int(n_cells_group1),
355
+ "n_cells_group2": int(n_cells_group2),
356
+ "n_genes_analyzed": temp_adata.n_vars,
357
+ "mean_log2fc": float(mean_log2fc) if mean_log2fc is not None else None,
358
+ "median_pvalue": (
359
+ float(median_pvalue) if median_pvalue is not None else None
360
+ ),
361
+ "pseudocount_used": pseudocount, # Document in statistics
362
+ },
363
+ )
364
+
365
+ return DifferentialExpressionResult(
366
+ data_id=data_id,
367
+ comparison=comparison,
368
+ n_genes=len(top_genes),
369
+ top_genes=top_genes,
370
+ statistics=statistics,
371
+ )
372
+
373
+
374
+ async def _run_pydeseq2(
375
+ data_id: str,
376
+ ctx: ToolContext,
377
+ params: DifferentialExpressionParameters,
378
+ ) -> DifferentialExpressionResult:
379
+ """Run PyDESeq2 pseudobulk differential expression analysis.
380
+
381
+ This function performs pseudobulk aggregation by summing raw counts within
382
+ each sample/group combination, then uses PyDESeq2 for DE analysis.
383
+
384
+ Args:
385
+ data_id: Dataset ID
386
+ ctx: Tool context for data access and logging
387
+ params: Differential expression parameters
388
+
389
+ Returns:
390
+ Differential expression analysis result
391
+
392
+ Raises:
393
+ ParameterError: If sample_key is not provided
394
+ ImportError: If pydeseq2 is not installed
395
+ """
396
+ # Validate sample_key is provided
397
+ if params.sample_key is None:
398
+ raise ParameterError(
399
+ "sample_key is required for pydeseq2 method.\n"
400
+ "Provide a column in adata.obs that identifies biological replicates "
401
+ "(e.g., 'sample', 'patient_id', 'batch').\n"
402
+ "Example: find_markers(group_key='cell_type', method='pydeseq2', "
403
+ "sample_key='sample')"
404
+ )
405
+
406
+ # Import pydeseq2 (require() raises ImportError if not available)
407
+ require("pydeseq2", ctx, feature="DESeq2 differential expression")
408
+ from pydeseq2.dds import DeseqDataSet
409
+ from pydeseq2.ds import DeseqStats
410
+
411
+ # Get data
412
+ adata = await ctx.get_adata(data_id)
413
+
414
+ # Validate columns
415
+ validate_obs_column(adata, params.group_key, "Group")
416
+ validate_obs_column(adata, params.sample_key, "Sample")
417
+
418
+ # Get raw counts (required for DESeq2)
419
+ if adata.raw is not None:
420
+ raw_X = adata.raw.X
421
+ var_names = adata.raw.var_names
422
+ else:
423
+ raw_X = adata.X
424
+ var_names = adata.var_names
425
+
426
+ # Convert to dense if sparse
427
+ if sparse.issparse(raw_X):
428
+ raw_X = raw_X.toarray()
429
+
430
+ # Validate counts are integers (DESeq2 requirement)
431
+ if not np.allclose(raw_X, raw_X.astype(int)):
432
+ await ctx.warning(
433
+ "Data appears to be normalized. DESeq2 requires raw integer counts. "
434
+ "Results may be inaccurate."
435
+ )
436
+
437
+ # Determine comparison groups
438
+ group_key = params.group_key
439
+ sample_key = params.sample_key
440
+ group1 = params.group1
441
+ group2 = params.group2
442
+
443
+ # If group1 is None, find first two groups for pairwise comparison
444
+ unique_groups = adata.obs[group_key].unique()
445
+ if group1 is None:
446
+ if len(unique_groups) < 2:
447
+ raise DataError(
448
+ f"Need at least 2 groups for DE analysis, found {len(unique_groups)}"
449
+ )
450
+ group1 = str(unique_groups[0])
451
+ group2 = str(unique_groups[1])
452
+ await ctx.info(
453
+ f"No group specified, comparing first two groups: {group1} vs {group2}"
454
+ )
455
+ elif group2 is None or group2 == "rest":
456
+ # Compare group1 vs all others combined as "rest"
457
+ group2 = "rest"
458
+
459
+ # Create pseudobulk aggregation
460
+ await ctx.info(f"Creating pseudobulk samples by {sample_key} and {group_key}...")
461
+
462
+ # Build aggregation key
463
+ if group2 == "rest":
464
+ # Binary comparison: group1 vs rest
465
+ condition = adata.obs[group_key].apply(
466
+ lambda x: group1 if x == group1 else "rest"
467
+ )
468
+ else:
469
+ # Pairwise comparison: filter to only group1 and group2
470
+ mask = adata.obs[group_key].isin([group1, group2])
471
+ adata = adata[mask].copy()
472
+ raw_X = raw_X[mask.values]
473
+ condition = adata.obs[group_key].astype(str)
474
+
475
+ # Create pseudobulk by aggregating (summing) counts per sample+condition
476
+ adata.obs["_de_condition"] = condition.values
477
+ adata.obs["_pseudobulk_id"] = (
478
+ adata.obs[sample_key].astype(str) + "_" + adata.obs["_de_condition"].astype(str)
479
+ )
480
+
481
+ # Aggregate counts
482
+ pseudobulk_groups = adata.obs.groupby("_pseudobulk_id")
483
+ pseudobulk_ids = list(pseudobulk_groups.groups.keys())
484
+ n_samples = len(pseudobulk_ids)
485
+
486
+ await ctx.info(f"Aggregated into {n_samples} pseudobulk samples")
487
+
488
+ if n_samples < 4:
489
+ raise DataError(
490
+ f"DESeq2 requires at least 2 samples per group. "
491
+ f"Found only {n_samples} total pseudobulk samples. "
492
+ f"Add more biological replicates or use a different method (wilcoxon)."
493
+ )
494
+
495
+ # Build pseudobulk count matrix
496
+ pseudobulk_counts = np.zeros((n_samples, raw_X.shape[1]), dtype=np.int64)
497
+ pseudobulk_metadata = []
498
+
499
+ for i, pb_id in enumerate(pseudobulk_ids):
500
+ # Get indices for this pseudobulk group
501
+ group_labels = pseudobulk_groups.groups[pb_id]
502
+ # Convert pandas Index to integer positional indices for numpy array indexing
503
+ int_idx = adata.obs.index.get_indexer(group_labels)
504
+ pseudobulk_counts[i] = raw_X[int_idx].sum(axis=0).astype(np.int64)
505
+ # Get condition from first cell in this group
506
+ first_int_idx = int_idx[0]
507
+ pseudobulk_metadata.append(
508
+ {
509
+ "sample_id": pb_id,
510
+ "condition": adata.obs.iloc[first_int_idx]["_de_condition"],
511
+ "sample": adata.obs.iloc[first_int_idx][sample_key],
512
+ }
513
+ )
514
+
515
+ # Create metadata DataFrame
516
+ metadata_df = pd.DataFrame(pseudobulk_metadata)
517
+ metadata_df = metadata_df.set_index("sample_id")
518
+
519
+ # Create count DataFrame
520
+ counts_df = pd.DataFrame(pseudobulk_counts, index=pseudobulk_ids, columns=var_names)
521
+
522
+ # Check sample counts per condition
523
+ condition_counts = metadata_df["condition"].value_counts()
524
+ await ctx.info(f"Samples per condition: {condition_counts.to_dict()}")
525
+
526
+ if any(condition_counts < 2):
527
+ raise DataError(
528
+ f"DESeq2 requires at least 2 samples per condition. "
529
+ f"Current counts: {condition_counts.to_dict()}"
530
+ )
531
+
532
+ # Run PyDESeq2
533
+ await ctx.info("Running PyDESeq2 differential expression analysis...")
534
+
535
+ try:
536
+ # Create DESeq2 dataset
537
+ dds = DeseqDataSet(
538
+ counts=counts_df,
539
+ metadata=metadata_df,
540
+ design_factors="condition",
541
+ )
542
+
543
+ # Run DESeq2 pipeline
544
+ dds.deseq2()
545
+
546
+ # Get results
547
+ stat_res = DeseqStats(dds, contrast=["condition", group1, group2])
548
+ stat_res.summary()
549
+
550
+ # Get results DataFrame
551
+ results_df = stat_res.results_df
552
+
553
+ except Exception as e:
554
+ raise ProcessingError(
555
+ f"PyDESeq2 analysis failed: {e}\n"
556
+ "This may be due to low sample counts or data issues."
557
+ ) from e
558
+
559
+ # Extract top DE genes
560
+ # Sort by adjusted p-value, filter significant genes
561
+ results_df = results_df.dropna(subset=["padj"])
562
+ results_df = results_df.sort_values("padj")
563
+
564
+ top_genes = results_df.head(params.n_top_genes).index.tolist()
565
+
566
+ if not top_genes:
567
+ raise ProcessingError(
568
+ f"No DE genes found between {group1} and {group2}. "
569
+ "Check sample sizes and expression differences."
570
+ )
571
+
572
+ # Get statistics
573
+ n_sig_genes = (results_df["padj"] < 0.05).sum()
574
+ mean_log2fc = results_df.head(params.n_top_genes)["log2FoldChange"].mean()
575
+ median_pvalue = results_df.head(params.n_top_genes)["padj"].median()
576
+
577
+ # Store results in adata.uns for persistence
578
+ adata.uns["pydeseq2_results"] = {
579
+ "results_df": results_df.to_dict(),
580
+ "comparison": f"{group1} vs {group2}",
581
+ "n_samples": n_samples,
582
+ }
583
+
584
+ # Store metadata for scientific provenance tracking
585
+ store_analysis_metadata(
586
+ adata,
587
+ analysis_name="differential_expression",
588
+ method="pydeseq2",
589
+ parameters={
590
+ "group_key": group_key,
591
+ "sample_key": sample_key,
592
+ "group1": group1,
593
+ "group2": group2,
594
+ "comparison_type": "pseudobulk",
595
+ "n_top_genes": params.n_top_genes,
596
+ },
597
+ results_keys={"uns": ["pydeseq2_results"]},
598
+ statistics={
599
+ "method": "pydeseq2",
600
+ "group1": group1,
601
+ "group2": group2,
602
+ "n_pseudobulk_samples": n_samples,
603
+ "n_significant_genes": int(n_sig_genes),
604
+ "mean_log2fc": float(mean_log2fc) if not np.isnan(mean_log2fc) else None,
605
+ "median_padj": (
606
+ float(median_pvalue) if not np.isnan(median_pvalue) else None
607
+ ),
608
+ },
609
+ )
610
+
611
+ return DifferentialExpressionResult(
612
+ data_id=data_id,
613
+ comparison=f"{group1} vs {group2}",
614
+ n_genes=len(top_genes),
615
+ top_genes=top_genes,
616
+ statistics={
617
+ "method": "pydeseq2",
618
+ "n_pseudobulk_samples": n_samples,
619
+ "n_significant_genes": int(n_sig_genes),
620
+ "mean_log2fc": float(mean_log2fc) if not np.isnan(mean_log2fc) else None,
621
+ "median_padj": (
622
+ float(median_pvalue) if not np.isnan(median_pvalue) else None
623
+ ),
624
+ },
625
+ )