gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,590 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rank calculation from latent representations
|
|
3
|
+
Extracts and processes the rank calculation logic from find_latent_representation.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import gc
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import anndata as ad
|
|
13
|
+
import h5py
|
|
14
|
+
import jax
|
|
15
|
+
import jax.numpy as jnp
|
|
16
|
+
import jax.scipy
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
19
|
+
import scanpy as sc
|
|
20
|
+
from anndata._io.specs import read_elem
|
|
21
|
+
from jax import jit
|
|
22
|
+
from rich.console import Console
|
|
23
|
+
from rich.progress import (
|
|
24
|
+
BarColumn,
|
|
25
|
+
MofNCompleteColumn,
|
|
26
|
+
Progress,
|
|
27
|
+
SpinnerColumn,
|
|
28
|
+
TaskProgressColumn,
|
|
29
|
+
TextColumn,
|
|
30
|
+
TimeRemainingColumn,
|
|
31
|
+
)
|
|
32
|
+
from rich.table import Table
|
|
33
|
+
from scipy.sparse import csr_matrix
|
|
34
|
+
|
|
35
|
+
from gsMap.config import LatentToGeneConfig
|
|
36
|
+
|
|
37
|
+
from .memmap_io import MemMapDense
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def filter_cells(obs_df, X=None, annotation_key=None, min_cells_per_type=None, min_genes=20, precomputed_gene_counts=None):
|
|
43
|
+
"""
|
|
44
|
+
Apply cell filtering based on annotation and gene expression criteria.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
obs_df: DataFrame with cell metadata
|
|
48
|
+
X: Optional expression matrix (sparse or dense) for gene count filtering
|
|
49
|
+
annotation_key: Optional annotation column to filter by
|
|
50
|
+
min_cells_per_type: Minimum cells required per annotation group
|
|
51
|
+
min_genes: Minimum number of genes expressed per cell (default: 20)
|
|
52
|
+
precomputed_gene_counts: Optional precomputed gene counts per cell (e.g., from CSR indptr)
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Tuple of (filtered_obs_df, filtering_stats_dict, keep_mask)
|
|
56
|
+
"""
|
|
57
|
+
n_cells_before = obs_df.shape[0]
|
|
58
|
+
stats = {
|
|
59
|
+
"input_cells": n_cells_before,
|
|
60
|
+
"nan_removed": 0,
|
|
61
|
+
"small_group_removed": 0,
|
|
62
|
+
"low_gene_count_removed": 0,
|
|
63
|
+
"final_cells": 0
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# Create a mask for cells to keep (start with all True)
|
|
67
|
+
keep_mask = pd.Series(True, index=obs_df.index)
|
|
68
|
+
|
|
69
|
+
# Filter 1: Remove cells with NaN in annotation_key
|
|
70
|
+
if annotation_key is not None and annotation_key in obs_df.columns:
|
|
71
|
+
nan_mask = obs_df[annotation_key].notna()
|
|
72
|
+
stats["nan_removed"] = (~nan_mask).sum()
|
|
73
|
+
keep_mask &= nan_mask
|
|
74
|
+
|
|
75
|
+
# Filter 2: Remove cells from small annotation groups
|
|
76
|
+
if annotation_key is not None and annotation_key in obs_df.columns and min_cells_per_type is not None:
|
|
77
|
+
# Apply on the current filtered data
|
|
78
|
+
temp_obs = obs_df[keep_mask]
|
|
79
|
+
annotation_counts = temp_obs[annotation_key].value_counts()
|
|
80
|
+
valid_annotations = annotation_counts[annotation_counts >= min_cells_per_type].index
|
|
81
|
+
|
|
82
|
+
if len(valid_annotations) < len(annotation_counts):
|
|
83
|
+
small_group_mask = obs_df[annotation_key].isin(valid_annotations)
|
|
84
|
+
stats["small_group_removed"] = keep_mask.sum() - (keep_mask & small_group_mask).sum()
|
|
85
|
+
keep_mask &= small_group_mask
|
|
86
|
+
|
|
87
|
+
# Filter 3: Remove cells with too few genes
|
|
88
|
+
if min_genes is not None:
|
|
89
|
+
if precomputed_gene_counts is not None:
|
|
90
|
+
# Use precomputed gene counts (efficient for CSR format from indptr)
|
|
91
|
+
nfeature_mask = precomputed_gene_counts >= min_genes
|
|
92
|
+
nfeature_mask = pd.Series(nfeature_mask, index=obs_df.index)
|
|
93
|
+
elif X is not None:
|
|
94
|
+
nfeature_mask, n_genes_per_cell = sc.pp.filter_cells(
|
|
95
|
+
X,
|
|
96
|
+
min_genes=min_genes,
|
|
97
|
+
inplace=False
|
|
98
|
+
)
|
|
99
|
+
# Convert to pandas Series for consistent indexing
|
|
100
|
+
nfeature_mask = pd.Series(nfeature_mask, index=obs_df.index)
|
|
101
|
+
else:
|
|
102
|
+
# No gene count information available, skip this filter
|
|
103
|
+
nfeature_mask = pd.Series(True, index=obs_df.index)
|
|
104
|
+
|
|
105
|
+
stats["low_gene_count_removed"] = keep_mask.sum() - (keep_mask & nfeature_mask).sum()
|
|
106
|
+
keep_mask &= nfeature_mask
|
|
107
|
+
|
|
108
|
+
# Apply the combined filter
|
|
109
|
+
filtered_obs = obs_df[keep_mask].copy()
|
|
110
|
+
stats["final_cells"] = filtered_obs.shape[0]
|
|
111
|
+
|
|
112
|
+
return filtered_obs, stats, keep_mask
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@jit
|
|
116
|
+
def jax_process_chunk(dense_matrix, n_genes):
|
|
117
|
+
|
|
118
|
+
nonzero_mask = dense_matrix != 0
|
|
119
|
+
|
|
120
|
+
ranks = jax.scipy.stats.rankdata(dense_matrix, method='average', axis=1)
|
|
121
|
+
log_ranks = jnp.log(ranks)
|
|
122
|
+
# Sum log ranks (with fill_zero)
|
|
123
|
+
sum_log_ranks = log_ranks.sum(axis=0)
|
|
124
|
+
# Sum fraction (count of non-zeros)
|
|
125
|
+
sum_frac = nonzero_mask.sum(axis=0)
|
|
126
|
+
|
|
127
|
+
return log_ranks, sum_log_ranks, sum_frac
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def rank_data_jax(X: csr_matrix, n_genes,
|
|
131
|
+
memmap_dense=None,
|
|
132
|
+
metadata: dict[str, Any] | None = None,
|
|
133
|
+
chunk_size: int = 1000,
|
|
134
|
+
write_interval: int = 10,
|
|
135
|
+
current_row_offset: int = 0,
|
|
136
|
+
progress=None,
|
|
137
|
+
progress_task=None
|
|
138
|
+
):
|
|
139
|
+
"""JAX-optimized rank calculation with batched writing to memory-mapped storage.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
X: Input sparse matrix
|
|
143
|
+
n_genes: Total number of genes
|
|
144
|
+
memmap_dense: Optional MemMapDense instance for writing
|
|
145
|
+
metadata: Optional metadata dictionary
|
|
146
|
+
chunk_size: Size of chunks for processing
|
|
147
|
+
write_interval: How often to write chunks to memory map
|
|
148
|
+
current_row_offset: Offset for writing to memory map (for multiple sections)
|
|
149
|
+
progress: Progress instance for updates
|
|
150
|
+
progress_task: Task ID for progress updates
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Tuple of (sum_log_ranks, sum_frac) as numpy arrays
|
|
154
|
+
"""
|
|
155
|
+
assert X.nnz != 0, "Input matrix must not be empty"
|
|
156
|
+
|
|
157
|
+
n_rows, n_cols = X.shape
|
|
158
|
+
|
|
159
|
+
# Initialize accumulators (use float32 for accumulators to avoid precision loss)
|
|
160
|
+
sum_log_ranks = jnp.zeros(n_genes, dtype=jnp.float32)
|
|
161
|
+
sum_frac = jnp.zeros(n_genes, dtype=jnp.float32)
|
|
162
|
+
|
|
163
|
+
# Process in chunks to manage memory
|
|
164
|
+
chunk_size = min(chunk_size, n_rows)
|
|
165
|
+
pending_chunks = [] # Buffer for batching writes
|
|
166
|
+
pending_indices = [] # Track global indices for writing
|
|
167
|
+
chunks_processed = 0
|
|
168
|
+
|
|
169
|
+
# Track speed at chunk level
|
|
170
|
+
chunk_start_time = time.time()
|
|
171
|
+
processed_cells = 0
|
|
172
|
+
|
|
173
|
+
for start_idx in range(0, n_rows, chunk_size):
|
|
174
|
+
end_idx = min(start_idx + chunk_size, n_rows)
|
|
175
|
+
|
|
176
|
+
# Convert chunk to dense - use JAX asarray for zero-copy when possible
|
|
177
|
+
chunk_X = X[start_idx:end_idx]
|
|
178
|
+
chunk_dense = chunk_X.toarray().astype(np.float32)
|
|
179
|
+
chunk_jax = jnp.asarray(chunk_dense) # Use asarray for zero-copy conversion
|
|
180
|
+
|
|
181
|
+
# Process chunk with JIT-compiled function (ranking + accumulators)
|
|
182
|
+
chunk_log_ranks, chunk_sum_log_ranks, chunk_sum_frac = jax_process_chunk(chunk_jax, n_genes)
|
|
183
|
+
|
|
184
|
+
# Update global accumulators
|
|
185
|
+
sum_log_ranks += chunk_sum_log_ranks
|
|
186
|
+
sum_frac += chunk_sum_frac
|
|
187
|
+
|
|
188
|
+
# Convert JAX array to numpy float16 for storage efficiency
|
|
189
|
+
# This reduces memory usage by 50% compared to float32
|
|
190
|
+
chunk_log_ranks_np = np.array(chunk_log_ranks, dtype=np.float16)
|
|
191
|
+
pending_chunks.append(chunk_log_ranks_np)
|
|
192
|
+
# Calculate global indices for this chunk
|
|
193
|
+
global_start = current_row_offset + start_idx
|
|
194
|
+
global_end = current_row_offset + end_idx
|
|
195
|
+
pending_indices.append((global_start, global_end))
|
|
196
|
+
chunks_processed += 1
|
|
197
|
+
|
|
198
|
+
# Write to memory map periodically
|
|
199
|
+
if memmap_dense and chunks_processed % write_interval == 0:
|
|
200
|
+
# Combine pending chunks for batch write
|
|
201
|
+
combined_data = np.vstack(pending_chunks)
|
|
202
|
+
# Calculate row indices for batch write
|
|
203
|
+
start_row = pending_indices[0][0]
|
|
204
|
+
end_row = pending_indices[-1][1]
|
|
205
|
+
# Write as a contiguous block
|
|
206
|
+
memmap_dense.write_batch(combined_data, row_indices=slice(start_row, end_row))
|
|
207
|
+
pending_chunks.clear()
|
|
208
|
+
pending_indices.clear()
|
|
209
|
+
|
|
210
|
+
# Update progress bar with speed calculation
|
|
211
|
+
if progress and progress_task is not None:
|
|
212
|
+
chunk_cells = end_idx - start_idx
|
|
213
|
+
processed_cells += chunk_cells
|
|
214
|
+
elapsed_time = time.time() - chunk_start_time
|
|
215
|
+
speed = processed_cells / elapsed_time if elapsed_time > 0 else 0
|
|
216
|
+
progress.update(progress_task, advance=chunk_cells, speed=f"{speed:.0f}")
|
|
217
|
+
|
|
218
|
+
# Write any remaining chunks
|
|
219
|
+
if memmap_dense and pending_chunks:
|
|
220
|
+
combined_data = np.vstack(pending_chunks)
|
|
221
|
+
start_row = pending_indices[0][0]
|
|
222
|
+
end_row = pending_indices[-1][1]
|
|
223
|
+
memmap_dense.write_batch(combined_data, row_indices=slice(start_row, end_row))
|
|
224
|
+
|
|
225
|
+
return np.array(sum_log_ranks), np.array(sum_frac)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class RankCalculator:
|
|
229
|
+
"""Calculate gene expression ranks and create concatenated latent representations"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, config: LatentToGeneConfig):
|
|
232
|
+
"""
|
|
233
|
+
Initialize RankCalculator with configuration
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
config: LatentToGeneConfig object with all necessary parameters
|
|
237
|
+
"""
|
|
238
|
+
self.config = config
|
|
239
|
+
self.latent_dir = Path(config.latent_dir)
|
|
240
|
+
self.output_dir = Path(config.latent2gene_dir)
|
|
241
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
242
|
+
self.console = Console()
|
|
243
|
+
|
|
244
|
+
def calculate_ranks_and_concatenate(
|
|
245
|
+
self,
|
|
246
|
+
sample_h5ad_dict: dict[str, Path] | None = None,
|
|
247
|
+
annotation_key: str | None = None,
|
|
248
|
+
data_layer: str = "counts"
|
|
249
|
+
) -> dict[str, Any]:
|
|
250
|
+
"""
|
|
251
|
+
Calculate expression ranks and create concatenated latent representation
|
|
252
|
+
|
|
253
|
+
This combines the rank calculation and concatenation logic from find_latent_representation.py
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
sample_h5ad_dict: Optional dict of sample_name -> h5ad path. If None, uses config.sample_h5ad_dict
|
|
257
|
+
annotation_key: Optional annotation to filter cells. If None, uses config.annotation
|
|
258
|
+
data_layer: Data layer to use for expression
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
Dictionary with paths to:
|
|
262
|
+
- concatenated_latent_adata: Path to concatenated latent representations
|
|
263
|
+
- rank_zarr: Path to rank zarr file
|
|
264
|
+
- mean_frac: Path to mean expression fraction parquet
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
# Use provided sample_h5ad_dict or get from config
|
|
268
|
+
if sample_h5ad_dict is None:
|
|
269
|
+
sample_h5ad_dict = self.config.sample_h5ad_dict
|
|
270
|
+
|
|
271
|
+
# Use provided annotation_key or get from config
|
|
272
|
+
if annotation_key is None:
|
|
273
|
+
annotation_key = self.config.annotation
|
|
274
|
+
|
|
275
|
+
# Output paths from config
|
|
276
|
+
concat_adata_path = Path(self.config.concatenated_latent_adata_path)
|
|
277
|
+
rank_memmap_path = Path(self.config.rank_memmap_path)
|
|
278
|
+
mean_frac_path = Path(self.config.mean_frac_path)
|
|
279
|
+
|
|
280
|
+
# Check if outputs already exist
|
|
281
|
+
if concat_adata_path.exists() and mean_frac_path.exists() and MemMapDense.check_complete(rank_memmap_path)[0]:
|
|
282
|
+
logger.info(f"Rank outputs already exist in {self.output_dir}")
|
|
283
|
+
return {
|
|
284
|
+
"concatenated_latent_adata": str(concat_adata_path),
|
|
285
|
+
"rank_memmap": str(rank_memmap_path),
|
|
286
|
+
"mean_frac": str(mean_frac_path)
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
logger.info("Starting rank calculation and concatenation...")
|
|
290
|
+
logger.info(f"Processing {len(sample_h5ad_dict)} samples")
|
|
291
|
+
|
|
292
|
+
# Process each section
|
|
293
|
+
adata_list = []
|
|
294
|
+
n_total_cells = 0
|
|
295
|
+
gene_list = None
|
|
296
|
+
|
|
297
|
+
# Initialize global accumulators
|
|
298
|
+
sum_log_ranks = None
|
|
299
|
+
sum_frac = None
|
|
300
|
+
total_cells = 0
|
|
301
|
+
rank_memmap = None
|
|
302
|
+
current_row_offset = 0 # Track current position in rank memory map
|
|
303
|
+
|
|
304
|
+
# First pass: count total cells and determine which cells to keep
|
|
305
|
+
logger.info("Counting total cells across all sections...")
|
|
306
|
+
total_cells_expected = 0
|
|
307
|
+
filtering_stats = []
|
|
308
|
+
sample_keep_indices = {} # Store cell indices to keep for each sample
|
|
309
|
+
|
|
310
|
+
for sample_name, h5ad_path in sample_h5ad_dict.items():
|
|
311
|
+
# Apply same filtering logic as in main loop
|
|
312
|
+
with h5py.File(h5ad_path, 'r') as f:
|
|
313
|
+
adata_temp_obs = read_elem(f['obs'])
|
|
314
|
+
|
|
315
|
+
# Efficiently read gene counts per cell without loading full expression matrix
|
|
316
|
+
# For CSR format, we only need indptr to calculate non-zero counts
|
|
317
|
+
X_group = f['X']
|
|
318
|
+
if 'encoding-type' in X_group.attrs and X_group.attrs['encoding-type'] == 'csr_matrix':
|
|
319
|
+
# CSR format: number of non-zero genes = indptr[i+1] - indptr[i]
|
|
320
|
+
indptr = X_group['indptr'][:]
|
|
321
|
+
n_genes_per_cell = np.diff(indptr)
|
|
322
|
+
# Pass precomputed counts to filter function
|
|
323
|
+
X_temp = None
|
|
324
|
+
use_precomputed_counts = True
|
|
325
|
+
else:
|
|
326
|
+
# Not CSR or unknown format, read the full matrix
|
|
327
|
+
X_temp = read_elem(X_group)
|
|
328
|
+
n_genes_per_cell = None
|
|
329
|
+
use_precomputed_counts = False
|
|
330
|
+
|
|
331
|
+
if annotation_key is not None:
|
|
332
|
+
assert annotation_key and annotation_key in adata_temp_obs.columns, \
|
|
333
|
+
f"Annotation key '{annotation_key}' not found in the obs of {sample_name}"
|
|
334
|
+
|
|
335
|
+
# Apply unified filtering function
|
|
336
|
+
if use_precomputed_counts:
|
|
337
|
+
filtered_obs, stats, keep_mask = filter_cells(
|
|
338
|
+
adata_temp_obs,
|
|
339
|
+
X=None,
|
|
340
|
+
annotation_key=annotation_key,
|
|
341
|
+
min_cells_per_type=self.config.min_cells_per_type,
|
|
342
|
+
min_genes=20,
|
|
343
|
+
precomputed_gene_counts=n_genes_per_cell
|
|
344
|
+
)
|
|
345
|
+
else:
|
|
346
|
+
filtered_obs, stats, keep_mask = filter_cells(
|
|
347
|
+
adata_temp_obs,
|
|
348
|
+
X=X_temp,
|
|
349
|
+
annotation_key=annotation_key,
|
|
350
|
+
min_cells_per_type=self.config.min_cells_per_type,
|
|
351
|
+
min_genes=20
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Store the indices of cells to keep for this sample
|
|
355
|
+
sample_keep_indices[sample_name] = keep_mask
|
|
356
|
+
|
|
357
|
+
total_cells_expected += stats["final_cells"]
|
|
358
|
+
|
|
359
|
+
filtering_stats.append({
|
|
360
|
+
"Sample": sample_name,
|
|
361
|
+
**stats
|
|
362
|
+
})
|
|
363
|
+
|
|
364
|
+
# Display filtering summary table
|
|
365
|
+
table = Table(title="[bold]Cell Filtering Summary[/bold]", show_header=True, header_style="bold magenta")
|
|
366
|
+
table.add_column("Sample", style="dim")
|
|
367
|
+
table.add_column("Input Cells", justify="right")
|
|
368
|
+
table.add_column("NaN Removed", justify="right", style="red")
|
|
369
|
+
table.add_column("Small Group Removed", justify="right", style="red")
|
|
370
|
+
table.add_column("Low Gene Count Removed", justify="right", style="red")
|
|
371
|
+
table.add_column("Final Cells", justify="right", style="green")
|
|
372
|
+
|
|
373
|
+
for stat in filtering_stats:
|
|
374
|
+
table.add_row(
|
|
375
|
+
stat["Sample"],
|
|
376
|
+
str(stat["input_cells"]),
|
|
377
|
+
str(stat["nan_removed"]),
|
|
378
|
+
str(stat["small_group_removed"]),
|
|
379
|
+
str(stat["low_gene_count_removed"]),
|
|
380
|
+
str(stat["final_cells"])
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
# Add a total row
|
|
384
|
+
table.add_section()
|
|
385
|
+
table.add_row(
|
|
386
|
+
"[bold]Total[/bold]",
|
|
387
|
+
str(sum(s["input_cells"] for s in filtering_stats)),
|
|
388
|
+
str(sum(s["nan_removed"] for s in filtering_stats)),
|
|
389
|
+
str(sum(s["small_group_removed"] for s in filtering_stats)),
|
|
390
|
+
str(sum(s["low_gene_count_removed"] for s in filtering_stats)),
|
|
391
|
+
f"[bold green]{total_cells_expected}[/bold green]"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
self.console.print(table)
|
|
395
|
+
logger.info(f"Expected total cells after filtering: {total_cells_expected}")
|
|
396
|
+
|
|
397
|
+
# Create overall section progress tracking
|
|
398
|
+
with Progress(
|
|
399
|
+
SpinnerColumn(),
|
|
400
|
+
TextColumn("[bold blue]{task.description}"),
|
|
401
|
+
BarColumn(bar_width=None),
|
|
402
|
+
MofNCompleteColumn(),
|
|
403
|
+
TaskProgressColumn(),
|
|
404
|
+
TimeRemainingColumn(),
|
|
405
|
+
refresh_per_second=1
|
|
406
|
+
) as section_progress:
|
|
407
|
+
# Overall section progress task
|
|
408
|
+
section_task = section_progress.add_task(
|
|
409
|
+
"Processing sections",
|
|
410
|
+
total=len(sample_h5ad_dict)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
processed_cells_total = 0
|
|
414
|
+
|
|
415
|
+
for st_id, (sample_name, h5ad_path) in enumerate(sample_h5ad_dict.items()):
|
|
416
|
+
|
|
417
|
+
section_progress.console.log(f"Loading {sample_name} ({st_id + 1}/{len(sample_h5ad_dict)})...")
|
|
418
|
+
|
|
419
|
+
# Load the h5ad file (which should already contain latent representations)
|
|
420
|
+
adata = sc.read_h5ad(h5ad_path)
|
|
421
|
+
|
|
422
|
+
# Add slice information
|
|
423
|
+
adata.obs['slice_id'] = st_id
|
|
424
|
+
adata.obs['slice_name'] = sample_name
|
|
425
|
+
adata.obs['sample_name'] = sample_name
|
|
426
|
+
|
|
427
|
+
# make unique index
|
|
428
|
+
adata.obs_names_make_unique()
|
|
429
|
+
adata.obs_names = adata.obs_names.astype(str) +'|'+ adata.obs['slice_name'].astype(str)
|
|
430
|
+
|
|
431
|
+
# Apply the pre-computed filter mask from first pass
|
|
432
|
+
keep_mask = sample_keep_indices[sample_name]
|
|
433
|
+
adata = adata[keep_mask].copy()
|
|
434
|
+
|
|
435
|
+
# Get gene list (should be consistent across sections)
|
|
436
|
+
if gene_list is None:
|
|
437
|
+
gene_list = adata.var_names.tolist()
|
|
438
|
+
n_genes = len(gene_list)
|
|
439
|
+
# Initialize rank memory map as dense matrix with float16 for 50% space savings
|
|
440
|
+
# Log ranks typically have sufficient precision with float16
|
|
441
|
+
rank_memmap = MemMapDense(
|
|
442
|
+
str(rank_memmap_path),
|
|
443
|
+
shape=(total_cells_expected, n_genes),
|
|
444
|
+
dtype=np.float16, # Use float16 to save space
|
|
445
|
+
mode='w',
|
|
446
|
+
num_write_workers=self.config.mkscore_write_workers
|
|
447
|
+
)
|
|
448
|
+
# Initialize global accumulators
|
|
449
|
+
sum_log_ranks = np.zeros(n_genes, dtype=np.float64)
|
|
450
|
+
sum_frac = np.zeros(n_genes, dtype=np.float64)
|
|
451
|
+
else:
|
|
452
|
+
# Verify gene list consistency
|
|
453
|
+
assert adata.var_names.tolist() == gene_list, \
|
|
454
|
+
f"Gene list mismatch in section {st_id}"
|
|
455
|
+
|
|
456
|
+
# Get expression data for ranking
|
|
457
|
+
if data_layer in adata.layers:
|
|
458
|
+
X = adata.layers[data_layer]
|
|
459
|
+
else:
|
|
460
|
+
X = adata.X
|
|
461
|
+
|
|
462
|
+
# Efficient sparse matrix conversion
|
|
463
|
+
if not hasattr(X, 'tocsr'):
|
|
464
|
+
X = csr_matrix(X, dtype=np.float32)
|
|
465
|
+
else:
|
|
466
|
+
X = X.tocsr()
|
|
467
|
+
if X.dtype != np.float32:
|
|
468
|
+
X = X.astype(np.float32)
|
|
469
|
+
|
|
470
|
+
# Pre-allocate output arrays for efficiency
|
|
471
|
+
X.sort_indices() # Sort indices for better cache performance
|
|
472
|
+
|
|
473
|
+
# Get number of cells after filtering
|
|
474
|
+
n_cells = X.shape[0]
|
|
475
|
+
|
|
476
|
+
# Use nested progress bar for detailed chunk processing
|
|
477
|
+
with Progress(
|
|
478
|
+
SpinnerColumn(),
|
|
479
|
+
TextColumn(f"[bold blue]Ranking {sample_name} ({{task.fields[cells]}} cells)"),
|
|
480
|
+
BarColumn(bar_width=None),
|
|
481
|
+
MofNCompleteColumn(),
|
|
482
|
+
TaskProgressColumn(),
|
|
483
|
+
TextColumn("[bold green]{task.fields[speed]} cells/s"),
|
|
484
|
+
TimeRemainingColumn(),
|
|
485
|
+
refresh_per_second=2,
|
|
486
|
+
transient=True
|
|
487
|
+
) as chunk_progress:
|
|
488
|
+
# Detailed chunk progress task
|
|
489
|
+
chunk_task = chunk_progress.add_task(
|
|
490
|
+
"Processing chunks",
|
|
491
|
+
total=n_cells,
|
|
492
|
+
cells=n_cells,
|
|
493
|
+
speed="0"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Use JAX rank calculation with nested progress
|
|
497
|
+
metadata = {'name': sample_name, 'cells': n_cells, 'study_id': st_id}
|
|
498
|
+
|
|
499
|
+
batch_sum_log_ranks, batch_frac = rank_data_jax(
|
|
500
|
+
X,
|
|
501
|
+
n_genes,
|
|
502
|
+
memmap_dense=rank_memmap,
|
|
503
|
+
metadata=metadata,
|
|
504
|
+
chunk_size=self.config.rank_batch_size,
|
|
505
|
+
write_interval=self.config.rank_write_interval,
|
|
506
|
+
current_row_offset=current_row_offset,
|
|
507
|
+
progress=chunk_progress,
|
|
508
|
+
progress_task=chunk_task
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
# Update global sums
|
|
512
|
+
sum_log_ranks += batch_sum_log_ranks
|
|
513
|
+
sum_frac += batch_frac
|
|
514
|
+
total_cells += n_cells
|
|
515
|
+
current_row_offset += n_cells # Update offset for next section
|
|
516
|
+
processed_cells_total += n_cells
|
|
517
|
+
|
|
518
|
+
# Update section progress
|
|
519
|
+
section_progress.update(section_task, advance=1)
|
|
520
|
+
|
|
521
|
+
# Create minimal AnnData with empty X matrix but keep obs and obsm
|
|
522
|
+
minimal_adata = ad.AnnData(
|
|
523
|
+
X=csr_matrix((adata.n_obs, n_genes), dtype=np.float32),
|
|
524
|
+
obs=adata.obs.copy(),
|
|
525
|
+
var=pd.DataFrame(index=gene_list),
|
|
526
|
+
obsm=adata.obsm.copy() # Keep all latent representations
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
adata_list.append(minimal_adata)
|
|
530
|
+
n_total_cells += n_cells
|
|
531
|
+
|
|
532
|
+
# Clean up memory
|
|
533
|
+
del adata, X, minimal_adata
|
|
534
|
+
gc.collect()
|
|
535
|
+
|
|
536
|
+
# Close rank memory map
|
|
537
|
+
if rank_memmap is not None:
|
|
538
|
+
rank_memmap.close()
|
|
539
|
+
logger.info(f"Saved rank matrix to {rank_memmap_path}")
|
|
540
|
+
|
|
541
|
+
# Calculate mean log ranks and mean fraction
|
|
542
|
+
mean_log_ranks = sum_log_ranks / total_cells
|
|
543
|
+
mean_frac = sum_frac / total_cells
|
|
544
|
+
|
|
545
|
+
# Save mean and fraction to parquet file
|
|
546
|
+
mean_frac_df = pd.DataFrame(
|
|
547
|
+
data=dict(
|
|
548
|
+
G_Mean=mean_log_ranks,
|
|
549
|
+
frac=mean_frac,
|
|
550
|
+
gene_name=gene_list,
|
|
551
|
+
),
|
|
552
|
+
index=gene_list,
|
|
553
|
+
)
|
|
554
|
+
# Save outside the progress context
|
|
555
|
+
mean_frac_df.to_parquet(
|
|
556
|
+
mean_frac_path,
|
|
557
|
+
index=True,
|
|
558
|
+
compression="gzip",
|
|
559
|
+
)
|
|
560
|
+
logger.info(f"Mean fraction data saved to {mean_frac_path}")
|
|
561
|
+
|
|
562
|
+
# Concatenate all sections
|
|
563
|
+
if adata_list:
|
|
564
|
+
with self.console.status("[bold blue]Concatenating and saving latent representations..."):
|
|
565
|
+
concatenated_adata = ad.concat(adata_list, axis=0, join='outer', merge='same')
|
|
566
|
+
|
|
567
|
+
# Ensure the var_names are the common genes
|
|
568
|
+
concatenated_adata.var_names = gene_list
|
|
569
|
+
|
|
570
|
+
# Save concatenated adata
|
|
571
|
+
concatenated_adata.write_h5ad(concat_adata_path)
|
|
572
|
+
logger.info(f"Saved concatenated latent representations to {concat_adata_path}")
|
|
573
|
+
logger.info(f" - Total cells: {concatenated_adata.n_obs}")
|
|
574
|
+
logger.info(f" - Total genes: {concatenated_adata.n_vars}")
|
|
575
|
+
logger.info(f" - Latent representations in obsm: {list(concatenated_adata.obsm.keys())}")
|
|
576
|
+
if 'slice_id' in concatenated_adata.obs.columns:
|
|
577
|
+
logger.info(f" - Number of slices: {concatenated_adata.obs['slice_id'].nunique()}")
|
|
578
|
+
|
|
579
|
+
# Clean up
|
|
580
|
+
del adata_list, concatenated_adata
|
|
581
|
+
gc.collect()
|
|
582
|
+
|
|
583
|
+
# Final completion message
|
|
584
|
+
logger.info("Rank calculation and concatenation completed successfully")
|
|
585
|
+
|
|
586
|
+
return {
|
|
587
|
+
"concatenated_latent_adata": str(concat_adata_path),
|
|
588
|
+
"rank_memmap": str(rank_memmap_path),
|
|
589
|
+
"mean_frac": str(mean_frac_path)
|
|
590
|
+
}
|