pymagical 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ .venv/
2
+ __pycache__/
3
+ *.py[cod]
4
+ .magical_cache/
5
+ outputs/
6
+ logs/
7
+ .DS_Store
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 C. Sun
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,96 @@
1
+ Metadata-Version: 2.4
2
+ Name: pymagical
3
+ Version: 0.1.0
4
+ Summary: A Python port of the MAGICAL hierarchical Bayesian Gibbs sampler for regulatory circuit inference.
5
+ Author-email: "C. Sun" <cs9095@princeton.edu>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.10
8
+ Requires-Dist: numpy>=1.22.0
9
+ Requires-Dist: pandas>=1.4.0
10
+ Requires-Dist: pyarrow>=8.0.0
11
+ Requires-Dist: scipy>=1.8.0
12
+ Requires-Dist: statsmodels>=0.13.0
13
+ Description-Content-Type: text/markdown
14
+
15
+ # pymagical
16
+
17
+ `pymagical` is a high-performance Python port of the **MAGICAL** (Multiome Accessibility Gene Integration Calling and Looping) algorithm. It provides an automated pipeline for inferring functional regulatory circuits—triads of Transcription Factors (TFs), cis-regulatory elements (Peaks), and target Genes—from paired single-cell RNA-seq and ATAC-seq data.
18
+
19
+ ## Key Features
20
+
21
+ * **Fast & Optimized:** Significantly faster than the original MATLAB implementation (~1.9x speedup in sampling, up to 15x speedup in data loading).
22
+ * **Intelligent IO Caching:** Automatically caches large sparse matrices and genomic metadata into PyArrow-backed Parquet and NumPy formats for near-instant subsequent loads.
23
+ * **Biological Directionality:** Unlike the original version, `pymagical` automatically classifies inferred circuits as **activators (+)** or **repressors (-)** by analyzing the continuous regression weights calculated during Gibbs sampling.
24
+ * **HPC Ready:** Includes built-in support for high-memory Slurm environments and allows for detailed weight-history dumping for downstream statistical distribution analysis.
25
+
26
+ ## Installation
27
+
28
+ This project uses `uv` for environment management.
29
+
30
+ ```bash
31
+ # Clone the repository
32
+ git clone <repo-url>
33
+ cd pymagical
34
+
35
+ # Sync the environment
36
+ uv sync
37
+ ```
38
+
39
+ ## Quick Start
40
+
41
+ ### 1. Command Line Usage
42
+
43
+ Run the circuit inference directly from your terminal:
44
+
45
+ ```bash
46
+ # Run with default demo data (astrocytes) for 500 iterations
47
+ uv run python main.py --iter 500 --outdir results/
48
+
49
+ # Run with custom data and dump weight history
50
+ uv run python main.py
51
+ --iter 1000
52
+ --prefix my_sample
53
+ --rna-counts path/to/rna.txt
54
+ --atac-counts path/to/atac.txt
55
+ --dump-weights
56
+ ```
57
+
58
+ ### 2. Programmatic Usage
59
+
60
+ Integrate `pymagical` into your own Python pipelines:
61
+
62
+ ```python
63
+ from pymagical import run_magical
64
+
65
+ run_magical(
66
+ cand_gene_file="genes.txt",
67
+ cand_peak_file="peaks.txt",
68
+ # ... other file paths ...
69
+ iteration_num=500,
70
+ output_file="my_results.txt"
71
+ )
72
+ ```
73
+
74
+ ## Output Notation
75
+
76
+ The final triad list includes a biological effect annotation for every identified TF:
77
+
78
+ `TF_Name (Confidence_Probability, Overall_Effect [L_dir, B_dir])`
79
+
80
+ * **Overall Effect:** `+` (Activator) or `-` (Repressor).
81
+ * **L_dir (Looping):** Direction of Peak-to-Gene effect.
82
+ * **B_dir (Binding):** Direction of TF-to-Peak effect.
83
+
84
+ *Example:* `STAT5B (0.85, + [+,+])` indicates an 85% confident activator that opens a peak which subsequently increases gene expression.
85
+
86
+ ## Documentation
87
+
88
+ * [Methodology and Notation Details](docs/methodology.md)
89
+ * [Design Decisions and Benchmarks](docs/decisions.md)
90
+
91
+ ## Evaluation and Comparison
92
+
93
+ The `eval/` directory contains tools for comparing `pymagical` against the original MATLAB implementation and profiling performance.
94
+
95
+ * `eval/tests/compare_results.py`: Compare fidelity and performance across implementations.
96
+ * `eval/benchmarks/profile_run.py`: Profile the runtime of different execution stages.
@@ -0,0 +1,82 @@
1
+ # pymagical
2
+
3
+ `pymagical` is a high-performance Python port of the **MAGICAL** (Multiome Accessibility Gene Integration Calling and Looping) algorithm. It provides an automated pipeline for inferring functional regulatory circuits—triads of Transcription Factors (TFs), cis-regulatory elements (Peaks), and target Genes—from paired single-cell RNA-seq and ATAC-seq data.
4
+
5
+ ## Key Features
6
+
7
+ * **Fast & Optimized:** Significantly faster than the original MATLAB implementation (~1.9x speedup in sampling, up to 15x speedup in data loading).
8
+ * **Intelligent IO Caching:** Automatically caches large sparse matrices and genomic metadata into PyArrow-backed Parquet and NumPy formats for near-instant subsequent loads.
9
+ * **Biological Directionality:** Unlike the original version, `pymagical` automatically classifies inferred circuits as **activators (+)** or **repressors (-)** by analyzing the continuous regression weights calculated during Gibbs sampling.
10
+ * **HPC Ready:** Includes built-in support for high-memory Slurm environments and allows for detailed weight-history dumping for downstream statistical distribution analysis.
11
+
12
+ ## Installation
13
+
14
+ This project uses `uv` for environment management.
15
+
16
+ ```bash
17
+ # Clone the repository
18
+ git clone <repo-url>
19
+ cd pymagical
20
+
21
+ # Sync the environment
22
+ uv sync
23
+ ```
24
+
25
+ ## Quick Start
26
+
27
+ ### 1. Command Line Usage
28
+
29
+ Run the circuit inference directly from your terminal:
30
+
31
+ ```bash
32
+ # Run with default demo data (astrocytes) for 500 iterations
33
+ uv run python main.py --iter 500 --outdir results/
34
+
35
+ # Run with custom data and dump weight history
36
+ uv run python main.py
37
+ --iter 1000
38
+ --prefix my_sample
39
+ --rna-counts path/to/rna.txt
40
+ --atac-counts path/to/atac.txt
41
+ --dump-weights
42
+ ```
43
+
44
+ ### 2. Programmatic Usage
45
+
46
+ Integrate `pymagical` into your own Python pipelines:
47
+
48
+ ```python
49
+ from pymagical import run_magical
50
+
51
+ run_magical(
52
+ cand_gene_file="genes.txt",
53
+ cand_peak_file="peaks.txt",
54
+ # ... other file paths ...
55
+ iteration_num=500,
56
+ output_file="my_results.txt"
57
+ )
58
+ ```
59
+
60
+ ## Output Notation
61
+
62
+ The final triad list includes a biological effect annotation for every identified TF:
63
+
64
+ `TF_Name (Confidence_Probability, Overall_Effect [L_dir, B_dir])`
65
+
66
+ * **Overall Effect:** `+` (Activator) or `-` (Repressor).
67
+ * **L_dir (Looping):** Direction of Peak-to-Gene effect.
68
+ * **B_dir (Binding):** Direction of TF-to-Peak effect.
69
+
70
+ *Example:* `STAT5B (0.85, + [+,+])` indicates an 85% confident activator that opens a peak which subsequently increases gene expression.
71
+
72
+ ## Documentation
73
+
74
+ * [Methodology and Notation Details](docs/methodology.md)
75
+ * [Design Decisions and Benchmarks](docs/decisions.md)
76
+
77
+ ## Evaluation and Comparison
78
+
79
+ The `eval/` directory contains tools for comparing `pymagical` against the original MATLAB implementation and profiling performance.
80
+
81
+ * `eval/tests/compare_results.py`: Compare fidelity and performance across implementations.
82
+ * `eval/benchmarks/profile_run.py`: Profile the runtime of different execution stages.
@@ -0,0 +1,51 @@
1
+ # Evaluation and Benchmarking Suite
2
+
3
+ This directory contains tools for verifying the statistical fidelity and performance of the `pymagical` Python port compared to the original MATLAB implementation.
4
+
5
+ ## Directory Structure
6
+
7
+ * `tests/`: Correctness and fidelity verification.
8
+ * `test_data_loader.py`: Functional test for the caching data loader.
9
+ * `compare_results.py`: Main tool for comparing Python and MATLAB output matrices and circuits.
10
+ * `benchmarks/`: Performance profiling and resource scaling.
11
+ * `profile_run.py`: Breaks down execution time by stage (Loading, Construction, Init, Sampling).
12
+ * `plot_comparison.py`: Generates comparative bar charts between implementations.
13
+ * `scripts/`: Generic Slurm (`.sh`) and MATLAB (`.m`) runners for launching cluster jobs.
14
+
15
+ ## Usage Examples
16
+
17
+ ### 1. Compare Fidelity and Performance
18
+
19
+ After running both Python and MATLAB versions for 500 iterations, use this script to calculate correlations and speedup:
20
+
21
+ ```bash
22
+ uv run python eval/tests/compare_results.py
23
+ --ml-dir path/to/matlab/outputs
24
+ --py-dir path/to/python/outputs
25
+ --iter 500
26
+ --prefix astrocytes
27
+ ```
28
+
29
+ ### 2. Profile Python Execution Stages
30
+
31
+ To see a granular breakdown of where time is spent in the Python pipeline:
32
+
33
+ ```bash
34
+ uv run python eval/benchmarks/profile_run.py --iter 100 --output my_profile.png
35
+ ```
36
+
37
+ ### 3. Launching Cluster Jobs
38
+
39
+ Use the provided Slurm wrappers to submit jobs to the queue:
40
+
41
+ ```bash
42
+ # Submit Python job for 1000 iterations
43
+ sbatch eval/benchmarks/scripts/run_pymagical.sh 1000 /path/to/output astrocytes
44
+
45
+ # Submit MATLAB job for 1000 iterations
46
+ sbatch eval/benchmarks/scripts/run_matlab.sh 1000 astrocytes_ml_1000
47
+ ```
48
+
49
+ ## Metrics and Validation
50
+
51
+ For a detailed history of the validation results (Pearson correlations, triad recovery rates, and wall-clock speedups), see [docs/decisions.md](../docs/decisions.md).
@@ -0,0 +1,3 @@
1
+ from .magical import run_magical
2
+
3
+ __all__ = ["run_magical"]
@@ -0,0 +1,4 @@
1
+ from .cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
@@ -0,0 +1,271 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from scipy import sparse
4
+
5
+ def construct_candidate_circuits_with_tad(
6
+ common_samples,
7
+ cand_genes, cand_peaks,
8
+ rna_genes, rna_cells, rna_count_matrix,
9
+ atac_peaks, atac_cells, atac_count_matrix,
10
+ motifs, tf_peak_binding_matrix,
11
+ refseq, tad_regions
12
+ ):
13
+ print("Starting candidate circuits construction...")
14
+
15
+ # --- 1. TF-peak binding ---
16
+ # Intersect candidate peaks with scATAC peaks
17
+ # Merge on chr_num, point1, point2
18
+ atac_peaks_merge = atac_peaks[['peak_index', 'chr_num', 'point1', 'point2']].copy()
19
+ cand_peaks_merge = cand_peaks[['chr_num', 'point1', 'point2']].copy()
20
+ cand_peaks_merge['cand_idx'] = np.arange(len(cand_peaks))
21
+
22
+ merged_peaks = pd.merge(atac_peaks_merge, cand_peaks_merge, on=['chr_num', 'point1', 'point2'], how='inner')
23
+ # Using 'inner' retains order if we sort by cand_idx, but to match MATLAB 'stable' we keep first occurrence
24
+ merged_peaks = merged_peaks.drop_duplicates(subset=['chr_num', 'point1', 'point2']).sort_values('cand_idx')
25
+
26
+ bidex = merged_peaks['peak_index'].values - 1 # 0-indexed ATAC peak indices
27
+ cidex = merged_peaks['cand_idx'].values # 0-indexed candidate peak indices
28
+
29
+ curr_cand_peaks = cand_peaks.iloc[cidex].copy().reset_index(drop=True)
30
+ curr_cand_tf_binding = tf_peak_binding_matrix[bidex, :]
31
+
32
+ num_peaks = curr_cand_tf_binding.shape[0]
33
+ total_atac_peaks = len(atac_peaks)
34
+
35
+ # TF filtering
36
+ tf_num = np.array(curr_cand_tf_binding.sum(axis=0)).flatten()
37
+ tf_pct = tf_num / num_peaks
38
+
39
+ # TF enrichment: (tf_num / cand_peaks) / (total_tf_bindings / total_atac_peaks)
40
+ total_tf_bindings = np.array(tf_peak_binding_matrix.sum(axis=0)).flatten()
41
+ # Avoid division by zero
42
+ bg_freq = total_tf_bindings / total_atac_peaks
43
+ bg_freq[bg_freq == 0] = 1.0 # prevent nan
44
+ tf_enrichment_fc = (tf_num / num_peaks) / bg_freq
45
+
46
+ pct_threshold = 0.05
47
+ num_threshold = 30
48
+ enrichment_threshold = 0.8
49
+
50
+ tf_index = np.where((tf_pct > pct_threshold) & (tf_num > num_threshold) & (tf_enrichment_fc > enrichment_threshold))[0]
51
+
52
+ if len(tf_index) > 0:
53
+ cand_tfs = motifs.iloc[tf_index]['name'].values
54
+ curr_cand_tf_binding = curr_cand_tf_binding[:, tf_index]
55
+ else:
56
+ print("Too few peaks with TF binding sites. MAGICAL not applicable to this cell type!")
57
+ return None
58
+
59
+ # Filter peaks with >0 TF binding
60
+ peak_sums = np.array(curr_cand_tf_binding.sum(axis=1)).flatten()
61
+ peak_idx = np.where(peak_sums > 0)[0]
62
+
63
+ curr_cand_peaks = curr_cand_peaks.iloc[peak_idx].reset_index(drop=True)
64
+ curr_cand_tf_binding = curr_cand_tf_binding[peak_idx, :]
65
+
66
+ # --- 2. Peak-gene looping ---
67
+ # Intersect candidate genes with Refseq
68
+ refseq_dedup = refseq.drop_duplicates(subset=['gene_name'])
69
+ cand_genes_df = pd.DataFrame({'gene_symbol': cand_genes})
70
+ merged_genes = pd.merge(cand_genes_df, refseq_dedup, left_on='gene_symbol', right_on='gene_name', how='inner')
71
+
72
+ # Calculate TSS
73
+ merged_genes['gene_TSS'] = np.where(merged_genes['strand'] == '+', merged_genes['start'], merged_genes['end'])
74
+ curr_cand_genes = merged_genes['gene_symbol'].values
75
+ gene_tss = merged_genes[['chr_num', 'gene_TSS']].values
76
+
77
+ num_cand_peaks = len(curr_cand_peaks)
78
+ num_cand_genes = len(curr_cand_genes)
79
+
80
+ peak_gene_looping_tad = np.zeros((num_cand_peaks, num_cand_genes), dtype=int)
81
+
82
+ peak_chr = curr_cand_peaks['chr_num'].values
83
+ peak_center = (curr_cand_peaks['point1'].values + curr_cand_peaks['point2'].values) / 2
84
+ peak_p1 = curr_cand_peaks['point1'].values
85
+ peak_p2 = curr_cand_peaks['point2'].values
86
+
87
+ # Build TAD looping mask
88
+ for _, tad in tad_regions.iterrows():
89
+ tad_chr = tad['chr_num']
90
+ t_left = tad['left_boundary']
91
+ t_right = tad['right_boundary']
92
+
93
+ # Peaks in TAD
94
+ p_idx = np.where((peak_chr == tad_chr) & (peak_p1 > t_left) & (peak_p2 < t_right))[0]
95
+ # Genes in TAD
96
+ g_idx = np.where((gene_tss[:, 0] == tad_chr) & (gene_tss[:, 1] > t_left) & (gene_tss[:, 1] < t_right))[0]
97
+
98
+ if len(p_idx) > 0 and len(g_idx) > 0:
99
+ for p in p_idx:
100
+ peak_gene_looping_tad[p, g_idx] = 1
101
+
102
+ # Build Distance looping mask (< 1e6)
103
+ peak_gene_looping_dist = np.zeros((num_cand_peaks, num_cand_genes), dtype=int)
104
+ for g in range(num_cand_genes):
105
+ g_chr = gene_tss[g, 0]
106
+ g_tss = gene_tss[g, 1]
107
+ dist_mask = (peak_chr == g_chr) & (np.abs(peak_center - g_tss) < 1e6)
108
+ peak_gene_looping_dist[dist_mask, g] = 1
109
+
110
+ curr_cand_peak_gene_looping = peak_gene_looping_tad * peak_gene_looping_dist
111
+
112
+ # Filter peaks and genes with >0 looping
113
+ p_sums = curr_cand_peak_gene_looping.sum(axis=1)
114
+ g_sums = curr_cand_peak_gene_looping.sum(axis=0)
115
+
116
+ p_idx = np.where(p_sums > 0)[0]
117
+ g_idx = np.where(g_sums > 0)[0]
118
+
119
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[p_idx, :][:, g_idx]
120
+ curr_cand_peaks = curr_cand_peaks.iloc[p_idx].reset_index(drop=True)
121
+ curr_cand_genes = curr_cand_genes[g_idx]
122
+ gene_tss = gene_tss[g_idx, :]
123
+ curr_cand_tf_binding = curr_cand_tf_binding[p_idx, :]
124
+
125
+ # Filter TFs
126
+ tf_sums = np.array(curr_cand_tf_binding.sum(axis=0)).flatten()
127
+ tf_idx = np.where(tf_sums > 0)[0]
128
+ cand_tfs = cand_tfs[tf_idx]
129
+ curr_cand_tf_binding = curr_cand_tf_binding[:, tf_idx]
130
+
131
+ # --- 3. Pseudo-bulk calculation ---
132
+ num_samples = len(common_samples)
133
+ atac_cell_vector = np.zeros(len(atac_cells), dtype=int)
134
+ rna_cell_vector = np.zeros(len(rna_cells), dtype=int)
135
+
136
+ atac_counts = np.zeros((atac_count_matrix.shape[0], num_samples))
137
+ rna_counts = np.zeros((rna_count_matrix.shape[0], num_samples))
138
+
139
+ for s_idx, sample in enumerate(common_samples):
140
+ # ATAC
141
+ a_mask = (atac_cells['subject_ID'] == sample).values
142
+ atac_cell_vector[a_mask] = s_idx + 1 # 1-based internally for output tracking? Let's use 0-based
143
+ if np.any(a_mask):
144
+ atac_counts[:, s_idx] = np.array(atac_count_matrix[:, a_mask].sum(axis=1)).flatten()
145
+
146
+ # RNA
147
+ r_mask = (rna_cells['subject_ID'] == sample).values
148
+ rna_cell_vector[r_mask] = s_idx + 1
149
+ if np.any(r_mask):
150
+ rna_counts[:, s_idx] = np.array(rna_count_matrix[:, r_mask].sum(axis=1)).flatten()
151
+
152
+ # Use 0-based for python logic
153
+ atac_cell_vector = atac_cell_vector - 1
154
+ rna_cell_vector = rna_cell_vector - 1
155
+
156
+ total_atac_reads = 5e6
157
+ atac_raw_count_sum = atac_counts.sum(axis=0) + 1
158
+ for s in range(num_samples):
159
+ atac_counts[:, s] = atac_counts[:, s] / atac_raw_count_sum[s] * total_atac_reads
160
+ atac_log2 = np.log2(atac_counts + 1)
161
+
162
+ total_rna_reads = 5e6
163
+ rna_raw_count_sum = rna_counts.sum(axis=0) + 1
164
+ for s in range(num_samples):
165
+ rna_counts[:, s] = rna_counts[:, s] / rna_raw_count_sum[s] * total_rna_reads
166
+ rna_log2 = np.log2(rna_counts + 1)
167
+
168
+ # --- 4. Select actively accessible peaks ---
169
+ # The peaks in curr_cand_peaks need their original index in the atac matrix
170
+ # Re-merge to find original atac peak indices
171
+ atac_peaks_merge = atac_peaks[['peak_index', 'chr_num', 'point1', 'point2']].copy()
172
+ curr_cand_peaks_merge = curr_cand_peaks[['chr_num', 'point1', 'point2']].copy()
173
+ curr_cand_peaks_merge['cand_idx'] = np.arange(len(curr_cand_peaks))
174
+
175
+ merged = pd.merge(curr_cand_peaks_merge, atac_peaks_merge, on=['chr_num', 'point1', 'point2'], how='inner')
176
+ merged = merged.drop_duplicates(subset=['chr_num', 'point1', 'point2']).sort_values('cand_idx')
177
+
178
+ bidex = merged['cand_idx'].values
179
+ cidex = merged['peak_index'].values - 1
180
+
181
+ curr_cand_peaks = curr_cand_peaks.iloc[bidex].reset_index(drop=True)
182
+ curr_cand_tf_binding = curr_cand_tf_binding[bidex, :]
183
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[bidex, :]
184
+ cand_peak_log2count = atac_log2[cidex, :]
185
+ scatac_read_count_matrix = atac_count_matrix[cidex, :]
186
+
187
+ active_mask = cand_peak_log2count.sum(axis=1) > 0
188
+ curr_cand_peaks = curr_cand_peaks[active_mask].reset_index(drop=True)
189
+ curr_cand_tf_binding = curr_cand_tf_binding[active_mask, :]
190
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[active_mask, :]
191
+ cand_peak_log2count = cand_peak_log2count[active_mask, :]
192
+ scatac_read_count_matrix = scatac_read_count_matrix[active_mask, :]
193
+
194
+ # mean center
195
+ cand_peak_log2count = cand_peak_log2count - cand_peak_log2count.mean(axis=1, keepdims=True)
196
+
197
+ # --- 5. Select actively expressed genes ---
198
+ rna_genes_df = rna_genes.copy()
199
+ rna_genes_df['rna_idx'] = np.arange(len(rna_genes_df))
200
+ curr_cand_genes_df = pd.DataFrame({'gene_symbol': curr_cand_genes, 'cand_idx': np.arange(len(curr_cand_genes))})
201
+
202
+ merged = pd.merge(curr_cand_genes_df, rna_genes_df, on='gene_symbol', how='inner').sort_values('cand_idx')
203
+ bidex = merged['cand_idx'].values
204
+ cidex = merged['rna_idx'].values
205
+
206
+ curr_cand_genes = curr_cand_genes[bidex]
207
+ gene_tss = gene_tss[bidex, :]
208
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[:, bidex]
209
+ cand_gene_log2count = rna_log2[cidex, :]
210
+ scrna_read_count_matrix = rna_count_matrix[cidex, :]
211
+
212
+ active_mask = cand_gene_log2count.sum(axis=1) > 0
213
+ curr_cand_genes = curr_cand_genes[active_mask]
214
+ gene_tss = gene_tss[active_mask, :]
215
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[:, active_mask]
216
+ cand_gene_log2count = cand_gene_log2count[active_mask, :]
217
+ scrna_read_count_matrix = scrna_read_count_matrix[active_mask, :]
218
+
219
+ cand_gene_log2count = cand_gene_log2count - cand_gene_log2count.mean(axis=1, keepdims=True)
220
+
221
+ # --- 6. Select actively expressed TFs ---
222
+ cand_tfs_df = pd.DataFrame({'gene_symbol': cand_tfs, 'tf_idx': np.arange(len(cand_tfs))})
223
+ merged = pd.merge(cand_tfs_df, rna_genes_df, on='gene_symbol', how='inner').sort_values('tf_idx')
224
+
225
+ bidex = merged['tf_idx'].values
226
+ cidex = merged['rna_idx'].values
227
+
228
+ cand_tfs = cand_tfs[bidex]
229
+ curr_cand_tf_binding = curr_cand_tf_binding[:, bidex]
230
+ cand_tf_log2count = rna_log2[cidex, :]
231
+
232
+ active_mask = cand_tf_log2count.sum(axis=1) > 0
233
+ cand_tfs = cand_tfs[active_mask]
234
+ curr_cand_tf_binding = curr_cand_tf_binding[:, active_mask]
235
+ cand_tf_log2count = cand_tf_log2count[active_mask, :]
236
+
237
+ cand_tf_log2count = cand_tf_log2count - cand_tf_log2count.mean(axis=1, keepdims=True)
238
+
239
+ # --- 7. Final filter ---
240
+ p_mask = (curr_cand_peak_gene_looping.sum(axis=1) > 0) & (np.array(curr_cand_tf_binding.sum(axis=1)).flatten() > 0)
241
+ curr_cand_peaks = curr_cand_peaks[p_mask].reset_index(drop=True)
242
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[p_mask, :]
243
+ curr_cand_tf_binding = curr_cand_tf_binding[p_mask, :]
244
+ cand_peak_log2count = cand_peak_log2count[p_mask, :]
245
+ scatac_read_count_matrix = scatac_read_count_matrix[p_mask, :]
246
+
247
+ g_mask = curr_cand_peak_gene_looping.sum(axis=0) > 0
248
+ curr_cand_genes = curr_cand_genes[g_mask]
249
+ gene_tss = gene_tss[g_mask, :]
250
+ curr_cand_peak_gene_looping = curr_cand_peak_gene_looping[:, g_mask]
251
+ cand_gene_log2count = cand_gene_log2count[g_mask, :]
252
+ scrna_read_count_matrix = scrna_read_count_matrix[g_mask, :]
253
+
254
+ tf_mask = np.array(curr_cand_tf_binding.sum(axis=0)).flatten() > 0
255
+ cand_tfs = cand_tfs[tf_mask]
256
+ curr_cand_tf_binding = curr_cand_tf_binding[:, tf_mask]
257
+ cand_tf_log2count = cand_tf_log2count[tf_mask, :]
258
+
259
+ print(f"MAGICAL initially selected {len(cand_tfs)} TFs, {len(curr_cand_peaks)} peaks, and {len(curr_cand_genes)} genes for circuit inference.")
260
+
261
+ # Add TSS back to genes for returning
262
+ genes_out = {'symbols': curr_cand_genes, 'tss': gene_tss}
263
+
264
+ return (
265
+ cand_tfs, cand_tf_log2count,
266
+ curr_cand_peaks, cand_peak_log2count,
267
+ genes_out, cand_gene_log2count,
268
+ curr_cand_tf_binding, curr_cand_peak_gene_looping,
269
+ atac_cell_vector, scatac_read_count_matrix,
270
+ rna_cell_vector, scrna_read_count_matrix
271
+ )
@@ -0,0 +1,61 @@
1
+ import argparse
2
+ import os
3
+ from .magical import run_magical
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(description="pymagical - Python port of MAGICAL regulatory circuit inference.")
7
+
8
+ # Core Parameters
9
+ parser.add_argument("--iter", type=int, default=500, help="Number of Gibbs sampling iterations (default: 500)")
10
+ parser.add_argument("--outdir", type=str, default="outputs", help="Output directory for results")
11
+ parser.add_argument("--prefix", type=str, default="astrocytes", help="Prefix for output filenames (default: astrocytes)")
12
+ parser.add_argument("--dump-weights", action="store_true", help="Dump history of continuous B and L weights as .npy files")
13
+
14
+ # Input File Overrides (Defaults to astrocytes demo data)
15
+ test_dir = "/mnt/ceph/users/agebrain/anderson/snmulti_data/pymagical/test_data"
16
+ astrocytes_dir = os.path.join(test_dir, "astrocytes")
17
+
18
+ parser.add_argument("--cand-genes", type=str, default=os.path.join(astrocytes_dir, "sig_cr_genes.txt"))
19
+ parser.add_argument("--cand-peaks", type=str, default=os.path.join(astrocytes_dir, "sig_cr_peaks.txt"))
20
+ parser.add_argument("--rna-counts", type=str, default=os.path.join(astrocytes_dir, "rna_counts.txt"))
21
+ parser.add_argument("--rna-genes", type=str, default=os.path.join(astrocytes_dir, "rna_genes.txt"))
22
+ parser.add_argument("--rna-meta", type=str, default=os.path.join(astrocytes_dir, "rna_meta.txt"))
23
+ parser.add_argument("--atac-counts", type=str, default=os.path.join(astrocytes_dir, "atac_counts.txt"))
24
+ parser.add_argument("--atac-peaks", type=str, default=os.path.join(astrocytes_dir, "atac_peaks.txt"))
25
+ parser.add_argument("--atac-meta", type=str, default=os.path.join(astrocytes_dir, "atac_meta.txt"))
26
+ parser.add_argument("--motif-mapping", type=str, default=os.path.join(test_dir, "motif_prior.txt"))
27
+ parser.add_argument("--motif-info", type=str, default=os.path.join(test_dir, "motif_info.txt"))
28
+ parser.add_argument("--tad-file", type=str, default=os.path.join(test_dir, "tad_regions.txt"))
29
+ parser.add_argument("--refseq-file", type=str, default=os.path.join(test_dir, "rhemac10_refseq.txt"))
30
+
31
+ args = parser.parse_args()
32
+
33
+ os.makedirs(args.outdir, exist_ok=True)
34
+ # Use the prefix and iteration count to name the final result file, coherent with MATLAB naming
35
+ out_file = os.path.join(args.outdir, f"{args.prefix}_py_{args.iter}.txt")
36
+
37
+ print(f"Running pymagical for {args.iter} iterations...")
38
+ if args.dump_weights:
39
+ print("Weight history dump enabled.")
40
+
41
+ run_magical(
42
+ cand_gene_file=args.cand_genes,
43
+ cand_peak_file=args.cand_peaks,
44
+ rna_counts_file=args.rna_counts,
45
+ rna_genes_file=args.rna_genes,
46
+ rna_meta_file=args.rna_meta,
47
+ atac_counts_file=args.atac_counts,
48
+ atac_peaks_file=args.atac_peaks,
49
+ atac_meta_file=args.atac_meta,
50
+ motif_mapping_file=args.motif_mapping,
51
+ motif_name_file=args.motif_info,
52
+ tad_flag=1,
53
+ tad_file=args.tad_file,
54
+ refseq_file=args.refseq_file,
55
+ output_file=out_file,
56
+ iteration_num=args.iter,
57
+ dump_weight_history=args.dump_weights
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ main()