gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,590 @@
1
+ """
2
+ Rank calculation from latent representations
3
+ Extracts and processes the rank calculation logic from find_latent_representation.py
4
+ """
5
+
6
+ import gc
7
+ import logging
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import anndata as ad
13
+ import h5py
14
+ import jax
15
+ import jax.numpy as jnp
16
+ import jax.scipy
17
+ import numpy as np
18
+ import pandas as pd
19
+ import scanpy as sc
20
+ from anndata._io.specs import read_elem
21
+ from jax import jit
22
+ from rich.console import Console
23
+ from rich.progress import (
24
+ BarColumn,
25
+ MofNCompleteColumn,
26
+ Progress,
27
+ SpinnerColumn,
28
+ TaskProgressColumn,
29
+ TextColumn,
30
+ TimeRemainingColumn,
31
+ )
32
+ from rich.table import Table
33
+ from scipy.sparse import csr_matrix
34
+
35
+ from gsMap.config import LatentToGeneConfig
36
+
37
+ from .memmap_io import MemMapDense
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def filter_cells(obs_df, X=None, annotation_key=None, min_cells_per_type=None, min_genes=20, precomputed_gene_counts=None):
43
+ """
44
+ Apply cell filtering based on annotation and gene expression criteria.
45
+
46
+ Args:
47
+ obs_df: DataFrame with cell metadata
48
+ X: Optional expression matrix (sparse or dense) for gene count filtering
49
+ annotation_key: Optional annotation column to filter by
50
+ min_cells_per_type: Minimum cells required per annotation group
51
+ min_genes: Minimum number of genes expressed per cell (default: 20)
52
+ precomputed_gene_counts: Optional precomputed gene counts per cell (e.g., from CSR indptr)
53
+
54
+ Returns:
55
+ Tuple of (filtered_obs_df, filtering_stats_dict, keep_mask)
56
+ """
57
+ n_cells_before = obs_df.shape[0]
58
+ stats = {
59
+ "input_cells": n_cells_before,
60
+ "nan_removed": 0,
61
+ "small_group_removed": 0,
62
+ "low_gene_count_removed": 0,
63
+ "final_cells": 0
64
+ }
65
+
66
+ # Create a mask for cells to keep (start with all True)
67
+ keep_mask = pd.Series(True, index=obs_df.index)
68
+
69
+ # Filter 1: Remove cells with NaN in annotation_key
70
+ if annotation_key is not None and annotation_key in obs_df.columns:
71
+ nan_mask = obs_df[annotation_key].notna()
72
+ stats["nan_removed"] = (~nan_mask).sum()
73
+ keep_mask &= nan_mask
74
+
75
+ # Filter 2: Remove cells from small annotation groups
76
+ if annotation_key is not None and annotation_key in obs_df.columns and min_cells_per_type is not None:
77
+ # Apply on the current filtered data
78
+ temp_obs = obs_df[keep_mask]
79
+ annotation_counts = temp_obs[annotation_key].value_counts()
80
+ valid_annotations = annotation_counts[annotation_counts >= min_cells_per_type].index
81
+
82
+ if len(valid_annotations) < len(annotation_counts):
83
+ small_group_mask = obs_df[annotation_key].isin(valid_annotations)
84
+ stats["small_group_removed"] = keep_mask.sum() - (keep_mask & small_group_mask).sum()
85
+ keep_mask &= small_group_mask
86
+
87
+ # Filter 3: Remove cells with too few genes
88
+ if min_genes is not None:
89
+ if precomputed_gene_counts is not None:
90
+ # Use precomputed gene counts (efficient for CSR format from indptr)
91
+ nfeature_mask = precomputed_gene_counts >= min_genes
92
+ nfeature_mask = pd.Series(nfeature_mask, index=obs_df.index)
93
+ elif X is not None:
94
+ nfeature_mask, n_genes_per_cell = sc.pp.filter_cells(
95
+ X,
96
+ min_genes=min_genes,
97
+ inplace=False
98
+ )
99
+ # Convert to pandas Series for consistent indexing
100
+ nfeature_mask = pd.Series(nfeature_mask, index=obs_df.index)
101
+ else:
102
+ # No gene count information available, skip this filter
103
+ nfeature_mask = pd.Series(True, index=obs_df.index)
104
+
105
+ stats["low_gene_count_removed"] = keep_mask.sum() - (keep_mask & nfeature_mask).sum()
106
+ keep_mask &= nfeature_mask
107
+
108
+ # Apply the combined filter
109
+ filtered_obs = obs_df[keep_mask].copy()
110
+ stats["final_cells"] = filtered_obs.shape[0]
111
+
112
+ return filtered_obs, stats, keep_mask
113
+
114
+
115
+ @jit
116
+ def jax_process_chunk(dense_matrix, n_genes):
117
+
118
+ nonzero_mask = dense_matrix != 0
119
+
120
+ ranks = jax.scipy.stats.rankdata(dense_matrix, method='average', axis=1)
121
+ log_ranks = jnp.log(ranks)
122
+ # Sum log ranks (with fill_zero)
123
+ sum_log_ranks = log_ranks.sum(axis=0)
124
+ # Sum fraction (count of non-zeros)
125
+ sum_frac = nonzero_mask.sum(axis=0)
126
+
127
+ return log_ranks, sum_log_ranks, sum_frac
128
+
129
+
130
+ def rank_data_jax(X: csr_matrix, n_genes,
131
+ memmap_dense=None,
132
+ metadata: dict[str, Any] | None = None,
133
+ chunk_size: int = 1000,
134
+ write_interval: int = 10,
135
+ current_row_offset: int = 0,
136
+ progress=None,
137
+ progress_task=None
138
+ ):
139
+ """JAX-optimized rank calculation with batched writing to memory-mapped storage.
140
+
141
+ Args:
142
+ X: Input sparse matrix
143
+ n_genes: Total number of genes
144
+ memmap_dense: Optional MemMapDense instance for writing
145
+ metadata: Optional metadata dictionary
146
+ chunk_size: Size of chunks for processing
147
+ write_interval: How often to write chunks to memory map
148
+ current_row_offset: Offset for writing to memory map (for multiple sections)
149
+ progress: Progress instance for updates
150
+ progress_task: Task ID for progress updates
151
+
152
+ Returns:
153
+ Tuple of (sum_log_ranks, sum_frac) as numpy arrays
154
+ """
155
+ assert X.nnz != 0, "Input matrix must not be empty"
156
+
157
+ n_rows, n_cols = X.shape
158
+
159
+ # Initialize accumulators (use float32 for accumulators to avoid precision loss)
160
+ sum_log_ranks = jnp.zeros(n_genes, dtype=jnp.float32)
161
+ sum_frac = jnp.zeros(n_genes, dtype=jnp.float32)
162
+
163
+ # Process in chunks to manage memory
164
+ chunk_size = min(chunk_size, n_rows)
165
+ pending_chunks = [] # Buffer for batching writes
166
+ pending_indices = [] # Track global indices for writing
167
+ chunks_processed = 0
168
+
169
+ # Track speed at chunk level
170
+ chunk_start_time = time.time()
171
+ processed_cells = 0
172
+
173
+ for start_idx in range(0, n_rows, chunk_size):
174
+ end_idx = min(start_idx + chunk_size, n_rows)
175
+
176
+ # Convert chunk to dense - use JAX asarray for zero-copy when possible
177
+ chunk_X = X[start_idx:end_idx]
178
+ chunk_dense = chunk_X.toarray().astype(np.float32)
179
+ chunk_jax = jnp.asarray(chunk_dense) # Use asarray for zero-copy conversion
180
+
181
+ # Process chunk with JIT-compiled function (ranking + accumulators)
182
+ chunk_log_ranks, chunk_sum_log_ranks, chunk_sum_frac = jax_process_chunk(chunk_jax, n_genes)
183
+
184
+ # Update global accumulators
185
+ sum_log_ranks += chunk_sum_log_ranks
186
+ sum_frac += chunk_sum_frac
187
+
188
+ # Convert JAX array to numpy float16 for storage efficiency
189
+ # This reduces memory usage by 50% compared to float32
190
+ chunk_log_ranks_np = np.array(chunk_log_ranks, dtype=np.float16)
191
+ pending_chunks.append(chunk_log_ranks_np)
192
+ # Calculate global indices for this chunk
193
+ global_start = current_row_offset + start_idx
194
+ global_end = current_row_offset + end_idx
195
+ pending_indices.append((global_start, global_end))
196
+ chunks_processed += 1
197
+
198
+ # Write to memory map periodically
199
+ if memmap_dense and chunks_processed % write_interval == 0:
200
+ # Combine pending chunks for batch write
201
+ combined_data = np.vstack(pending_chunks)
202
+ # Calculate row indices for batch write
203
+ start_row = pending_indices[0][0]
204
+ end_row = pending_indices[-1][1]
205
+ # Write as a contiguous block
206
+ memmap_dense.write_batch(combined_data, row_indices=slice(start_row, end_row))
207
+ pending_chunks.clear()
208
+ pending_indices.clear()
209
+
210
+ # Update progress bar with speed calculation
211
+ if progress and progress_task is not None:
212
+ chunk_cells = end_idx - start_idx
213
+ processed_cells += chunk_cells
214
+ elapsed_time = time.time() - chunk_start_time
215
+ speed = processed_cells / elapsed_time if elapsed_time > 0 else 0
216
+ progress.update(progress_task, advance=chunk_cells, speed=f"{speed:.0f}")
217
+
218
+ # Write any remaining chunks
219
+ if memmap_dense and pending_chunks:
220
+ combined_data = np.vstack(pending_chunks)
221
+ start_row = pending_indices[0][0]
222
+ end_row = pending_indices[-1][1]
223
+ memmap_dense.write_batch(combined_data, row_indices=slice(start_row, end_row))
224
+
225
+ return np.array(sum_log_ranks), np.array(sum_frac)
226
+
227
+
228
+ class RankCalculator:
229
+ """Calculate gene expression ranks and create concatenated latent representations"""
230
+
231
+ def __init__(self, config: LatentToGeneConfig):
232
+ """
233
+ Initialize RankCalculator with configuration
234
+
235
+ Args:
236
+ config: LatentToGeneConfig object with all necessary parameters
237
+ """
238
+ self.config = config
239
+ self.latent_dir = Path(config.latent_dir)
240
+ self.output_dir = Path(config.latent2gene_dir)
241
+ self.output_dir.mkdir(parents=True, exist_ok=True)
242
+ self.console = Console()
243
+
244
+ def calculate_ranks_and_concatenate(
245
+ self,
246
+ sample_h5ad_dict: dict[str, Path] | None = None,
247
+ annotation_key: str | None = None,
248
+ data_layer: str = "counts"
249
+ ) -> dict[str, Any]:
250
+ """
251
+ Calculate expression ranks and create concatenated latent representation
252
+
253
+ This combines the rank calculation and concatenation logic from find_latent_representation.py
254
+
255
+ Args:
256
+ sample_h5ad_dict: Optional dict of sample_name -> h5ad path. If None, uses config.sample_h5ad_dict
257
+ annotation_key: Optional annotation to filter cells. If None, uses config.annotation
258
+ data_layer: Data layer to use for expression
259
+
260
+ Returns:
261
+ Dictionary with paths to:
262
+ - concatenated_latent_adata: Path to concatenated latent representations
263
+ - rank_zarr: Path to rank zarr file
264
+ - mean_frac: Path to mean expression fraction parquet
265
+ """
266
+
267
+ # Use provided sample_h5ad_dict or get from config
268
+ if sample_h5ad_dict is None:
269
+ sample_h5ad_dict = self.config.sample_h5ad_dict
270
+
271
+ # Use provided annotation_key or get from config
272
+ if annotation_key is None:
273
+ annotation_key = self.config.annotation
274
+
275
+ # Output paths from config
276
+ concat_adata_path = Path(self.config.concatenated_latent_adata_path)
277
+ rank_memmap_path = Path(self.config.rank_memmap_path)
278
+ mean_frac_path = Path(self.config.mean_frac_path)
279
+
280
+ # Check if outputs already exist
281
+ if concat_adata_path.exists() and mean_frac_path.exists() and MemMapDense.check_complete(rank_memmap_path)[0]:
282
+ logger.info(f"Rank outputs already exist in {self.output_dir}")
283
+ return {
284
+ "concatenated_latent_adata": str(concat_adata_path),
285
+ "rank_memmap": str(rank_memmap_path),
286
+ "mean_frac": str(mean_frac_path)
287
+ }
288
+
289
+ logger.info("Starting rank calculation and concatenation...")
290
+ logger.info(f"Processing {len(sample_h5ad_dict)} samples")
291
+
292
+ # Process each section
293
+ adata_list = []
294
+ n_total_cells = 0
295
+ gene_list = None
296
+
297
+ # Initialize global accumulators
298
+ sum_log_ranks = None
299
+ sum_frac = None
300
+ total_cells = 0
301
+ rank_memmap = None
302
+ current_row_offset = 0 # Track current position in rank memory map
303
+
304
+ # First pass: count total cells and determine which cells to keep
305
+ logger.info("Counting total cells across all sections...")
306
+ total_cells_expected = 0
307
+ filtering_stats = []
308
+ sample_keep_indices = {} # Store cell indices to keep for each sample
309
+
310
+ for sample_name, h5ad_path in sample_h5ad_dict.items():
311
+ # Apply same filtering logic as in main loop
312
+ with h5py.File(h5ad_path, 'r') as f:
313
+ adata_temp_obs = read_elem(f['obs'])
314
+
315
+ # Efficiently read gene counts per cell without loading full expression matrix
316
+ # For CSR format, we only need indptr to calculate non-zero counts
317
+ X_group = f['X']
318
+ if 'encoding-type' in X_group.attrs and X_group.attrs['encoding-type'] == 'csr_matrix':
319
+ # CSR format: number of non-zero genes = indptr[i+1] - indptr[i]
320
+ indptr = X_group['indptr'][:]
321
+ n_genes_per_cell = np.diff(indptr)
322
+ # Pass precomputed counts to filter function
323
+ X_temp = None
324
+ use_precomputed_counts = True
325
+ else:
326
+ # Not CSR or unknown format, read the full matrix
327
+ X_temp = read_elem(X_group)
328
+ n_genes_per_cell = None
329
+ use_precomputed_counts = False
330
+
331
+ if annotation_key is not None:
332
+ assert annotation_key and annotation_key in adata_temp_obs.columns, \
333
+ f"Annotation key '{annotation_key}' not found in the obs of {sample_name}"
334
+
335
+ # Apply unified filtering function
336
+ if use_precomputed_counts:
337
+ filtered_obs, stats, keep_mask = filter_cells(
338
+ adata_temp_obs,
339
+ X=None,
340
+ annotation_key=annotation_key,
341
+ min_cells_per_type=self.config.min_cells_per_type,
342
+ min_genes=20,
343
+ precomputed_gene_counts=n_genes_per_cell
344
+ )
345
+ else:
346
+ filtered_obs, stats, keep_mask = filter_cells(
347
+ adata_temp_obs,
348
+ X=X_temp,
349
+ annotation_key=annotation_key,
350
+ min_cells_per_type=self.config.min_cells_per_type,
351
+ min_genes=20
352
+ )
353
+
354
+ # Store the indices of cells to keep for this sample
355
+ sample_keep_indices[sample_name] = keep_mask
356
+
357
+ total_cells_expected += stats["final_cells"]
358
+
359
+ filtering_stats.append({
360
+ "Sample": sample_name,
361
+ **stats
362
+ })
363
+
364
+ # Display filtering summary table
365
+ table = Table(title="[bold]Cell Filtering Summary[/bold]", show_header=True, header_style="bold magenta")
366
+ table.add_column("Sample", style="dim")
367
+ table.add_column("Input Cells", justify="right")
368
+ table.add_column("NaN Removed", justify="right", style="red")
369
+ table.add_column("Small Group Removed", justify="right", style="red")
370
+ table.add_column("Low Gene Count Removed", justify="right", style="red")
371
+ table.add_column("Final Cells", justify="right", style="green")
372
+
373
+ for stat in filtering_stats:
374
+ table.add_row(
375
+ stat["Sample"],
376
+ str(stat["input_cells"]),
377
+ str(stat["nan_removed"]),
378
+ str(stat["small_group_removed"]),
379
+ str(stat["low_gene_count_removed"]),
380
+ str(stat["final_cells"])
381
+ )
382
+
383
+ # Add a total row
384
+ table.add_section()
385
+ table.add_row(
386
+ "[bold]Total[/bold]",
387
+ str(sum(s["input_cells"] for s in filtering_stats)),
388
+ str(sum(s["nan_removed"] for s in filtering_stats)),
389
+ str(sum(s["small_group_removed"] for s in filtering_stats)),
390
+ str(sum(s["low_gene_count_removed"] for s in filtering_stats)),
391
+ f"[bold green]{total_cells_expected}[/bold green]"
392
+ )
393
+
394
+ self.console.print(table)
395
+ logger.info(f"Expected total cells after filtering: {total_cells_expected}")
396
+
397
+ # Create overall section progress tracking
398
+ with Progress(
399
+ SpinnerColumn(),
400
+ TextColumn("[bold blue]{task.description}"),
401
+ BarColumn(bar_width=None),
402
+ MofNCompleteColumn(),
403
+ TaskProgressColumn(),
404
+ TimeRemainingColumn(),
405
+ refresh_per_second=1
406
+ ) as section_progress:
407
+ # Overall section progress task
408
+ section_task = section_progress.add_task(
409
+ "Processing sections",
410
+ total=len(sample_h5ad_dict)
411
+ )
412
+
413
+ processed_cells_total = 0
414
+
415
+ for st_id, (sample_name, h5ad_path) in enumerate(sample_h5ad_dict.items()):
416
+
417
+ section_progress.console.log(f"Loading {sample_name} ({st_id + 1}/{len(sample_h5ad_dict)})...")
418
+
419
+ # Load the h5ad file (which should already contain latent representations)
420
+ adata = sc.read_h5ad(h5ad_path)
421
+
422
+ # Add slice information
423
+ adata.obs['slice_id'] = st_id
424
+ adata.obs['slice_name'] = sample_name
425
+ adata.obs['sample_name'] = sample_name
426
+
427
+ # make unique index
428
+ adata.obs_names_make_unique()
429
+ adata.obs_names = adata.obs_names.astype(str) +'|'+ adata.obs['slice_name'].astype(str)
430
+
431
+ # Apply the pre-computed filter mask from first pass
432
+ keep_mask = sample_keep_indices[sample_name]
433
+ adata = adata[keep_mask].copy()
434
+
435
+ # Get gene list (should be consistent across sections)
436
+ if gene_list is None:
437
+ gene_list = adata.var_names.tolist()
438
+ n_genes = len(gene_list)
439
+ # Initialize rank memory map as dense matrix with float16 for 50% space savings
440
+ # Log ranks typically have sufficient precision with float16
441
+ rank_memmap = MemMapDense(
442
+ str(rank_memmap_path),
443
+ shape=(total_cells_expected, n_genes),
444
+ dtype=np.float16, # Use float16 to save space
445
+ mode='w',
446
+ num_write_workers=self.config.mkscore_write_workers
447
+ )
448
+ # Initialize global accumulators
449
+ sum_log_ranks = np.zeros(n_genes, dtype=np.float64)
450
+ sum_frac = np.zeros(n_genes, dtype=np.float64)
451
+ else:
452
+ # Verify gene list consistency
453
+ assert adata.var_names.tolist() == gene_list, \
454
+ f"Gene list mismatch in section {st_id}"
455
+
456
+ # Get expression data for ranking
457
+ if data_layer in adata.layers:
458
+ X = adata.layers[data_layer]
459
+ else:
460
+ X = adata.X
461
+
462
+ # Efficient sparse matrix conversion
463
+ if not hasattr(X, 'tocsr'):
464
+ X = csr_matrix(X, dtype=np.float32)
465
+ else:
466
+ X = X.tocsr()
467
+ if X.dtype != np.float32:
468
+ X = X.astype(np.float32)
469
+
470
+ # Pre-allocate output arrays for efficiency
471
+ X.sort_indices() # Sort indices for better cache performance
472
+
473
+ # Get number of cells after filtering
474
+ n_cells = X.shape[0]
475
+
476
+ # Use nested progress bar for detailed chunk processing
477
+ with Progress(
478
+ SpinnerColumn(),
479
+ TextColumn(f"[bold blue]Ranking {sample_name} ({{task.fields[cells]}} cells)"),
480
+ BarColumn(bar_width=None),
481
+ MofNCompleteColumn(),
482
+ TaskProgressColumn(),
483
+ TextColumn("[bold green]{task.fields[speed]} cells/s"),
484
+ TimeRemainingColumn(),
485
+ refresh_per_second=2,
486
+ transient=True
487
+ ) as chunk_progress:
488
+ # Detailed chunk progress task
489
+ chunk_task = chunk_progress.add_task(
490
+ "Processing chunks",
491
+ total=n_cells,
492
+ cells=n_cells,
493
+ speed="0"
494
+ )
495
+
496
+ # Use JAX rank calculation with nested progress
497
+ metadata = {'name': sample_name, 'cells': n_cells, 'study_id': st_id}
498
+
499
+ batch_sum_log_ranks, batch_frac = rank_data_jax(
500
+ X,
501
+ n_genes,
502
+ memmap_dense=rank_memmap,
503
+ metadata=metadata,
504
+ chunk_size=self.config.rank_batch_size,
505
+ write_interval=self.config.rank_write_interval,
506
+ current_row_offset=current_row_offset,
507
+ progress=chunk_progress,
508
+ progress_task=chunk_task
509
+ )
510
+
511
+ # Update global sums
512
+ sum_log_ranks += batch_sum_log_ranks
513
+ sum_frac += batch_frac
514
+ total_cells += n_cells
515
+ current_row_offset += n_cells # Update offset for next section
516
+ processed_cells_total += n_cells
517
+
518
+ # Update section progress
519
+ section_progress.update(section_task, advance=1)
520
+
521
+ # Create minimal AnnData with empty X matrix but keep obs and obsm
522
+ minimal_adata = ad.AnnData(
523
+ X=csr_matrix((adata.n_obs, n_genes), dtype=np.float32),
524
+ obs=adata.obs.copy(),
525
+ var=pd.DataFrame(index=gene_list),
526
+ obsm=adata.obsm.copy() # Keep all latent representations
527
+ )
528
+
529
+ adata_list.append(minimal_adata)
530
+ n_total_cells += n_cells
531
+
532
+ # Clean up memory
533
+ del adata, X, minimal_adata
534
+ gc.collect()
535
+
536
+ # Close rank memory map
537
+ if rank_memmap is not None:
538
+ rank_memmap.close()
539
+ logger.info(f"Saved rank matrix to {rank_memmap_path}")
540
+
541
+ # Calculate mean log ranks and mean fraction
542
+ mean_log_ranks = sum_log_ranks / total_cells
543
+ mean_frac = sum_frac / total_cells
544
+
545
+ # Save mean and fraction to parquet file
546
+ mean_frac_df = pd.DataFrame(
547
+ data=dict(
548
+ G_Mean=mean_log_ranks,
549
+ frac=mean_frac,
550
+ gene_name=gene_list,
551
+ ),
552
+ index=gene_list,
553
+ )
554
+ # Save outside the progress context
555
+ mean_frac_df.to_parquet(
556
+ mean_frac_path,
557
+ index=True,
558
+ compression="gzip",
559
+ )
560
+ logger.info(f"Mean fraction data saved to {mean_frac_path}")
561
+
562
+ # Concatenate all sections
563
+ if adata_list:
564
+ with self.console.status("[bold blue]Concatenating and saving latent representations..."):
565
+ concatenated_adata = ad.concat(adata_list, axis=0, join='outer', merge='same')
566
+
567
+ # Ensure the var_names are the common genes
568
+ concatenated_adata.var_names = gene_list
569
+
570
+ # Save concatenated adata
571
+ concatenated_adata.write_h5ad(concat_adata_path)
572
+ logger.info(f"Saved concatenated latent representations to {concat_adata_path}")
573
+ logger.info(f" - Total cells: {concatenated_adata.n_obs}")
574
+ logger.info(f" - Total genes: {concatenated_adata.n_vars}")
575
+ logger.info(f" - Latent representations in obsm: {list(concatenated_adata.obsm.keys())}")
576
+ if 'slice_id' in concatenated_adata.obs.columns:
577
+ logger.info(f" - Number of slices: {concatenated_adata.obs['slice_id'].nunique()}")
578
+
579
+ # Clean up
580
+ del adata_list, concatenated_adata
581
+ gc.collect()
582
+
583
+ # Final completion message
584
+ logger.info("Rank calculation and concatenation completed successfully")
585
+
586
+ return {
587
+ "concatenated_latent_adata": str(concat_adata_path),
588
+ "rank_memmap": str(rank_memmap_path),
589
+ "mean_frac": str(mean_frac_path)
590
+ }