crispyx 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crispyx-0.0.1/LICENSE +21 -0
- crispyx-0.0.1/PKG-INFO +32 -0
- crispyx-0.0.1/README.md +75 -0
- crispyx-0.0.1/pyproject.toml +60 -0
- crispyx-0.0.1/setup.cfg +4 -0
- crispyx-0.0.1/src/crispyx/__init__.py +155 -0
- crispyx-0.0.1/src/crispyx/_checkpoint.py +252 -0
- crispyx-0.0.1/src/crispyx/_kernels.py +2034 -0
- crispyx-0.0.1/src/crispyx/_memory.py +246 -0
- crispyx-0.0.1/src/crispyx/_namespaces.py +852 -0
- crispyx-0.0.1/src/crispyx/_size_factors.py +376 -0
- crispyx-0.0.1/src/crispyx/_statistics.py +313 -0
- crispyx-0.0.1/src/crispyx/data.py +3857 -0
- crispyx-0.0.1/src/crispyx/de.py +4545 -0
- crispyx-0.0.1/src/crispyx/dimred.py +909 -0
- crispyx-0.0.1/src/crispyx/glm.py +4795 -0
- crispyx-0.0.1/src/crispyx/plotting.py +1424 -0
- crispyx-0.0.1/src/crispyx/profiling.py +642 -0
- crispyx-0.0.1/src/crispyx/pseudobulk.py +207 -0
- crispyx-0.0.1/src/crispyx/qc.py +1638 -0
- crispyx-0.0.1/src/crispyx.egg-info/PKG-INFO +32 -0
- crispyx-0.0.1/src/crispyx.egg-info/SOURCES.txt +38 -0
- crispyx-0.0.1/src/crispyx.egg-info/dependency_links.txt +1 -0
- crispyx-0.0.1/src/crispyx.egg-info/requires.txt +28 -0
- crispyx-0.0.1/src/crispyx.egg-info/top_level.txt +1 -0
- crispyx-0.0.1/tests/test_benchmarking.py +191 -0
- crispyx-0.0.1/tests/test_convert_to_csc.py +221 -0
- crispyx-0.0.1/tests/test_convert_to_csr.py +281 -0
- crispyx-0.0.1/tests/test_data_helpers.py +526 -0
- crispyx-0.0.1/tests/test_dimred.py +819 -0
- crispyx-0.0.1/tests/test_memory_dispatch.py +324 -0
- crispyx-0.0.1/tests/test_nb_glm.py +753 -0
- crispyx-0.0.1/tests/test_normalisation.py +481 -0
- crispyx-0.0.1/tests/test_plotting.py +177 -0
- crispyx-0.0.1/tests/test_profiling.py +342 -0
- crispyx-0.0.1/tests/test_qc_parity.py +426 -0
- crispyx-0.0.1/tests/test_scanpy_parity.py +429 -0
- crispyx-0.0.1/tests/test_streaming_control.py +416 -0
- crispyx-0.0.1/tests/test_wilcoxon_dispatch.py +896 -0
- crispyx-0.0.1/tests/test_workflow.py +572 -0
crispyx-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Du Jinhong
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
crispyx-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: crispyx
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Tools for scalable CRISPR screen analysis on on-disk AnnData objects
|
|
5
|
+
Author: Streamlining CRISPR Team
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Dist: anndata>=0.9
|
|
8
|
+
Requires-Dist: numpy>=1.23
|
|
9
|
+
Requires-Dist: numba>=0.59
|
|
10
|
+
Requires-Dist: pandas>=1.5
|
|
11
|
+
Requires-Dist: scipy>=1.10
|
|
12
|
+
Requires-Dist: h5py>=3.0
|
|
13
|
+
Requires-Dist: joblib>=1.0
|
|
14
|
+
Requires-Dist: scikit-learn>=1.0
|
|
15
|
+
Requires-Dist: scanpy>=1.9.2
|
|
16
|
+
Requires-Dist: seaborn>=0.12
|
|
17
|
+
Requires-Dist: matplotlib>=3.5
|
|
18
|
+
Requires-Dist: tqdm>=4.50
|
|
19
|
+
Provides-Extra: test
|
|
20
|
+
Requires-Dist: filelock; extra == "test"
|
|
21
|
+
Requires-Dist: pytest; extra == "test"
|
|
22
|
+
Requires-Dist: statsmodels>=0.14; extra == "test"
|
|
23
|
+
Requires-Dist: pydeseq2>=0.4; extra == "test"
|
|
24
|
+
Provides-Extra: benchmark
|
|
25
|
+
Requires-Dist: pertpy>=0.4; extra == "benchmark"
|
|
26
|
+
Requires-Dist: pyyaml>=6.0; extra == "benchmark"
|
|
27
|
+
Requires-Dist: tqdm>=4.65; extra == "benchmark"
|
|
28
|
+
Requires-Dist: psutil>=5.9; extra == "benchmark"
|
|
29
|
+
Provides-Extra: docs
|
|
30
|
+
Requires-Dist: sphinx>=6.0; extra == "docs"
|
|
31
|
+
Requires-Dist: sphinx-rtd-theme>=1.2; extra == "docs"
|
|
32
|
+
Dynamic: license-file
|
crispyx-0.0.1/README.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# crispyx
|
|
2
|
+
|
|
3
|
+
## Motivation
|
|
4
|
+
|
|
5
|
+
Genome-wide CRISPR screens routinely produce datasets with hundreds of thousands of cells and tens of thousands of genes. Standard single-cell analysis toolkits (Scanpy, Pertpy) load the entire count matrix into memory, which can require 30–100+ GB of RAM and makes many screens impractical to analyse on commodity hardware or shared HPC nodes with per-job memory limits.
|
|
6
|
+
|
|
7
|
+
**crispyx** solves this by streaming data directly from on-disk AnnData (`.h5ad`) files. Quality control, normalisation, pseudo-bulk aggregation, and differential expression all operate without materialising the full matrix in memory, so even the largest screens can be processed with modest resources.
|
|
8
|
+
|
|
9
|
+
## Features
|
|
10
|
+
|
|
11
|
+
- **Streaming QC & preprocessing** – Filter cells, perturbations, and genes; normalise and log-transform; all without loading the full matrix into memory
|
|
12
|
+
- **Pseudo-bulk aggregation** – Average log expression and pseudo-bulk count matrices for effect size estimation
|
|
13
|
+
- **Differential expression** – t-test, Wilcoxon rank-sum, and negative binomial GLM with apeGLM LFC shrinkage; multi-core support and adaptive memory management
|
|
14
|
+
- **Dimension reduction** – Memory-efficient PCA and KNN graph construction on backed data
|
|
15
|
+
- **Scanpy-compatible API & plotting** – Familiar `cx.pp`, `cx.pb`, `cx.tl`, and `cx.pl` namespaces; Scanpy-style rank genes plots, volcano, MA, PCA, UMAP, QC summaries, and overlap heatmaps
|
|
16
|
+
- **Data preparation utilities** – Edit backed metadata without loading X; standardise gene names; normalise perturbation labels; auto-detect metadata columns
|
|
17
|
+
- **HPC-ready** – Resume/checkpoint for long-running jobs; configurable `memory_limit_gb`; Docker and Singularity support
|
|
18
|
+
|
|
19
|
+
## Quick Start
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
import crispyx as cx
|
|
23
|
+
|
|
24
|
+
# Open dataset without loading into memory
|
|
25
|
+
adata = cx.read_h5ad_ondisk("data/demo_benchmark.h5ad")
|
|
26
|
+
|
|
27
|
+
# Quality control with adaptive thresholds
|
|
28
|
+
adata = cx.pp.qc_summary(
|
|
29
|
+
adata,
|
|
30
|
+
perturbation_column="perturbation",
|
|
31
|
+
min_genes=5,
|
|
32
|
+
min_cells_per_perturbation=5,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Differential expression
|
|
36
|
+
adata = cx.tl.rank_genes_groups(
|
|
37
|
+
adata,
|
|
38
|
+
perturbation_column="perturbation",
|
|
39
|
+
method="wilcoxon", # or "t-test", "nb_glm"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Access results
|
|
43
|
+
print(adata.uns["rank_genes_groups"])
|
|
44
|
+
de_results = adata.uns["rank_genes_groups"].load()
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
For the full workflow (normalisation, PCA, pseudo-bulk, NB-GLM, LFC shrinkage, plotting, data preparation utilities), see the [Usage Guide](docs/usage.rst) and the [tutorial notebook](docs/crispyx_tutorial.ipynb).
|
|
48
|
+
|
|
49
|
+
## Installation
|
|
50
|
+
|
|
51
|
+
```bash
|
|
52
|
+
pip install -e .
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Benchmarking
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
cd benchmarking
|
|
59
|
+
./run_benchmark.sh config/Adamson.yaml # single dataset
|
|
60
|
+
./run_benchmark.sh config/*.yaml # all datasets
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
See [benchmarking/README.md](benchmarking/README.md) for configuration options and output structure.
|
|
64
|
+
|
|
65
|
+
## Testing
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
pytest
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## Documentation
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
sphinx-build docs docs/_build
|
|
75
|
+
```
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "crispyx"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Tools for scalable CRISPR screen analysis on on-disk AnnData objects"
|
|
9
|
+
authors = [{name = "Streamlining CRISPR Team"}]
|
|
10
|
+
dependencies = [
|
|
11
|
+
"anndata>=0.9",
|
|
12
|
+
"numpy>=1.23",
|
|
13
|
+
"numba>=0.59",
|
|
14
|
+
"pandas>=1.5",
|
|
15
|
+
"scipy>=1.10",
|
|
16
|
+
"h5py>=3.0",
|
|
17
|
+
"joblib>=1.0",
|
|
18
|
+
"scikit-learn>=1.0",
|
|
19
|
+
"scanpy>=1.9.2",
|
|
20
|
+
"seaborn>=0.12",
|
|
21
|
+
"matplotlib>=3.5",
|
|
22
|
+
"tqdm>=4.50",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
[project.optional-dependencies]
|
|
26
|
+
test = [
|
|
27
|
+
"filelock",
|
|
28
|
+
"pytest",
|
|
29
|
+
"statsmodels>=0.14",
|
|
30
|
+
"pydeseq2>=0.4",
|
|
31
|
+
]
|
|
32
|
+
benchmark = [
|
|
33
|
+
"pertpy>=0.4",
|
|
34
|
+
"pyyaml>=6.0",
|
|
35
|
+
"tqdm>=4.65",
|
|
36
|
+
"psutil>=5.9",
|
|
37
|
+
]
|
|
38
|
+
docs = [
|
|
39
|
+
"sphinx>=6.0",
|
|
40
|
+
"sphinx-rtd-theme>=1.2",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[tool.pytest.ini_options]
|
|
44
|
+
addopts = "-q"
|
|
45
|
+
filterwarnings = [
|
|
46
|
+
# Ignore h5py deprecation warnings about numpy.product (third-party issue)
|
|
47
|
+
"ignore:`product` is deprecated as of NumPy 1.25.0:DeprecationWarning",
|
|
48
|
+
# Ignore pydeseq2 FutureWarnings about dtype incompatibility (third-party issue)
|
|
49
|
+
"ignore:Setting an item of incompatible dtype is deprecated:FutureWarning",
|
|
50
|
+
# Ignore UserWarnings from third-party libraries during tests
|
|
51
|
+
"ignore:Some cells have zero counts:UserWarning",
|
|
52
|
+
"ignore:Every gene contains at least one zero:UserWarning",
|
|
53
|
+
"ignore:The dispersion trend curve fitting did not converge:UserWarning",
|
|
54
|
+
# Ignore AnnData implicit modification warnings (third-party issue)
|
|
55
|
+
"ignore:Transforming to str index:anndata.ImplicitModificationWarning",
|
|
56
|
+
# Ignore AnnData old-format warnings from test fixture files written without encoding metadata
|
|
57
|
+
"ignore:.*was written without encoding metadata:anndata.OldFormatWarning",
|
|
58
|
+
# Ignore matplotlib layout warnings in plotting tests
|
|
59
|
+
"ignore:Tight layout not applied:UserWarning",
|
|
60
|
+
]
|
crispyx-0.0.1/setup.cfg
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Streamlined CRISPR screen analysis toolkit with Scanpy-style entry points."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
__version__ = version("crispyx")
|
|
9
|
+
except PackageNotFoundError:
|
|
10
|
+
__version__ = "0.0.0.dev"
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# Public API re-exports
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
from .data import (
|
|
17
|
+
AnnData,
|
|
18
|
+
OverlapResult,
|
|
19
|
+
compute_overlap,
|
|
20
|
+
convert_to_csc,
|
|
21
|
+
convert_to_csr,
|
|
22
|
+
detect_gene_symbol_column,
|
|
23
|
+
detect_perturbation_column,
|
|
24
|
+
ensure_gene_symbol_column,
|
|
25
|
+
infer_columns,
|
|
26
|
+
load_obs,
|
|
27
|
+
load_var,
|
|
28
|
+
normalize_total_log1p,
|
|
29
|
+
normalise_perturbation_labels,
|
|
30
|
+
read_h5ad_ondisk,
|
|
31
|
+
read_backed,
|
|
32
|
+
resolve_data_path,
|
|
33
|
+
standardise_gene_names,
|
|
34
|
+
write_obs,
|
|
35
|
+
write_var,
|
|
36
|
+
)
|
|
37
|
+
from .de import (
|
|
38
|
+
RankGenesGroupsResult,
|
|
39
|
+
nb_glm_test,
|
|
40
|
+
shrink_lfc,
|
|
41
|
+
t_test,
|
|
42
|
+
wilcoxon_test,
|
|
43
|
+
)
|
|
44
|
+
from .profiling import (
|
|
45
|
+
Profiler,
|
|
46
|
+
MemoryProfiler,
|
|
47
|
+
TimingProfiler,
|
|
48
|
+
plot_benchmark_comparison,
|
|
49
|
+
)
|
|
50
|
+
from .plotting import (
|
|
51
|
+
materialize_rank_genes_groups,
|
|
52
|
+
plot_ma,
|
|
53
|
+
plot_overlap_heatmap,
|
|
54
|
+
plot_pca,
|
|
55
|
+
plot_pca_loadings,
|
|
56
|
+
plot_pca_variance_ratio,
|
|
57
|
+
plot_qc_perturbation_counts,
|
|
58
|
+
plot_qc_summary,
|
|
59
|
+
plot_top_genes_bar,
|
|
60
|
+
plot_umap,
|
|
61
|
+
plot_volcano,
|
|
62
|
+
rank_genes_groups_df,
|
|
63
|
+
)
|
|
64
|
+
from .pseudobulk import (
|
|
65
|
+
compute_average_log_expression,
|
|
66
|
+
compute_pseudobulk_expression,
|
|
67
|
+
)
|
|
68
|
+
from .qc import (
|
|
69
|
+
filter_cells_by_gene_count,
|
|
70
|
+
filter_genes_by_cell_count,
|
|
71
|
+
filter_perturbations_by_cell_count,
|
|
72
|
+
quality_control_summary,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# ---------------------------------------------------------------------------
|
|
76
|
+
# Scanpy-style namespace singletons: cx.pp, cx.pb, cx.tl, cx.pl
|
|
77
|
+
# ---------------------------------------------------------------------------
|
|
78
|
+
|
|
79
|
+
from ._namespaces import (
|
|
80
|
+
_PlottingNamespace,
|
|
81
|
+
_PreprocessingNamespace,
|
|
82
|
+
_PseudobulkNamespace,
|
|
83
|
+
_ToolsNamespace,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
pp = _PreprocessingNamespace()
|
|
87
|
+
pb = _PseudobulkNamespace()
|
|
88
|
+
tl = _ToolsNamespace()
|
|
89
|
+
pl = _PlottingNamespace()
|
|
90
|
+
|
|
91
|
+
# ---------------------------------------------------------------------------
|
|
92
|
+
# __all__
|
|
93
|
+
# ---------------------------------------------------------------------------
|
|
94
|
+
|
|
95
|
+
__all__ = [
|
|
96
|
+
"__version__",
|
|
97
|
+
# Namespace singletons
|
|
98
|
+
"pp",
|
|
99
|
+
"pb",
|
|
100
|
+
"tl",
|
|
101
|
+
"pl",
|
|
102
|
+
# Quality control
|
|
103
|
+
"filter_cells_by_gene_count",
|
|
104
|
+
"filter_genes_by_cell_count",
|
|
105
|
+
"filter_perturbations_by_cell_count",
|
|
106
|
+
"quality_control_summary",
|
|
107
|
+
# Pseudo-bulk
|
|
108
|
+
"compute_average_log_expression",
|
|
109
|
+
"compute_pseudobulk_expression",
|
|
110
|
+
# Differential expression
|
|
111
|
+
"RankGenesGroupsResult",
|
|
112
|
+
"t_test",
|
|
113
|
+
"wilcoxon_test",
|
|
114
|
+
"nb_glm_test",
|
|
115
|
+
"shrink_lfc",
|
|
116
|
+
# Data utilities
|
|
117
|
+
"AnnData",
|
|
118
|
+
"ensure_gene_symbol_column",
|
|
119
|
+
"read_h5ad_ondisk",
|
|
120
|
+
"read_backed",
|
|
121
|
+
"resolve_data_path",
|
|
122
|
+
"normalize_total_log1p",
|
|
123
|
+
"convert_to_csc",
|
|
124
|
+
"convert_to_csr",
|
|
125
|
+
"load_obs",
|
|
126
|
+
"load_var",
|
|
127
|
+
"write_obs",
|
|
128
|
+
"write_var",
|
|
129
|
+
"standardise_gene_names",
|
|
130
|
+
"normalise_perturbation_labels",
|
|
131
|
+
"detect_perturbation_column",
|
|
132
|
+
"detect_gene_symbol_column",
|
|
133
|
+
"infer_columns",
|
|
134
|
+
"OverlapResult",
|
|
135
|
+
"compute_overlap",
|
|
136
|
+
# Profiling
|
|
137
|
+
"Profiler",
|
|
138
|
+
"MemoryProfiler",
|
|
139
|
+
"TimingProfiler",
|
|
140
|
+
"plot_benchmark_comparison",
|
|
141
|
+
# Plotting
|
|
142
|
+
"materialize_rank_genes_groups",
|
|
143
|
+
"rank_genes_groups_df",
|
|
144
|
+
"plot_pca",
|
|
145
|
+
"plot_pca_variance_ratio",
|
|
146
|
+
"plot_pca_loadings",
|
|
147
|
+
"plot_umap",
|
|
148
|
+
"plot_volcano",
|
|
149
|
+
"plot_ma",
|
|
150
|
+
"plot_top_genes_bar",
|
|
151
|
+
"plot_qc_perturbation_counts",
|
|
152
|
+
"plot_qc_summary",
|
|
153
|
+
"plot_overlap_heatmap",
|
|
154
|
+
]
|
|
155
|
+
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Checkpoint and progress utilities for streaming DE tests.
|
|
2
|
+
|
|
3
|
+
This module provides atomic checkpointing and progress tracking for
|
|
4
|
+
resumable differential expression tests.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
import h5py
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Check for tqdm availability
|
|
24
|
+
try:
|
|
25
|
+
from tqdm import tqdm as _tqdm
|
|
26
|
+
HAS_TQDM = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
HAS_TQDM = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _write_checkpoint_atomic(
|
|
32
|
+
checkpoint_path: Path,
|
|
33
|
+
data: dict,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Write checkpoint data atomically using temp file + rename.
|
|
36
|
+
|
|
37
|
+
This ensures checkpoint file is never corrupted on crash.
|
|
38
|
+
"""
|
|
39
|
+
# Ensure parent directory exists
|
|
40
|
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
# Write to a temporary file in the same directory
|
|
43
|
+
tmp_path = checkpoint_path.with_suffix(".tmp")
|
|
44
|
+
try:
|
|
45
|
+
with open(tmp_path, "w") as f:
|
|
46
|
+
json.dump(data, f, indent=2)
|
|
47
|
+
# Atomic rename
|
|
48
|
+
os.rename(tmp_path, checkpoint_path)
|
|
49
|
+
except Exception:
|
|
50
|
+
# Clean up temp file on error
|
|
51
|
+
if tmp_path.exists():
|
|
52
|
+
try:
|
|
53
|
+
tmp_path.unlink()
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
raise
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _read_checkpoint(checkpoint_path: Path) -> dict | None:
|
|
60
|
+
"""Read checkpoint file, returning None if missing or corrupted.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
dict or None
|
|
65
|
+
Checkpoint data if valid, None if file is missing or corrupted.
|
|
66
|
+
"""
|
|
67
|
+
if not checkpoint_path.exists():
|
|
68
|
+
return None
|
|
69
|
+
try:
|
|
70
|
+
with open(checkpoint_path, "r") as f:
|
|
71
|
+
data = json.load(f)
|
|
72
|
+
# Validate required fields
|
|
73
|
+
if not isinstance(data, dict):
|
|
74
|
+
return None
|
|
75
|
+
if "completed" not in data or "total" not in data:
|
|
76
|
+
return None
|
|
77
|
+
return data
|
|
78
|
+
except (json.JSONDecodeError, IOError, OSError):
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _scan_h5ad_completed(
|
|
83
|
+
h5ad_path: Path,
|
|
84
|
+
all_candidates: list[str],
|
|
85
|
+
result_dataset: str = "uns/rank_genes_groups/full/scores",
|
|
86
|
+
) -> list[str]:
|
|
87
|
+
"""Scan h5ad file to detect completed perturbations by non-zero/non-NaN rows.
|
|
88
|
+
|
|
89
|
+
This is a fallback when checkpoint file is missing or corrupted.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
h5ad_path
|
|
94
|
+
Path to the output h5ad file.
|
|
95
|
+
all_candidates
|
|
96
|
+
List of all perturbation labels (in order).
|
|
97
|
+
result_dataset
|
|
98
|
+
HDF5 dataset path to check for results. Should have shape (n_groups, n_genes).
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list[str]
|
|
103
|
+
List of perturbation labels that have been completed.
|
|
104
|
+
"""
|
|
105
|
+
completed = []
|
|
106
|
+
if not h5ad_path.exists():
|
|
107
|
+
return completed
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
with h5py.File(h5ad_path, "r") as f:
|
|
111
|
+
# Try to access the result dataset
|
|
112
|
+
if result_dataset in f:
|
|
113
|
+
ds = f[result_dataset]
|
|
114
|
+
n_groups = ds.shape[0]
|
|
115
|
+
for idx in range(min(n_groups, len(all_candidates))):
|
|
116
|
+
row = ds[idx, :]
|
|
117
|
+
# Check if row has any non-NaN, non-zero values
|
|
118
|
+
if np.any(np.isfinite(row) & (row != 0)):
|
|
119
|
+
completed.append(all_candidates[idx])
|
|
120
|
+
else:
|
|
121
|
+
# Try alternative: check layers in X matrix
|
|
122
|
+
if "X" in f:
|
|
123
|
+
X = f["X"]
|
|
124
|
+
if hasattr(X, "shape") and len(X.shape) == 2:
|
|
125
|
+
n_groups = X.shape[0]
|
|
126
|
+
for idx in range(min(n_groups, len(all_candidates))):
|
|
127
|
+
row = X[idx, :]
|
|
128
|
+
if np.any(np.isfinite(row) & (row != 0)):
|
|
129
|
+
completed.append(all_candidates[idx])
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.warning(f"Failed to scan h5ad for completed perturbations: {e}")
|
|
132
|
+
|
|
133
|
+
return completed
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_resumable_candidates(
|
|
137
|
+
checkpoint_path: Path,
|
|
138
|
+
h5ad_path: Path,
|
|
139
|
+
all_candidates: list[str],
|
|
140
|
+
retry_failed: bool = True,
|
|
141
|
+
) -> tuple[list[str], list[str], list[str]]:
|
|
142
|
+
"""Get candidates to process, accounting for previous progress.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
checkpoint_path
|
|
147
|
+
Path to the progress JSON file.
|
|
148
|
+
h5ad_path
|
|
149
|
+
Path to the output h5ad file.
|
|
150
|
+
all_candidates
|
|
151
|
+
List of all perturbation labels to process.
|
|
152
|
+
retry_failed
|
|
153
|
+
If True, previously failed perturbations will be retried.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
tuple[list[str], list[str], list[str]]
|
|
158
|
+
(candidates_to_run, completed, failed)
|
|
159
|
+
- candidates_to_run: perturbations that need to be processed
|
|
160
|
+
- completed: perturbations already completed
|
|
161
|
+
- failed: perturbations that failed (for logging)
|
|
162
|
+
"""
|
|
163
|
+
checkpoint = _read_checkpoint(checkpoint_path)
|
|
164
|
+
|
|
165
|
+
if checkpoint is not None:
|
|
166
|
+
completed = checkpoint.get("completed", [])
|
|
167
|
+
failed = checkpoint.get("failed", [])
|
|
168
|
+
logger.info(f"Resuming: {len(completed)}/{len(all_candidates)} already completed")
|
|
169
|
+
if failed:
|
|
170
|
+
logger.info(f" {len(failed)} previously failed perturbations")
|
|
171
|
+
else:
|
|
172
|
+
# Checkpoint missing or corrupted - try scanning h5ad
|
|
173
|
+
if h5ad_path.exists():
|
|
174
|
+
logger.warning(
|
|
175
|
+
f"Checkpoint file missing or corrupted at {checkpoint_path}. "
|
|
176
|
+
f"Scanning h5ad file to detect completed perturbations..."
|
|
177
|
+
)
|
|
178
|
+
completed = _scan_h5ad_completed(h5ad_path, all_candidates)
|
|
179
|
+
failed = []
|
|
180
|
+
if completed:
|
|
181
|
+
logger.info(f"Detected {len(completed)} completed perturbations from h5ad scan")
|
|
182
|
+
else:
|
|
183
|
+
completed = []
|
|
184
|
+
failed = []
|
|
185
|
+
|
|
186
|
+
# Determine which candidates to run
|
|
187
|
+
completed_set = set(completed)
|
|
188
|
+
failed_set = set(failed) if not retry_failed else set()
|
|
189
|
+
|
|
190
|
+
candidates_to_run = [
|
|
191
|
+
c for c in all_candidates
|
|
192
|
+
if c not in completed_set and c not in failed_set
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
return candidates_to_run, completed, failed
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _get_checkpoint_interval(n_perturbations: int, checkpoint_interval: int | None) -> int:
|
|
199
|
+
"""Determine checkpoint interval based on dataset size.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
n_perturbations
|
|
204
|
+
Total number of perturbations.
|
|
205
|
+
checkpoint_interval
|
|
206
|
+
User-specified interval, or None for auto.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
int
|
|
211
|
+
Number of perturbations to process between checkpoints.
|
|
212
|
+
"""
|
|
213
|
+
if checkpoint_interval is not None:
|
|
214
|
+
return max(1, checkpoint_interval)
|
|
215
|
+
# Auto: every 1 for small datasets, every 10 for larger ones
|
|
216
|
+
if n_perturbations < 100:
|
|
217
|
+
return 1
|
|
218
|
+
elif n_perturbations < 1000:
|
|
219
|
+
return 10
|
|
220
|
+
else:
|
|
221
|
+
return 50
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class _DummyProgress:
|
|
225
|
+
"""Dummy progress bar that does nothing (for when verbose=False)."""
|
|
226
|
+
|
|
227
|
+
def __enter__(self):
|
|
228
|
+
return self
|
|
229
|
+
|
|
230
|
+
def __exit__(self, *args):
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
def update(self, n: int = 1):
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
def set_postfix(self, **kwargs):
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _create_progress_context(
|
|
241
|
+
total: int,
|
|
242
|
+
desc: str,
|
|
243
|
+
verbose: bool,
|
|
244
|
+
) -> "_tqdm | _DummyProgress":
|
|
245
|
+
"""Create a progress bar context manager.
|
|
246
|
+
|
|
247
|
+
Returns tqdm progress bar if verbose=True and tqdm is available,
|
|
248
|
+
otherwise returns a dummy context manager.
|
|
249
|
+
"""
|
|
250
|
+
if verbose and HAS_TQDM and total > 0:
|
|
251
|
+
return _tqdm(total=total, desc=desc, unit="perturbation")
|
|
252
|
+
return _DummyProgress()
|