crispyx 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
crispyx/__init__.py ADDED
@@ -0,0 +1,155 @@
1
+ """Streamlined CRISPR screen analysis toolkit with Scanpy-style entry points."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from importlib.metadata import PackageNotFoundError, version
6
+
7
+ try:
8
+ __version__ = version("crispyx")
9
+ except PackageNotFoundError:
10
+ __version__ = "0.0.0.dev"
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Public API re-exports
14
+ # ---------------------------------------------------------------------------
15
+
16
+ from .data import (
17
+ AnnData,
18
+ OverlapResult,
19
+ compute_overlap,
20
+ convert_to_csc,
21
+ convert_to_csr,
22
+ detect_gene_symbol_column,
23
+ detect_perturbation_column,
24
+ ensure_gene_symbol_column,
25
+ infer_columns,
26
+ load_obs,
27
+ load_var,
28
+ normalize_total_log1p,
29
+ normalise_perturbation_labels,
30
+ read_h5ad_ondisk,
31
+ read_backed,
32
+ resolve_data_path,
33
+ standardise_gene_names,
34
+ write_obs,
35
+ write_var,
36
+ )
37
+ from .de import (
38
+ RankGenesGroupsResult,
39
+ nb_glm_test,
40
+ shrink_lfc,
41
+ t_test,
42
+ wilcoxon_test,
43
+ )
44
+ from .profiling import (
45
+ Profiler,
46
+ MemoryProfiler,
47
+ TimingProfiler,
48
+ plot_benchmark_comparison,
49
+ )
50
+ from .plotting import (
51
+ materialize_rank_genes_groups,
52
+ plot_ma,
53
+ plot_overlap_heatmap,
54
+ plot_pca,
55
+ plot_pca_loadings,
56
+ plot_pca_variance_ratio,
57
+ plot_qc_perturbation_counts,
58
+ plot_qc_summary,
59
+ plot_top_genes_bar,
60
+ plot_umap,
61
+ plot_volcano,
62
+ rank_genes_groups_df,
63
+ )
64
+ from .pseudobulk import (
65
+ compute_average_log_expression,
66
+ compute_pseudobulk_expression,
67
+ )
68
+ from .qc import (
69
+ filter_cells_by_gene_count,
70
+ filter_genes_by_cell_count,
71
+ filter_perturbations_by_cell_count,
72
+ quality_control_summary,
73
+ )
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Scanpy-style namespace singletons: cx.pp, cx.pb, cx.tl, cx.pl
77
+ # ---------------------------------------------------------------------------
78
+
79
+ from ._namespaces import (
80
+ _PlottingNamespace,
81
+ _PreprocessingNamespace,
82
+ _PseudobulkNamespace,
83
+ _ToolsNamespace,
84
+ )
85
+
86
+ pp = _PreprocessingNamespace()
87
+ pb = _PseudobulkNamespace()
88
+ tl = _ToolsNamespace()
89
+ pl = _PlottingNamespace()
90
+
91
+ # ---------------------------------------------------------------------------
92
+ # __all__
93
+ # ---------------------------------------------------------------------------
94
+
95
+ __all__ = [
96
+ "__version__",
97
+ # Namespace singletons
98
+ "pp",
99
+ "pb",
100
+ "tl",
101
+ "pl",
102
+ # Quality control
103
+ "filter_cells_by_gene_count",
104
+ "filter_genes_by_cell_count",
105
+ "filter_perturbations_by_cell_count",
106
+ "quality_control_summary",
107
+ # Pseudo-bulk
108
+ "compute_average_log_expression",
109
+ "compute_pseudobulk_expression",
110
+ # Differential expression
111
+ "RankGenesGroupsResult",
112
+ "t_test",
113
+ "wilcoxon_test",
114
+ "nb_glm_test",
115
+ "shrink_lfc",
116
+ # Data utilities
117
+ "AnnData",
118
+ "ensure_gene_symbol_column",
119
+ "read_h5ad_ondisk",
120
+ "read_backed",
121
+ "resolve_data_path",
122
+ "normalize_total_log1p",
123
+ "convert_to_csc",
124
+ "convert_to_csr",
125
+ "load_obs",
126
+ "load_var",
127
+ "write_obs",
128
+ "write_var",
129
+ "standardise_gene_names",
130
+ "normalise_perturbation_labels",
131
+ "detect_perturbation_column",
132
+ "detect_gene_symbol_column",
133
+ "infer_columns",
134
+ "OverlapResult",
135
+ "compute_overlap",
136
+ # Profiling
137
+ "Profiler",
138
+ "MemoryProfiler",
139
+ "TimingProfiler",
140
+ "plot_benchmark_comparison",
141
+ # Plotting
142
+ "materialize_rank_genes_groups",
143
+ "rank_genes_groups_df",
144
+ "plot_pca",
145
+ "plot_pca_variance_ratio",
146
+ "plot_pca_loadings",
147
+ "plot_umap",
148
+ "plot_volcano",
149
+ "plot_ma",
150
+ "plot_top_genes_bar",
151
+ "plot_qc_perturbation_counts",
152
+ "plot_qc_summary",
153
+ "plot_overlap_heatmap",
154
+ ]
155
+
crispyx/_checkpoint.py ADDED
@@ -0,0 +1,252 @@
1
+ """Checkpoint and progress utilities for streaming DE tests.
2
+
3
+ This module provides atomic checkpointing and progress tracking for
4
+ resumable differential expression tests.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import os
12
+ from pathlib import Path
13
+ from typing import TYPE_CHECKING
14
+
15
+ import h5py
16
+ import numpy as np
17
+
18
+ if TYPE_CHECKING:
19
+ from tqdm import tqdm
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Check for tqdm availability
24
+ try:
25
+ from tqdm import tqdm as _tqdm
26
+ HAS_TQDM = True
27
+ except ImportError:
28
+ HAS_TQDM = False
29
+
30
+
31
+ def _write_checkpoint_atomic(
32
+ checkpoint_path: Path,
33
+ data: dict,
34
+ ) -> None:
35
+ """Write checkpoint data atomically using temp file + rename.
36
+
37
+ This ensures checkpoint file is never corrupted on crash.
38
+ """
39
+ # Ensure parent directory exists
40
+ checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
41
+
42
+ # Write to a temporary file in the same directory
43
+ tmp_path = checkpoint_path.with_suffix(".tmp")
44
+ try:
45
+ with open(tmp_path, "w") as f:
46
+ json.dump(data, f, indent=2)
47
+ # Atomic rename
48
+ os.rename(tmp_path, checkpoint_path)
49
+ except Exception:
50
+ # Clean up temp file on error
51
+ if tmp_path.exists():
52
+ try:
53
+ tmp_path.unlink()
54
+ except Exception:
55
+ pass
56
+ raise
57
+
58
+
59
+ def _read_checkpoint(checkpoint_path: Path) -> dict | None:
60
+ """Read checkpoint file, returning None if missing or corrupted.
61
+
62
+ Returns
63
+ -------
64
+ dict or None
65
+ Checkpoint data if valid, None if file is missing or corrupted.
66
+ """
67
+ if not checkpoint_path.exists():
68
+ return None
69
+ try:
70
+ with open(checkpoint_path, "r") as f:
71
+ data = json.load(f)
72
+ # Validate required fields
73
+ if not isinstance(data, dict):
74
+ return None
75
+ if "completed" not in data or "total" not in data:
76
+ return None
77
+ return data
78
+ except (json.JSONDecodeError, IOError, OSError):
79
+ return None
80
+
81
+
82
+ def _scan_h5ad_completed(
83
+ h5ad_path: Path,
84
+ all_candidates: list[str],
85
+ result_dataset: str = "uns/rank_genes_groups/full/scores",
86
+ ) -> list[str]:
87
+ """Scan h5ad file to detect completed perturbations by non-zero/non-NaN rows.
88
+
89
+ This is a fallback when checkpoint file is missing or corrupted.
90
+
91
+ Parameters
92
+ ----------
93
+ h5ad_path
94
+ Path to the output h5ad file.
95
+ all_candidates
96
+ List of all perturbation labels (in order).
97
+ result_dataset
98
+ HDF5 dataset path to check for results. Should have shape (n_groups, n_genes).
99
+
100
+ Returns
101
+ -------
102
+ list[str]
103
+ List of perturbation labels that have been completed.
104
+ """
105
+ completed = []
106
+ if not h5ad_path.exists():
107
+ return completed
108
+
109
+ try:
110
+ with h5py.File(h5ad_path, "r") as f:
111
+ # Try to access the result dataset
112
+ if result_dataset in f:
113
+ ds = f[result_dataset]
114
+ n_groups = ds.shape[0]
115
+ for idx in range(min(n_groups, len(all_candidates))):
116
+ row = ds[idx, :]
117
+ # Check if row has any non-NaN, non-zero values
118
+ if np.any(np.isfinite(row) & (row != 0)):
119
+ completed.append(all_candidates[idx])
120
+ else:
121
+ # Try alternative: check layers in X matrix
122
+ if "X" in f:
123
+ X = f["X"]
124
+ if hasattr(X, "shape") and len(X.shape) == 2:
125
+ n_groups = X.shape[0]
126
+ for idx in range(min(n_groups, len(all_candidates))):
127
+ row = X[idx, :]
128
+ if np.any(np.isfinite(row) & (row != 0)):
129
+ completed.append(all_candidates[idx])
130
+ except Exception as e:
131
+ logger.warning(f"Failed to scan h5ad for completed perturbations: {e}")
132
+
133
+ return completed
134
+
135
+
136
+ def _get_resumable_candidates(
137
+ checkpoint_path: Path,
138
+ h5ad_path: Path,
139
+ all_candidates: list[str],
140
+ retry_failed: bool = True,
141
+ ) -> tuple[list[str], list[str], list[str]]:
142
+ """Get candidates to process, accounting for previous progress.
143
+
144
+ Parameters
145
+ ----------
146
+ checkpoint_path
147
+ Path to the progress JSON file.
148
+ h5ad_path
149
+ Path to the output h5ad file.
150
+ all_candidates
151
+ List of all perturbation labels to process.
152
+ retry_failed
153
+ If True, previously failed perturbations will be retried.
154
+
155
+ Returns
156
+ -------
157
+ tuple[list[str], list[str], list[str]]
158
+ (candidates_to_run, completed, failed)
159
+ - candidates_to_run: perturbations that need to be processed
160
+ - completed: perturbations already completed
161
+ - failed: perturbations that failed (for logging)
162
+ """
163
+ checkpoint = _read_checkpoint(checkpoint_path)
164
+
165
+ if checkpoint is not None:
166
+ completed = checkpoint.get("completed", [])
167
+ failed = checkpoint.get("failed", [])
168
+ logger.info(f"Resuming: {len(completed)}/{len(all_candidates)} already completed")
169
+ if failed:
170
+ logger.info(f" {len(failed)} previously failed perturbations")
171
+ else:
172
+ # Checkpoint missing or corrupted - try scanning h5ad
173
+ if h5ad_path.exists():
174
+ logger.warning(
175
+ f"Checkpoint file missing or corrupted at {checkpoint_path}. "
176
+ f"Scanning h5ad file to detect completed perturbations..."
177
+ )
178
+ completed = _scan_h5ad_completed(h5ad_path, all_candidates)
179
+ failed = []
180
+ if completed:
181
+ logger.info(f"Detected {len(completed)} completed perturbations from h5ad scan")
182
+ else:
183
+ completed = []
184
+ failed = []
185
+
186
+ # Determine which candidates to run
187
+ completed_set = set(completed)
188
+ failed_set = set(failed) if not retry_failed else set()
189
+
190
+ candidates_to_run = [
191
+ c for c in all_candidates
192
+ if c not in completed_set and c not in failed_set
193
+ ]
194
+
195
+ return candidates_to_run, completed, failed
196
+
197
+
198
+ def _get_checkpoint_interval(n_perturbations: int, checkpoint_interval: int | None) -> int:
199
+ """Determine checkpoint interval based on dataset size.
200
+
201
+ Parameters
202
+ ----------
203
+ n_perturbations
204
+ Total number of perturbations.
205
+ checkpoint_interval
206
+ User-specified interval, or None for auto.
207
+
208
+ Returns
209
+ -------
210
+ int
211
+ Number of perturbations to process between checkpoints.
212
+ """
213
+ if checkpoint_interval is not None:
214
+ return max(1, checkpoint_interval)
215
+ # Auto: every 1 for small datasets, every 10 for larger ones
216
+ if n_perturbations < 100:
217
+ return 1
218
+ elif n_perturbations < 1000:
219
+ return 10
220
+ else:
221
+ return 50
222
+
223
+
224
+ class _DummyProgress:
225
+ """Dummy progress bar that does nothing (for when verbose=False)."""
226
+
227
+ def __enter__(self):
228
+ return self
229
+
230
+ def __exit__(self, *args):
231
+ pass
232
+
233
+ def update(self, n: int = 1):
234
+ pass
235
+
236
+ def set_postfix(self, **kwargs):
237
+ pass
238
+
239
+
240
+ def _create_progress_context(
241
+ total: int,
242
+ desc: str,
243
+ verbose: bool,
244
+ ) -> "_tqdm | _DummyProgress":
245
+ """Create a progress bar context manager.
246
+
247
+ Returns tqdm progress bar if verbose=True and tqdm is available,
248
+ otherwise returns a dummy context manager.
249
+ """
250
+ if verbose and HAS_TQDM and total > 0:
251
+ return _tqdm(total=total, desc=desc, unit="perturbation")
252
+ return _DummyProgress()