crispyx 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crispyx/__init__.py +155 -0
- crispyx/_checkpoint.py +252 -0
- crispyx/_kernels.py +2034 -0
- crispyx/_memory.py +246 -0
- crispyx/_namespaces.py +852 -0
- crispyx/_size_factors.py +376 -0
- crispyx/_statistics.py +313 -0
- crispyx/data.py +3857 -0
- crispyx/de.py +4545 -0
- crispyx/dimred.py +909 -0
- crispyx/glm.py +4795 -0
- crispyx/plotting.py +1424 -0
- crispyx/profiling.py +642 -0
- crispyx/pseudobulk.py +207 -0
- crispyx/qc.py +1638 -0
- crispyx-0.0.1.dist-info/METADATA +32 -0
- crispyx-0.0.1.dist-info/RECORD +20 -0
- crispyx-0.0.1.dist-info/WHEEL +5 -0
- crispyx-0.0.1.dist-info/licenses/LICENSE +21 -0
- crispyx-0.0.1.dist-info/top_level.txt +1 -0
crispyx/__init__.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Streamlined CRISPR screen analysis toolkit with Scanpy-style entry points."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
__version__ = version("crispyx")
|
|
9
|
+
except PackageNotFoundError:
|
|
10
|
+
__version__ = "0.0.0.dev"
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# Public API re-exports
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
from .data import (
|
|
17
|
+
AnnData,
|
|
18
|
+
OverlapResult,
|
|
19
|
+
compute_overlap,
|
|
20
|
+
convert_to_csc,
|
|
21
|
+
convert_to_csr,
|
|
22
|
+
detect_gene_symbol_column,
|
|
23
|
+
detect_perturbation_column,
|
|
24
|
+
ensure_gene_symbol_column,
|
|
25
|
+
infer_columns,
|
|
26
|
+
load_obs,
|
|
27
|
+
load_var,
|
|
28
|
+
normalize_total_log1p,
|
|
29
|
+
normalise_perturbation_labels,
|
|
30
|
+
read_h5ad_ondisk,
|
|
31
|
+
read_backed,
|
|
32
|
+
resolve_data_path,
|
|
33
|
+
standardise_gene_names,
|
|
34
|
+
write_obs,
|
|
35
|
+
write_var,
|
|
36
|
+
)
|
|
37
|
+
from .de import (
|
|
38
|
+
RankGenesGroupsResult,
|
|
39
|
+
nb_glm_test,
|
|
40
|
+
shrink_lfc,
|
|
41
|
+
t_test,
|
|
42
|
+
wilcoxon_test,
|
|
43
|
+
)
|
|
44
|
+
from .profiling import (
|
|
45
|
+
Profiler,
|
|
46
|
+
MemoryProfiler,
|
|
47
|
+
TimingProfiler,
|
|
48
|
+
plot_benchmark_comparison,
|
|
49
|
+
)
|
|
50
|
+
from .plotting import (
|
|
51
|
+
materialize_rank_genes_groups,
|
|
52
|
+
plot_ma,
|
|
53
|
+
plot_overlap_heatmap,
|
|
54
|
+
plot_pca,
|
|
55
|
+
plot_pca_loadings,
|
|
56
|
+
plot_pca_variance_ratio,
|
|
57
|
+
plot_qc_perturbation_counts,
|
|
58
|
+
plot_qc_summary,
|
|
59
|
+
plot_top_genes_bar,
|
|
60
|
+
plot_umap,
|
|
61
|
+
plot_volcano,
|
|
62
|
+
rank_genes_groups_df,
|
|
63
|
+
)
|
|
64
|
+
from .pseudobulk import (
|
|
65
|
+
compute_average_log_expression,
|
|
66
|
+
compute_pseudobulk_expression,
|
|
67
|
+
)
|
|
68
|
+
from .qc import (
|
|
69
|
+
filter_cells_by_gene_count,
|
|
70
|
+
filter_genes_by_cell_count,
|
|
71
|
+
filter_perturbations_by_cell_count,
|
|
72
|
+
quality_control_summary,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# ---------------------------------------------------------------------------
|
|
76
|
+
# Scanpy-style namespace singletons: cx.pp, cx.pb, cx.tl, cx.pl
|
|
77
|
+
# ---------------------------------------------------------------------------
|
|
78
|
+
|
|
79
|
+
from ._namespaces import (
|
|
80
|
+
_PlottingNamespace,
|
|
81
|
+
_PreprocessingNamespace,
|
|
82
|
+
_PseudobulkNamespace,
|
|
83
|
+
_ToolsNamespace,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
pp = _PreprocessingNamespace()
|
|
87
|
+
pb = _PseudobulkNamespace()
|
|
88
|
+
tl = _ToolsNamespace()
|
|
89
|
+
pl = _PlottingNamespace()
|
|
90
|
+
|
|
91
|
+
# ---------------------------------------------------------------------------
|
|
92
|
+
# __all__
|
|
93
|
+
# ---------------------------------------------------------------------------
|
|
94
|
+
|
|
95
|
+
__all__ = [
|
|
96
|
+
"__version__",
|
|
97
|
+
# Namespace singletons
|
|
98
|
+
"pp",
|
|
99
|
+
"pb",
|
|
100
|
+
"tl",
|
|
101
|
+
"pl",
|
|
102
|
+
# Quality control
|
|
103
|
+
"filter_cells_by_gene_count",
|
|
104
|
+
"filter_genes_by_cell_count",
|
|
105
|
+
"filter_perturbations_by_cell_count",
|
|
106
|
+
"quality_control_summary",
|
|
107
|
+
# Pseudo-bulk
|
|
108
|
+
"compute_average_log_expression",
|
|
109
|
+
"compute_pseudobulk_expression",
|
|
110
|
+
# Differential expression
|
|
111
|
+
"RankGenesGroupsResult",
|
|
112
|
+
"t_test",
|
|
113
|
+
"wilcoxon_test",
|
|
114
|
+
"nb_glm_test",
|
|
115
|
+
"shrink_lfc",
|
|
116
|
+
# Data utilities
|
|
117
|
+
"AnnData",
|
|
118
|
+
"ensure_gene_symbol_column",
|
|
119
|
+
"read_h5ad_ondisk",
|
|
120
|
+
"read_backed",
|
|
121
|
+
"resolve_data_path",
|
|
122
|
+
"normalize_total_log1p",
|
|
123
|
+
"convert_to_csc",
|
|
124
|
+
"convert_to_csr",
|
|
125
|
+
"load_obs",
|
|
126
|
+
"load_var",
|
|
127
|
+
"write_obs",
|
|
128
|
+
"write_var",
|
|
129
|
+
"standardise_gene_names",
|
|
130
|
+
"normalise_perturbation_labels",
|
|
131
|
+
"detect_perturbation_column",
|
|
132
|
+
"detect_gene_symbol_column",
|
|
133
|
+
"infer_columns",
|
|
134
|
+
"OverlapResult",
|
|
135
|
+
"compute_overlap",
|
|
136
|
+
# Profiling
|
|
137
|
+
"Profiler",
|
|
138
|
+
"MemoryProfiler",
|
|
139
|
+
"TimingProfiler",
|
|
140
|
+
"plot_benchmark_comparison",
|
|
141
|
+
# Plotting
|
|
142
|
+
"materialize_rank_genes_groups",
|
|
143
|
+
"rank_genes_groups_df",
|
|
144
|
+
"plot_pca",
|
|
145
|
+
"plot_pca_variance_ratio",
|
|
146
|
+
"plot_pca_loadings",
|
|
147
|
+
"plot_umap",
|
|
148
|
+
"plot_volcano",
|
|
149
|
+
"plot_ma",
|
|
150
|
+
"plot_top_genes_bar",
|
|
151
|
+
"plot_qc_perturbation_counts",
|
|
152
|
+
"plot_qc_summary",
|
|
153
|
+
"plot_overlap_heatmap",
|
|
154
|
+
]
|
|
155
|
+
|
crispyx/_checkpoint.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Checkpoint and progress utilities for streaming DE tests.
|
|
2
|
+
|
|
3
|
+
This module provides atomic checkpointing and progress tracking for
|
|
4
|
+
resumable differential expression tests.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
import h5py
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Check for tqdm availability
|
|
24
|
+
try:
|
|
25
|
+
from tqdm import tqdm as _tqdm
|
|
26
|
+
HAS_TQDM = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
HAS_TQDM = False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _write_checkpoint_atomic(
|
|
32
|
+
checkpoint_path: Path,
|
|
33
|
+
data: dict,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Write checkpoint data atomically using temp file + rename.
|
|
36
|
+
|
|
37
|
+
This ensures checkpoint file is never corrupted on crash.
|
|
38
|
+
"""
|
|
39
|
+
# Ensure parent directory exists
|
|
40
|
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
|
|
42
|
+
# Write to a temporary file in the same directory
|
|
43
|
+
tmp_path = checkpoint_path.with_suffix(".tmp")
|
|
44
|
+
try:
|
|
45
|
+
with open(tmp_path, "w") as f:
|
|
46
|
+
json.dump(data, f, indent=2)
|
|
47
|
+
# Atomic rename
|
|
48
|
+
os.rename(tmp_path, checkpoint_path)
|
|
49
|
+
except Exception:
|
|
50
|
+
# Clean up temp file on error
|
|
51
|
+
if tmp_path.exists():
|
|
52
|
+
try:
|
|
53
|
+
tmp_path.unlink()
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
raise
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _read_checkpoint(checkpoint_path: Path) -> dict | None:
|
|
60
|
+
"""Read checkpoint file, returning None if missing or corrupted.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
dict or None
|
|
65
|
+
Checkpoint data if valid, None if file is missing or corrupted.
|
|
66
|
+
"""
|
|
67
|
+
if not checkpoint_path.exists():
|
|
68
|
+
return None
|
|
69
|
+
try:
|
|
70
|
+
with open(checkpoint_path, "r") as f:
|
|
71
|
+
data = json.load(f)
|
|
72
|
+
# Validate required fields
|
|
73
|
+
if not isinstance(data, dict):
|
|
74
|
+
return None
|
|
75
|
+
if "completed" not in data or "total" not in data:
|
|
76
|
+
return None
|
|
77
|
+
return data
|
|
78
|
+
except (json.JSONDecodeError, IOError, OSError):
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _scan_h5ad_completed(
|
|
83
|
+
h5ad_path: Path,
|
|
84
|
+
all_candidates: list[str],
|
|
85
|
+
result_dataset: str = "uns/rank_genes_groups/full/scores",
|
|
86
|
+
) -> list[str]:
|
|
87
|
+
"""Scan h5ad file to detect completed perturbations by non-zero/non-NaN rows.
|
|
88
|
+
|
|
89
|
+
This is a fallback when checkpoint file is missing or corrupted.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
h5ad_path
|
|
94
|
+
Path to the output h5ad file.
|
|
95
|
+
all_candidates
|
|
96
|
+
List of all perturbation labels (in order).
|
|
97
|
+
result_dataset
|
|
98
|
+
HDF5 dataset path to check for results. Should have shape (n_groups, n_genes).
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list[str]
|
|
103
|
+
List of perturbation labels that have been completed.
|
|
104
|
+
"""
|
|
105
|
+
completed = []
|
|
106
|
+
if not h5ad_path.exists():
|
|
107
|
+
return completed
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
with h5py.File(h5ad_path, "r") as f:
|
|
111
|
+
# Try to access the result dataset
|
|
112
|
+
if result_dataset in f:
|
|
113
|
+
ds = f[result_dataset]
|
|
114
|
+
n_groups = ds.shape[0]
|
|
115
|
+
for idx in range(min(n_groups, len(all_candidates))):
|
|
116
|
+
row = ds[idx, :]
|
|
117
|
+
# Check if row has any non-NaN, non-zero values
|
|
118
|
+
if np.any(np.isfinite(row) & (row != 0)):
|
|
119
|
+
completed.append(all_candidates[idx])
|
|
120
|
+
else:
|
|
121
|
+
# Try alternative: check layers in X matrix
|
|
122
|
+
if "X" in f:
|
|
123
|
+
X = f["X"]
|
|
124
|
+
if hasattr(X, "shape") and len(X.shape) == 2:
|
|
125
|
+
n_groups = X.shape[0]
|
|
126
|
+
for idx in range(min(n_groups, len(all_candidates))):
|
|
127
|
+
row = X[idx, :]
|
|
128
|
+
if np.any(np.isfinite(row) & (row != 0)):
|
|
129
|
+
completed.append(all_candidates[idx])
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.warning(f"Failed to scan h5ad for completed perturbations: {e}")
|
|
132
|
+
|
|
133
|
+
return completed
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _get_resumable_candidates(
|
|
137
|
+
checkpoint_path: Path,
|
|
138
|
+
h5ad_path: Path,
|
|
139
|
+
all_candidates: list[str],
|
|
140
|
+
retry_failed: bool = True,
|
|
141
|
+
) -> tuple[list[str], list[str], list[str]]:
|
|
142
|
+
"""Get candidates to process, accounting for previous progress.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
checkpoint_path
|
|
147
|
+
Path to the progress JSON file.
|
|
148
|
+
h5ad_path
|
|
149
|
+
Path to the output h5ad file.
|
|
150
|
+
all_candidates
|
|
151
|
+
List of all perturbation labels to process.
|
|
152
|
+
retry_failed
|
|
153
|
+
If True, previously failed perturbations will be retried.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
tuple[list[str], list[str], list[str]]
|
|
158
|
+
(candidates_to_run, completed, failed)
|
|
159
|
+
- candidates_to_run: perturbations that need to be processed
|
|
160
|
+
- completed: perturbations already completed
|
|
161
|
+
- failed: perturbations that failed (for logging)
|
|
162
|
+
"""
|
|
163
|
+
checkpoint = _read_checkpoint(checkpoint_path)
|
|
164
|
+
|
|
165
|
+
if checkpoint is not None:
|
|
166
|
+
completed = checkpoint.get("completed", [])
|
|
167
|
+
failed = checkpoint.get("failed", [])
|
|
168
|
+
logger.info(f"Resuming: {len(completed)}/{len(all_candidates)} already completed")
|
|
169
|
+
if failed:
|
|
170
|
+
logger.info(f" {len(failed)} previously failed perturbations")
|
|
171
|
+
else:
|
|
172
|
+
# Checkpoint missing or corrupted - try scanning h5ad
|
|
173
|
+
if h5ad_path.exists():
|
|
174
|
+
logger.warning(
|
|
175
|
+
f"Checkpoint file missing or corrupted at {checkpoint_path}. "
|
|
176
|
+
f"Scanning h5ad file to detect completed perturbations..."
|
|
177
|
+
)
|
|
178
|
+
completed = _scan_h5ad_completed(h5ad_path, all_candidates)
|
|
179
|
+
failed = []
|
|
180
|
+
if completed:
|
|
181
|
+
logger.info(f"Detected {len(completed)} completed perturbations from h5ad scan")
|
|
182
|
+
else:
|
|
183
|
+
completed = []
|
|
184
|
+
failed = []
|
|
185
|
+
|
|
186
|
+
# Determine which candidates to run
|
|
187
|
+
completed_set = set(completed)
|
|
188
|
+
failed_set = set(failed) if not retry_failed else set()
|
|
189
|
+
|
|
190
|
+
candidates_to_run = [
|
|
191
|
+
c for c in all_candidates
|
|
192
|
+
if c not in completed_set and c not in failed_set
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
return candidates_to_run, completed, failed
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _get_checkpoint_interval(n_perturbations: int, checkpoint_interval: int | None) -> int:
|
|
199
|
+
"""Determine checkpoint interval based on dataset size.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
n_perturbations
|
|
204
|
+
Total number of perturbations.
|
|
205
|
+
checkpoint_interval
|
|
206
|
+
User-specified interval, or None for auto.
|
|
207
|
+
|
|
208
|
+
Returns
|
|
209
|
+
-------
|
|
210
|
+
int
|
|
211
|
+
Number of perturbations to process between checkpoints.
|
|
212
|
+
"""
|
|
213
|
+
if checkpoint_interval is not None:
|
|
214
|
+
return max(1, checkpoint_interval)
|
|
215
|
+
# Auto: every 1 for small datasets, every 10 for larger ones
|
|
216
|
+
if n_perturbations < 100:
|
|
217
|
+
return 1
|
|
218
|
+
elif n_perturbations < 1000:
|
|
219
|
+
return 10
|
|
220
|
+
else:
|
|
221
|
+
return 50
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class _DummyProgress:
|
|
225
|
+
"""Dummy progress bar that does nothing (for when verbose=False)."""
|
|
226
|
+
|
|
227
|
+
def __enter__(self):
|
|
228
|
+
return self
|
|
229
|
+
|
|
230
|
+
def __exit__(self, *args):
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
def update(self, n: int = 1):
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
def set_postfix(self, **kwargs):
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _create_progress_context(
|
|
241
|
+
total: int,
|
|
242
|
+
desc: str,
|
|
243
|
+
verbose: bool,
|
|
244
|
+
) -> "_tqdm | _DummyProgress":
|
|
245
|
+
"""Create a progress bar context manager.
|
|
246
|
+
|
|
247
|
+
Returns tqdm progress bar if verbose=True and tqdm is available,
|
|
248
|
+
otherwise returns a dummy context manager.
|
|
249
|
+
"""
|
|
250
|
+
if verbose and HAS_TQDM and total > 0:
|
|
251
|
+
return _tqdm(total=total, desc=desc, unit="perturbation")
|
|
252
|
+
return _DummyProgress()
|