torchflat 0.8.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchflat-0.8.0/LICENSE +21 -0
- torchflat-0.8.0/PKG-INFO +234 -0
- torchflat-0.8.0/README.md +201 -0
- torchflat-0.8.0/pyproject.toml +56 -0
- torchflat-0.8.0/setup.cfg +4 -0
- torchflat-0.8.0/tests/test_batching.py +143 -0
- torchflat-0.8.0/tests/test_clipping.py +122 -0
- torchflat-0.8.0/tests/test_degenerate.py +93 -0
- torchflat-0.8.0/tests/test_determinism.py +58 -0
- torchflat-0.8.0/tests/test_gaps.py +181 -0
- torchflat-0.8.0/tests/test_highpass.py +177 -0
- torchflat-0.8.0/tests/test_injection.py +233 -0
- torchflat-0.8.0/tests/test_kernel.py +193 -0
- torchflat-0.8.0/tests/test_normalize.py +105 -0
- torchflat-0.8.0/tests/test_pipeline.py +164 -0
- torchflat-0.8.0/tests/test_quality.py +74 -0
- torchflat-0.8.0/tests/test_umi.py +214 -0
- torchflat-0.8.0/tests/test_utils.py +229 -0
- torchflat-0.8.0/tests/test_windows.py +130 -0
- torchflat-0.8.0/torchflat/__init__.py +15 -0
- torchflat-0.8.0/torchflat/_kernel_loader.py +289 -0
- torchflat-0.8.0/torchflat/_utils.py +99 -0
- torchflat-0.8.0/torchflat/batching.py +238 -0
- torchflat-0.8.0/torchflat/cli.py +535 -0
- torchflat-0.8.0/torchflat/clipping.py +85 -0
- torchflat-0.8.0/torchflat/csrc/build/test_combined.cpp +29 -0
- torchflat-0.8.0/torchflat/csrc/build/test_error_check.cpp +47 -0
- torchflat-0.8.0/torchflat/csrc/build/test_kernel.cpp +29 -0
- torchflat-0.8.0/torchflat/csrc/build/umi_kernel_hip.cpp +258 -0
- torchflat-0.8.0/torchflat/csrc/masked_median_kernel_hip.cpp +202 -0
- torchflat-0.8.0/torchflat/csrc/umi_ext.cpp +24 -0
- torchflat-0.8.0/torchflat/csrc/umi_kernel.cu +490 -0
- torchflat-0.8.0/torchflat/gaps.py +146 -0
- torchflat-0.8.0/torchflat/highpass.py +146 -0
- torchflat-0.8.0/torchflat/normalize.py +52 -0
- torchflat-0.8.0/torchflat/pipeline.py +604 -0
- torchflat-0.8.0/torchflat/py.typed +0 -0
- torchflat-0.8.0/torchflat/quality.py +30 -0
- torchflat-0.8.0/torchflat/umi.py +185 -0
- torchflat-0.8.0/torchflat/windows.py +87 -0
- torchflat-0.8.0/torchflat.egg-info/PKG-INFO +234 -0
- torchflat-0.8.0/torchflat.egg-info/SOURCES.txt +44 -0
- torchflat-0.8.0/torchflat.egg-info/dependency_links.txt +1 -0
- torchflat-0.8.0/torchflat.egg-info/entry_points.txt +2 -0
- torchflat-0.8.0/torchflat.egg-info/requires.txt +12 -0
- torchflat-0.8.0/torchflat.egg-info/top_level.txt +1 -0
torchflat-0.8.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Omar Khan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
torchflat-0.8.0/PKG-INFO
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchflat
|
|
3
|
+
Version: 0.8.0
|
|
4
|
+
Summary: GPU-accelerated photometric preprocessing with UMI detrending for exoplanet transit searches
|
|
5
|
+
Author: Omar Khan
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/omarkhan2217/TorchFlat
|
|
8
|
+
Project-URL: Repository, https://github.com/omarkhan2217/TorchFlat
|
|
9
|
+
Keywords: astronomy,exoplanets,transits,TESS,GPU,PyTorch,detrending,UMI
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Astronomy
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: torch>=2.1.0
|
|
23
|
+
Requires-Dist: numpy>=1.24.0
|
|
24
|
+
Requires-Dist: scipy>=1.10.0
|
|
25
|
+
Provides-Extra: test
|
|
26
|
+
Requires-Dist: pytest>=7.0; extra == "test"
|
|
27
|
+
Requires-Dist: pytest-benchmark; extra == "test"
|
|
28
|
+
Requires-Dist: wotan; extra == "test"
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: torchflat[test]; extra == "dev"
|
|
31
|
+
Requires-Dist: astropy; extra == "dev"
|
|
32
|
+
Dynamic: license-file
|
|
33
|
+
|
|
34
|
+
# TorchFlat
|
|
35
|
+
|
|
36
|
+
**GPU-native photometric preprocessing pipeline for exoplanet transit searches.**
|
|
37
|
+
|
|
38
|
+
[](https://opensource.org/licenses/MIT)
|
|
39
|
+
[](https://www.python.org/downloads/)
|
|
40
|
+
|
|
41
|
+
TorchFlat replaces the standard CPU preprocessing workflow (quality filtering, gap handling, sigma clipping, detrending, normalization, and windowing) with a GPU-accelerated pipeline. It uses **UMI** (Unified Median Iterative), a novel asymmetric robust location estimator implemented as a fused HIP/CUDA kernel, to detrend light curves faster and more accurately than existing methods.
|
|
42
|
+
|
|
43
|
+
## Performance
|
|
44
|
+
|
|
45
|
+
Benchmarked on AMD Radeon RX 9060 XT (16 GB VRAM) with real TESS sector 6 data (19,618 stars):
|
|
46
|
+
|
|
47
|
+
| Method | Rate | Full Sector | Speedup |
|
|
48
|
+
|--------|------|-------------|---------|
|
|
49
|
+
| Celix wotan 12-worker | 4.2 stars/sec | ~78 min | baseline |
|
|
50
|
+
| TorchFlat v0.5.0 (hybrid) | 59.3 stars/sec | ~5.5 min | 14.2x |
|
|
51
|
+
| **TorchFlat v0.8.0 + UMI kernel** | **139 stars/sec** | **~2.4 min** | **33x** |
|
|
52
|
+
|
|
53
|
+
### Transit Depth Recovery Accuracy
|
|
54
|
+
|
|
55
|
+
Injection-recovery test on 200 real TESS stars, median per-star error (lower = better):
|
|
56
|
+
|
|
57
|
+
| Depth | wotan biweight | TorchFlat UMI | Winner |
|
|
58
|
+
|-------|---------------|---------------|--------|
|
|
59
|
+
| 0.1% (super-Earths) | 19.8% | **14.7%** | TorchFlat |
|
|
60
|
+
| 0.3% (sub-Neptunes) | 8.4% | **3.7%** | TorchFlat |
|
|
61
|
+
| 0.5% (Neptunes) | 1.5% | **1.5%** | tie |
|
|
62
|
+
| 1.0% (hot Jupiters) | **0.4%** | 0.7% | wotan |
|
|
63
|
+
| 5.0% (deep transits) | **0.0%** | 0.1% | both perfect |
|
|
64
|
+
|
|
65
|
+
TorchFlat is more accurate at the transit depths where most detectable planets live (0.1-0.5%).
|
|
66
|
+
|
|
67
|
+
Validated on 3 TESS sectors (6, 7, 12), 42 confirmed planets, and 1000-star train/test split. Results in `results/`.
|
|
68
|
+
|
|
69
|
+
## The UMI Algorithm
|
|
70
|
+
|
|
71
|
+
UMI (Unified Median Iterative) is a three-phase robust location estimator:
|
|
72
|
+
|
|
73
|
+
1. **Quickselect median** -- exact median via O(n) selection algorithm, computed per-thread on GPU
|
|
74
|
+
2. **Upper-RMS scale** -- RMS of points above the median only. Transit dips never contaminate the scale estimate, giving a tighter and more accurate noise measurement than standard MAD
|
|
75
|
+
3. **Asymmetric bisquare iterations** -- weighted location refinement where downward deviations (transit dips) are penalized 1.5x more than upward ones
|
|
76
|
+
|
|
77
|
+
The asymmetric weight function exploits the fact that **transits are always below the continuum**. Standard biweight treats dips and spikes equally. UMI penalizes dips more aggressively, so the trend stays above the transit and transit depth is preserved.
|
|
78
|
+
|
|
79
|
+
All three phases run in a single fused GPU kernel call -- median, upper-RMS, and 5 iterations happen per-thread with zero global memory traffic between steps.
|
|
80
|
+
|
|
81
|
+
When the GPU kernel is not available (no ROCm/CUDA toolkit), UMI falls back to a pure-PyTorch path using torch.sort for median + upper-RMS scale.
|
|
82
|
+
|
|
83
|
+
## Installation
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
git clone https://github.com/omarkhan2217/TorchFlat.git
|
|
87
|
+
cd TorchFlat
|
|
88
|
+
pip install -e .
|
|
89
|
+
```
|
|
90
|
+
|
|
91
|
+
**Requirements:** PyTorch >= 2.1.0, NumPy >= 1.24.0, Numba >= 0.57.0, SciPy >= 1.10.0
|
|
92
|
+
|
|
93
|
+
Works with both **NVIDIA CUDA** and **AMD ROCm** (via PyTorch's unified CUDA API). The UMI kernel compiles automatically on first use via JIT (requires ROCm SDK or CUDA toolkit).
|
|
94
|
+
|
|
95
|
+
## Quick Start
|
|
96
|
+
|
|
97
|
+
### Process a TESS sector
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
import numpy as np
|
|
101
|
+
import torchflat
|
|
102
|
+
|
|
103
|
+
star_data = [
|
|
104
|
+
{
|
|
105
|
+
"time": np.load("star_001_time.npy"),
|
|
106
|
+
"pdcsap_flux": np.load("star_001_pdcsap.npy"),
|
|
107
|
+
"sap_flux": np.load("star_001_sap.npy"),
|
|
108
|
+
"quality": np.load("star_001_quality.npy"),
|
|
109
|
+
}
|
|
110
|
+
# ... for each star in the sector
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
results, skipped = torchflat.preprocess_sector(
|
|
114
|
+
star_data,
|
|
115
|
+
device="cuda",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
for i, result in enumerate(results):
|
|
119
|
+
if not result:
|
|
120
|
+
continue
|
|
121
|
+
windows = result["windows_2048"]
|
|
122
|
+
trend = result["trend"]
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### Standalone UMI detrending
|
|
126
|
+
|
|
127
|
+
```python
|
|
128
|
+
import torch
|
|
129
|
+
from torchflat import umi_detrend
|
|
130
|
+
|
|
131
|
+
# flux, time, valid_mask, segment_id are [B, L] tensors on GPU
|
|
132
|
+
detrended, trend = umi_detrend(
|
|
133
|
+
flux, time, valid_mask, segment_id,
|
|
134
|
+
window_length_days=0.5,
|
|
135
|
+
asymmetry=2.0, # 2.0=best accuracy, 1.0=variable stars, 1.5=mixed
|
|
136
|
+
)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
## Architecture
|
|
140
|
+
|
|
141
|
+
TorchFlat implements two processing tracks:
|
|
142
|
+
|
|
143
|
+
- **Track A (Transit Search):** Quality filter > gap interpolation > sigma clipping > UMI detrending > normalization > multi-scale window extraction
|
|
144
|
+
- **Track B (Anomaly Detection):** Quality filter > gap interpolation > conservative clipping > FFT highpass filter > MAD normalization > fixed-length padding
|
|
145
|
+
|
|
146
|
+
### UMI kernel
|
|
147
|
+
|
|
148
|
+
The direct HIP/CUDA kernel (`torchflat/csrc/umi_kernel.cu`) runs one thread per (star, window position) pair. Each thread reads directly from the raw `[B, L]` flux array, no unfold or tensor copies needed:
|
|
149
|
+
|
|
150
|
+
1. Reads W values from raw flux, checks segment validity inline
|
|
151
|
+
2. Quickselect for exact median (O(n))
|
|
152
|
+
3. Absolute deviations + quickselect for exact MAD
|
|
153
|
+
4. 5 asymmetric bisquare iterations
|
|
154
|
+
5. Writes the final location estimate
|
|
155
|
+
|
|
156
|
+
VRAM usage: 319 MB for a 50-star batch. The kernel compiles via JIT on first import and is cached for subsequent runs.
|
|
157
|
+
|
|
158
|
+
## Validation
|
|
159
|
+
|
|
160
|
+
All validation results are saved as JSON in `results/`:
|
|
161
|
+
|
|
162
|
+
| Validation | Result | File |
|
|
163
|
+
|-----------|--------|------|
|
|
164
|
+
| Asymmetry train/test split | optimal=1.5, generalizes across held-out stars | `asymmetry_validation.json` |
|
|
165
|
+
| Known planet recovery | TorchFlat wins 24/41 (59%) confirmed planets | `known_planet_recovery.json` |
|
|
166
|
+
| Multi-sector consistency | UMI wins 10/15 depth-sector combos across sectors 6,7,12 | `multisector_validation.json` |
|
|
167
|
+
|
|
168
|
+
135/135 unit tests passing.
|
|
169
|
+
|
|
170
|
+
## Benchmarks
|
|
171
|
+
|
|
172
|
+
```bash
|
|
173
|
+
# Full sector speed benchmark
|
|
174
|
+
python benchmarks/bench_real_tess.py --data-dir /path/to/fits/sector_6 --n-stars 19618
|
|
175
|
+
|
|
176
|
+
# Asymmetry parameter validation
|
|
177
|
+
python benchmarks/validate_asymmetry.py
|
|
178
|
+
|
|
179
|
+
# Known planet recovery
|
|
180
|
+
python benchmarks/validate_known_planets.py
|
|
181
|
+
|
|
182
|
+
# Multi-sector validation
|
|
183
|
+
python benchmarks/validate_multisector.py
|
|
184
|
+
```
|
|
185
|
+
|
|
186
|
+
**Note:** Set `$env:TORCHFLAT_NO_KERNEL = "0"` (PowerShell) or `export TORCHFLAT_NO_KERNEL=0` (bash) to enable the UMI kernel.
|
|
187
|
+
|
|
188
|
+
## API Reference
|
|
189
|
+
|
|
190
|
+
### Main Entry Points
|
|
191
|
+
|
|
192
|
+
- **`torchflat.preprocess_sector(star_data, ...)`** -- Full pipeline (Track A + Track B).
|
|
193
|
+
- **`torchflat.preprocess_track_a(times, fluxes, qualities, ...)`** -- Track A only.
|
|
194
|
+
- **`torchflat.preprocess_track_b(times, sap_fluxes, qualities, ...)`** -- Track B only.
|
|
195
|
+
- **`torchflat.umi_detrend(flux, time, valid_mask, segment_id, ...)`** -- Standalone UMI kernel.
|
|
196
|
+
|
|
197
|
+
### Key Parameters
|
|
198
|
+
|
|
199
|
+
| Parameter | Default | Description |
|
|
200
|
+
|-----------|---------|-------------|
|
|
201
|
+
| `device` | `"cuda"` | Torch device |
|
|
202
|
+
| `window_length_days` | `0.5` | Sliding window width (days) |
|
|
203
|
+
| `asymmetry` | `2.0` | Dip penalty: 2.0 (quiet stars), 1.5 (mixed), 1.0 (variable stars) |
|
|
204
|
+
| `n_iter` | `5` | Number of bisquare iterations |
|
|
205
|
+
| `cval` | `5.0` | Rejection threshold in MAD units |
|
|
206
|
+
| `skip_track_b` | `False` | Skip Track B (FFT highpass) |
|
|
207
|
+
| `window_scales` | 4 scales | `[(256,128), (512,256), (2048,512), (8192,2048)]` |
|
|
208
|
+
| `dtype` | `float32` | Computation precision |
|
|
209
|
+
|
|
210
|
+
## Development
|
|
211
|
+
|
|
212
|
+
```bash
|
|
213
|
+
git clone https://github.com/omarkhan2217/TorchFlat.git
|
|
214
|
+
cd TorchFlat
|
|
215
|
+
pip install -e ".[dev]"
|
|
216
|
+
pytest tests/ -v
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
## Citation
|
|
220
|
+
|
|
221
|
+
If you use TorchFlat in your research, please cite:
|
|
222
|
+
|
|
223
|
+
```bibtex
|
|
224
|
+
@software{torchflat,
|
|
225
|
+
author = {Khan, Omar},
|
|
226
|
+
title = {TorchFlat: GPU-Accelerated Photometric Preprocessing with UMI Detrending},
|
|
227
|
+
year = {2026},
|
|
228
|
+
url = {https://github.com/omarkhan2217/TorchFlat}
|
|
229
|
+
}
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
## License
|
|
233
|
+
|
|
234
|
+
MIT License. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# TorchFlat
|
|
2
|
+
|
|
3
|
+
**GPU-native photometric preprocessing pipeline for exoplanet transit searches.**
|
|
4
|
+
|
|
5
|
+
[](https://opensource.org/licenses/MIT)
|
|
6
|
+
[](https://www.python.org/downloads/)
|
|
7
|
+
|
|
8
|
+
TorchFlat replaces the standard CPU preprocessing workflow (quality filtering, gap handling, sigma clipping, detrending, normalization, and windowing) with a GPU-accelerated pipeline. It uses **UMI** (Unified Median Iterative), a novel asymmetric robust location estimator implemented as a fused HIP/CUDA kernel, to detrend light curves faster and more accurately than existing methods.
|
|
9
|
+
|
|
10
|
+
## Performance
|
|
11
|
+
|
|
12
|
+
Benchmarked on AMD Radeon RX 9060 XT (16 GB VRAM) with real TESS sector 6 data (19,618 stars):
|
|
13
|
+
|
|
14
|
+
| Method | Rate | Full Sector | Speedup |
|
|
15
|
+
|--------|------|-------------|---------|
|
|
16
|
+
| Celix wotan 12-worker | 4.2 stars/sec | ~78 min | baseline |
|
|
17
|
+
| TorchFlat v0.5.0 (hybrid) | 59.3 stars/sec | ~5.5 min | 14.2x |
|
|
18
|
+
| **TorchFlat v0.8.0 + UMI kernel** | **139 stars/sec** | **~2.4 min** | **33x** |
|
|
19
|
+
|
|
20
|
+
### Transit Depth Recovery Accuracy
|
|
21
|
+
|
|
22
|
+
Injection-recovery test on 200 real TESS stars, median per-star error (lower = better):
|
|
23
|
+
|
|
24
|
+
| Depth | wotan biweight | TorchFlat UMI | Winner |
|
|
25
|
+
|-------|---------------|---------------|--------|
|
|
26
|
+
| 0.1% (super-Earths) | 19.8% | **14.7%** | TorchFlat |
|
|
27
|
+
| 0.3% (sub-Neptunes) | 8.4% | **3.7%** | TorchFlat |
|
|
28
|
+
| 0.5% (Neptunes) | 1.5% | **1.5%** | tie |
|
|
29
|
+
| 1.0% (hot Jupiters) | **0.4%** | 0.7% | wotan |
|
|
30
|
+
| 5.0% (deep transits) | **0.0%** | 0.1% | both perfect |
|
|
31
|
+
|
|
32
|
+
TorchFlat is more accurate at the transit depths where most detectable planets live (0.1-0.5%).
|
|
33
|
+
|
|
34
|
+
Validated on 3 TESS sectors (6, 7, 12), 42 confirmed planets, and 1000-star train/test split. Results in `results/`.
|
|
35
|
+
|
|
36
|
+
## The UMI Algorithm
|
|
37
|
+
|
|
38
|
+
UMI (Unified Median Iterative) is a three-phase robust location estimator:
|
|
39
|
+
|
|
40
|
+
1. **Quickselect median** -- exact median via O(n) selection algorithm, computed per-thread on GPU
|
|
41
|
+
2. **Upper-RMS scale** -- RMS of points above the median only. Transit dips never contaminate the scale estimate, giving a tighter and more accurate noise measurement than standard MAD
|
|
42
|
+
3. **Asymmetric bisquare iterations** -- weighted location refinement where downward deviations (transit dips) are penalized 1.5x more than upward ones
|
|
43
|
+
|
|
44
|
+
The asymmetric weight function exploits the fact that **transits are always below the continuum**. Standard biweight treats dips and spikes equally. UMI penalizes dips more aggressively, so the trend stays above the transit and transit depth is preserved.
|
|
45
|
+
|
|
46
|
+
All three phases run in a single fused GPU kernel call -- median, upper-RMS, and 5 iterations happen per-thread with zero global memory traffic between steps.
|
|
47
|
+
|
|
48
|
+
When the GPU kernel is not available (no ROCm/CUDA toolkit), UMI falls back to a pure-PyTorch path using torch.sort for median + upper-RMS scale.
|
|
49
|
+
|
|
50
|
+
## Installation
|
|
51
|
+
|
|
52
|
+
```bash
|
|
53
|
+
git clone https://github.com/omarkhan2217/TorchFlat.git
|
|
54
|
+
cd TorchFlat
|
|
55
|
+
pip install -e .
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
**Requirements:** PyTorch >= 2.1.0, NumPy >= 1.24.0, Numba >= 0.57.0, SciPy >= 1.10.0
|
|
59
|
+
|
|
60
|
+
Works with both **NVIDIA CUDA** and **AMD ROCm** (via PyTorch's unified CUDA API). The UMI kernel compiles automatically on first use via JIT (requires ROCm SDK or CUDA toolkit).
|
|
61
|
+
|
|
62
|
+
## Quick Start
|
|
63
|
+
|
|
64
|
+
### Process a TESS sector
|
|
65
|
+
|
|
66
|
+
```python
|
|
67
|
+
import numpy as np
|
|
68
|
+
import torchflat
|
|
69
|
+
|
|
70
|
+
star_data = [
|
|
71
|
+
{
|
|
72
|
+
"time": np.load("star_001_time.npy"),
|
|
73
|
+
"pdcsap_flux": np.load("star_001_pdcsap.npy"),
|
|
74
|
+
"sap_flux": np.load("star_001_sap.npy"),
|
|
75
|
+
"quality": np.load("star_001_quality.npy"),
|
|
76
|
+
}
|
|
77
|
+
# ... for each star in the sector
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
results, skipped = torchflat.preprocess_sector(
|
|
81
|
+
star_data,
|
|
82
|
+
device="cuda",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
for i, result in enumerate(results):
|
|
86
|
+
if not result:
|
|
87
|
+
continue
|
|
88
|
+
windows = result["windows_2048"]
|
|
89
|
+
trend = result["trend"]
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Standalone UMI detrending
|
|
93
|
+
|
|
94
|
+
```python
|
|
95
|
+
import torch
|
|
96
|
+
from torchflat import umi_detrend
|
|
97
|
+
|
|
98
|
+
# flux, time, valid_mask, segment_id are [B, L] tensors on GPU
|
|
99
|
+
detrended, trend = umi_detrend(
|
|
100
|
+
flux, time, valid_mask, segment_id,
|
|
101
|
+
window_length_days=0.5,
|
|
102
|
+
asymmetry=2.0, # 2.0=best accuracy, 1.0=variable stars, 1.5=mixed
|
|
103
|
+
)
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
## Architecture
|
|
107
|
+
|
|
108
|
+
TorchFlat implements two processing tracks:
|
|
109
|
+
|
|
110
|
+
- **Track A (Transit Search):** Quality filter > gap interpolation > sigma clipping > UMI detrending > normalization > multi-scale window extraction
|
|
111
|
+
- **Track B (Anomaly Detection):** Quality filter > gap interpolation > conservative clipping > FFT highpass filter > MAD normalization > fixed-length padding
|
|
112
|
+
|
|
113
|
+
### UMI kernel
|
|
114
|
+
|
|
115
|
+
The direct HIP/CUDA kernel (`torchflat/csrc/umi_kernel.cu`) runs one thread per (star, window position) pair. Each thread reads directly from the raw `[B, L]` flux array, no unfold or tensor copies needed:
|
|
116
|
+
|
|
117
|
+
1. Reads W values from raw flux, checks segment validity inline
|
|
118
|
+
2. Quickselect for exact median (O(n))
|
|
119
|
+
3. Absolute deviations + quickselect for exact MAD
|
|
120
|
+
4. 5 asymmetric bisquare iterations
|
|
121
|
+
5. Writes the final location estimate
|
|
122
|
+
|
|
123
|
+
VRAM usage: 319 MB for a 50-star batch. The kernel compiles via JIT on first import and is cached for subsequent runs.
|
|
124
|
+
|
|
125
|
+
## Validation
|
|
126
|
+
|
|
127
|
+
All validation results are saved as JSON in `results/`:
|
|
128
|
+
|
|
129
|
+
| Validation | Result | File |
|
|
130
|
+
|-----------|--------|------|
|
|
131
|
+
| Asymmetry train/test split | optimal=1.5, generalizes across held-out stars | `asymmetry_validation.json` |
|
|
132
|
+
| Known planet recovery | TorchFlat wins 24/41 (59%) confirmed planets | `known_planet_recovery.json` |
|
|
133
|
+
| Multi-sector consistency | UMI wins 10/15 depth-sector combos across sectors 6,7,12 | `multisector_validation.json` |
|
|
134
|
+
|
|
135
|
+
135/135 unit tests passing.
|
|
136
|
+
|
|
137
|
+
## Benchmarks
|
|
138
|
+
|
|
139
|
+
```bash
|
|
140
|
+
# Full sector speed benchmark
|
|
141
|
+
python benchmarks/bench_real_tess.py --data-dir /path/to/fits/sector_6 --n-stars 19618
|
|
142
|
+
|
|
143
|
+
# Asymmetry parameter validation
|
|
144
|
+
python benchmarks/validate_asymmetry.py
|
|
145
|
+
|
|
146
|
+
# Known planet recovery
|
|
147
|
+
python benchmarks/validate_known_planets.py
|
|
148
|
+
|
|
149
|
+
# Multi-sector validation
|
|
150
|
+
python benchmarks/validate_multisector.py
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
**Note:** Set `$env:TORCHFLAT_NO_KERNEL = "0"` (PowerShell) or `export TORCHFLAT_NO_KERNEL=0` (bash) to enable the UMI kernel.
|
|
154
|
+
|
|
155
|
+
## API Reference
|
|
156
|
+
|
|
157
|
+
### Main Entry Points
|
|
158
|
+
|
|
159
|
+
- **`torchflat.preprocess_sector(star_data, ...)`** -- Full pipeline (Track A + Track B).
|
|
160
|
+
- **`torchflat.preprocess_track_a(times, fluxes, qualities, ...)`** -- Track A only.
|
|
161
|
+
- **`torchflat.preprocess_track_b(times, sap_fluxes, qualities, ...)`** -- Track B only.
|
|
162
|
+
- **`torchflat.umi_detrend(flux, time, valid_mask, segment_id, ...)`** -- Standalone UMI kernel.
|
|
163
|
+
|
|
164
|
+
### Key Parameters
|
|
165
|
+
|
|
166
|
+
| Parameter | Default | Description |
|
|
167
|
+
|-----------|---------|-------------|
|
|
168
|
+
| `device` | `"cuda"` | Torch device |
|
|
169
|
+
| `window_length_days` | `0.5` | Sliding window width (days) |
|
|
170
|
+
| `asymmetry` | `2.0` | Dip penalty: 2.0 (quiet stars), 1.5 (mixed), 1.0 (variable stars) |
|
|
171
|
+
| `n_iter` | `5` | Number of bisquare iterations |
|
|
172
|
+
| `cval` | `5.0` | Rejection threshold in MAD units |
|
|
173
|
+
| `skip_track_b` | `False` | Skip Track B (FFT highpass) |
|
|
174
|
+
| `window_scales` | 4 scales | `[(256,128), (512,256), (2048,512), (8192,2048)]` |
|
|
175
|
+
| `dtype` | `float32` | Computation precision |
|
|
176
|
+
|
|
177
|
+
## Development
|
|
178
|
+
|
|
179
|
+
```bash
|
|
180
|
+
git clone https://github.com/omarkhan2217/TorchFlat.git
|
|
181
|
+
cd TorchFlat
|
|
182
|
+
pip install -e ".[dev]"
|
|
183
|
+
pytest tests/ -v
|
|
184
|
+
```
|
|
185
|
+
|
|
186
|
+
## Citation
|
|
187
|
+
|
|
188
|
+
If you use TorchFlat in your research, please cite:
|
|
189
|
+
|
|
190
|
+
```bibtex
|
|
191
|
+
@software{torchflat,
|
|
192
|
+
author = {Khan, Omar},
|
|
193
|
+
title = {TorchFlat: GPU-Accelerated Photometric Preprocessing with UMI Detrending},
|
|
194
|
+
year = {2026},
|
|
195
|
+
url = {https://github.com/omarkhan2217/TorchFlat}
|
|
196
|
+
}
|
|
197
|
+
```
|
|
198
|
+
|
|
199
|
+
## License
|
|
200
|
+
|
|
201
|
+
MIT License. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torchflat"
|
|
7
|
+
version = "0.8.0"
|
|
8
|
+
description = "GPU-accelerated photometric preprocessing with UMI detrending for exoplanet transit searches"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = {text = "MIT"}
|
|
11
|
+
authors = [{name = "Omar Khan"}]
|
|
12
|
+
requires-python = ">=3.10"
|
|
13
|
+
classifiers = [
|
|
14
|
+
"Development Status :: 4 - Beta",
|
|
15
|
+
"Intended Audience :: Science/Research",
|
|
16
|
+
"License :: OSI Approved :: MIT License",
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Programming Language :: Python :: 3.10",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
"Topic :: Scientific/Engineering :: Astronomy",
|
|
23
|
+
]
|
|
24
|
+
keywords = ["astronomy", "exoplanets", "transits", "TESS", "GPU", "PyTorch", "detrending", "UMI"]
|
|
25
|
+
dependencies = [
|
|
26
|
+
"torch>=2.1.0",
|
|
27
|
+
"numpy>=1.24.0",
|
|
28
|
+
"scipy>=1.10.0",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.scripts]
|
|
32
|
+
torchflat = "torchflat.cli:main"
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/omarkhan2217/TorchFlat"
|
|
36
|
+
Repository = "https://github.com/omarkhan2217/TorchFlat"
|
|
37
|
+
|
|
38
|
+
[project.optional-dependencies]
|
|
39
|
+
test = [
|
|
40
|
+
"pytest>=7.0",
|
|
41
|
+
"pytest-benchmark",
|
|
42
|
+
"wotan",
|
|
43
|
+
]
|
|
44
|
+
dev = [
|
|
45
|
+
"torchflat[test]",
|
|
46
|
+
"astropy",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
[tool.setuptools.packages.find]
|
|
50
|
+
include = ["torchflat*"]
|
|
51
|
+
|
|
52
|
+
[tool.setuptools.package-data]
|
|
53
|
+
torchflat = ["csrc/*.cu", "csrc/*.cpp", "csrc/build/*.cpp"]
|
|
54
|
+
|
|
55
|
+
[tool.pytest.ini_options]
|
|
56
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Tests for torchflat.batching."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pytest
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from torchflat.batching import (
|
|
8
|
+
assemble_batch,
|
|
9
|
+
bucket_stars,
|
|
10
|
+
compute_max_batch,
|
|
11
|
+
cpu_prescan,
|
|
12
|
+
estimate_peak_vram,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CADENCE = 2.0 / 1440.0
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _make_star(n_points: int = 18000, seed: int = 0):
|
|
19
|
+
"""Create a clean synthetic star as numpy arrays."""
|
|
20
|
+
rng = np.random.default_rng(seed)
|
|
21
|
+
time = np.arange(n_points, dtype=np.float64) * CADENCE
|
|
22
|
+
flux = (1.0 + rng.normal(0, 0.001, n_points)).astype(np.float32)
|
|
23
|
+
quality = np.zeros(n_points, dtype=np.int32)
|
|
24
|
+
return time, flux, quality
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestCpuPrescan:
|
|
28
|
+
|
|
29
|
+
def test_clean_star(self):
|
|
30
|
+
t, f, q = _make_star(18000)
|
|
31
|
+
results = cpu_prescan([t], [f], [q])
|
|
32
|
+
assert len(results) == 1
|
|
33
|
+
r = results[0]
|
|
34
|
+
assert r["index"] == 0
|
|
35
|
+
assert r["n_valid"] == 18000
|
|
36
|
+
assert r["degenerate"] is False
|
|
37
|
+
assert r["post_filter_length"] >= r["n_valid"]
|
|
38
|
+
|
|
39
|
+
def test_degenerate_too_few_points(self):
|
|
40
|
+
t, f, q = _make_star(50)
|
|
41
|
+
results = cpu_prescan([t], [f], [q])
|
|
42
|
+
assert results[0]["degenerate"] is True
|
|
43
|
+
assert results[0]["degenerate_reason"] == "too_few_valid_points"
|
|
44
|
+
|
|
45
|
+
def test_degenerate_segment_too_short(self):
|
|
46
|
+
# Star with many small segments, none long enough for biweight window
|
|
47
|
+
n = 5000
|
|
48
|
+
t = np.arange(n, dtype=np.float64) * CADENCE
|
|
49
|
+
# Insert large gaps every 100 points -> segments of ~100, window=360
|
|
50
|
+
for i in range(100, n, 100):
|
|
51
|
+
t[i:] += 20 * CADENCE
|
|
52
|
+
f = np.ones(n, dtype=np.float32)
|
|
53
|
+
q = np.zeros(n, dtype=np.int32)
|
|
54
|
+
results = cpu_prescan([t], [f], [q], window_samples=360)
|
|
55
|
+
assert results[0]["degenerate"] is True
|
|
56
|
+
assert results[0]["degenerate_reason"] == "segment_too_short"
|
|
57
|
+
|
|
58
|
+
def test_multiple_stars(self):
|
|
59
|
+
stars = [_make_star(n, seed=i) for i, n in enumerate([18000, 50, 15000])]
|
|
60
|
+
ts, fs, qs = zip(*stars)
|
|
61
|
+
results = cpu_prescan(list(ts), list(fs), list(qs))
|
|
62
|
+
assert len(results) == 3
|
|
63
|
+
assert results[0]["degenerate"] is False
|
|
64
|
+
assert results[1]["degenerate"] is True
|
|
65
|
+
assert results[2]["degenerate"] is False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class TestBucketStars:
|
|
69
|
+
|
|
70
|
+
def test_bucketing_correctness(self):
|
|
71
|
+
prescan = [
|
|
72
|
+
{"index": 0, "post_filter_length": 14500, "degenerate": False},
|
|
73
|
+
{"index": 1, "post_filter_length": 15500, "degenerate": False},
|
|
74
|
+
{"index": 2, "post_filter_length": 14800, "degenerate": False},
|
|
75
|
+
{"index": 3, "post_filter_length": 50, "degenerate": True, "degenerate_reason": "x"},
|
|
76
|
+
]
|
|
77
|
+
buckets = bucket_stars(prescan, bucket_width=1000)
|
|
78
|
+
# Star 3 is degenerate -> excluded
|
|
79
|
+
total_stars = sum(len(b["star_indices"]) for b in buckets)
|
|
80
|
+
assert total_stars == 3
|
|
81
|
+
|
|
82
|
+
def test_bucket_width(self):
|
|
83
|
+
prescan = [
|
|
84
|
+
{"index": i, "post_filter_length": 14000 + i * 200, "degenerate": False}
|
|
85
|
+
for i in range(10)
|
|
86
|
+
]
|
|
87
|
+
buckets = bucket_stars(prescan, bucket_width=1000)
|
|
88
|
+
for b in buckets:
|
|
89
|
+
# All stars in bucket should fit within pad_length
|
|
90
|
+
for idx in b["star_indices"]:
|
|
91
|
+
pfl = prescan[idx]["post_filter_length"]
|
|
92
|
+
assert pfl <= b["pad_length"]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class TestAssembleBatch:
|
|
96
|
+
|
|
97
|
+
def test_padding_no_data_leakage(self):
|
|
98
|
+
t1, f1, q1 = _make_star(100, seed=1)
|
|
99
|
+
t2, f2, q2 = _make_star(80, seed=2)
|
|
100
|
+
batch = assemble_batch([0, 1], [t1, t2], [f1, f2], [q1, q2], 120, torch.device("cpu"))
|
|
101
|
+
# Padded positions should be zero
|
|
102
|
+
assert (batch["flux"][0, 100:] == 0).all()
|
|
103
|
+
assert (batch["flux"][1, 80:] == 0).all()
|
|
104
|
+
|
|
105
|
+
def test_padding_mask(self):
|
|
106
|
+
t, f, q = _make_star(100)
|
|
107
|
+
batch = assemble_batch([0], [t], [f], [q], 150, torch.device("cpu"))
|
|
108
|
+
assert batch["valid_mask"][0, :100].all()
|
|
109
|
+
assert not batch["valid_mask"][0, 100:].any()
|
|
110
|
+
|
|
111
|
+
def test_lengths(self):
|
|
112
|
+
stars = [_make_star(n) for n in [100, 200, 150]]
|
|
113
|
+
ts, fs, qs = zip(*stars)
|
|
114
|
+
batch = assemble_batch([0, 1, 2], list(ts), list(fs), list(qs), 250, torch.device("cpu"))
|
|
115
|
+
assert batch["lengths"].tolist() == [100, 200, 150]
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class TestVramEstimation:
|
|
119
|
+
|
|
120
|
+
def test_monotonic(self):
|
|
121
|
+
v1 = estimate_peak_vram(15000, 360)
|
|
122
|
+
v2 = estimate_peak_vram(20000, 360)
|
|
123
|
+
assert v2 > v1
|
|
124
|
+
|
|
125
|
+
def test_reasonable_range(self):
|
|
126
|
+
v = estimate_peak_vram(20000, 360)
|
|
127
|
+
# Should be between 100MB and 500MB per star
|
|
128
|
+
assert 100 * 1024**2 < v < 500 * 1024**2
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class TestComputeMaxBatch:
|
|
132
|
+
|
|
133
|
+
def test_override(self):
|
|
134
|
+
assert compute_max_batch(20000, max_batch_override=10) == 10
|
|
135
|
+
|
|
136
|
+
def test_budget(self):
|
|
137
|
+
mb_12 = compute_max_batch(20000, vram_budget_gb=12.0)
|
|
138
|
+
mb_6 = compute_max_batch(20000, vram_budget_gb=6.0)
|
|
139
|
+
assert mb_12 > mb_6
|
|
140
|
+
assert mb_6 >= 1
|
|
141
|
+
|
|
142
|
+
def test_cpu_fallback(self):
|
|
143
|
+
assert compute_max_batch(20000, device=torch.device("cpu")) == 1
|