torchflat 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. torchflat-0.8.0/LICENSE +21 -0
  2. torchflat-0.8.0/PKG-INFO +234 -0
  3. torchflat-0.8.0/README.md +201 -0
  4. torchflat-0.8.0/pyproject.toml +56 -0
  5. torchflat-0.8.0/setup.cfg +4 -0
  6. torchflat-0.8.0/tests/test_batching.py +143 -0
  7. torchflat-0.8.0/tests/test_clipping.py +122 -0
  8. torchflat-0.8.0/tests/test_degenerate.py +93 -0
  9. torchflat-0.8.0/tests/test_determinism.py +58 -0
  10. torchflat-0.8.0/tests/test_gaps.py +181 -0
  11. torchflat-0.8.0/tests/test_highpass.py +177 -0
  12. torchflat-0.8.0/tests/test_injection.py +233 -0
  13. torchflat-0.8.0/tests/test_kernel.py +193 -0
  14. torchflat-0.8.0/tests/test_normalize.py +105 -0
  15. torchflat-0.8.0/tests/test_pipeline.py +164 -0
  16. torchflat-0.8.0/tests/test_quality.py +74 -0
  17. torchflat-0.8.0/tests/test_umi.py +214 -0
  18. torchflat-0.8.0/tests/test_utils.py +229 -0
  19. torchflat-0.8.0/tests/test_windows.py +130 -0
  20. torchflat-0.8.0/torchflat/__init__.py +15 -0
  21. torchflat-0.8.0/torchflat/_kernel_loader.py +289 -0
  22. torchflat-0.8.0/torchflat/_utils.py +99 -0
  23. torchflat-0.8.0/torchflat/batching.py +238 -0
  24. torchflat-0.8.0/torchflat/cli.py +535 -0
  25. torchflat-0.8.0/torchflat/clipping.py +85 -0
  26. torchflat-0.8.0/torchflat/csrc/build/test_combined.cpp +29 -0
  27. torchflat-0.8.0/torchflat/csrc/build/test_error_check.cpp +47 -0
  28. torchflat-0.8.0/torchflat/csrc/build/test_kernel.cpp +29 -0
  29. torchflat-0.8.0/torchflat/csrc/build/umi_kernel_hip.cpp +258 -0
  30. torchflat-0.8.0/torchflat/csrc/masked_median_kernel_hip.cpp +202 -0
  31. torchflat-0.8.0/torchflat/csrc/umi_ext.cpp +24 -0
  32. torchflat-0.8.0/torchflat/csrc/umi_kernel.cu +490 -0
  33. torchflat-0.8.0/torchflat/gaps.py +146 -0
  34. torchflat-0.8.0/torchflat/highpass.py +146 -0
  35. torchflat-0.8.0/torchflat/normalize.py +52 -0
  36. torchflat-0.8.0/torchflat/pipeline.py +604 -0
  37. torchflat-0.8.0/torchflat/py.typed +0 -0
  38. torchflat-0.8.0/torchflat/quality.py +30 -0
  39. torchflat-0.8.0/torchflat/umi.py +185 -0
  40. torchflat-0.8.0/torchflat/windows.py +87 -0
  41. torchflat-0.8.0/torchflat.egg-info/PKG-INFO +234 -0
  42. torchflat-0.8.0/torchflat.egg-info/SOURCES.txt +44 -0
  43. torchflat-0.8.0/torchflat.egg-info/dependency_links.txt +1 -0
  44. torchflat-0.8.0/torchflat.egg-info/entry_points.txt +2 -0
  45. torchflat-0.8.0/torchflat.egg-info/requires.txt +12 -0
  46. torchflat-0.8.0/torchflat.egg-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Omar Khan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,234 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchflat
3
+ Version: 0.8.0
4
+ Summary: GPU-accelerated photometric preprocessing with UMI detrending for exoplanet transit searches
5
+ Author: Omar Khan
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/omarkhan2217/TorchFlat
8
+ Project-URL: Repository, https://github.com/omarkhan2217/TorchFlat
9
+ Keywords: astronomy,exoplanets,transits,TESS,GPU,PyTorch,detrending,UMI
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Astronomy
19
+ Requires-Python: >=3.10
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: torch>=2.1.0
23
+ Requires-Dist: numpy>=1.24.0
24
+ Requires-Dist: scipy>=1.10.0
25
+ Provides-Extra: test
26
+ Requires-Dist: pytest>=7.0; extra == "test"
27
+ Requires-Dist: pytest-benchmark; extra == "test"
28
+ Requires-Dist: wotan; extra == "test"
29
+ Provides-Extra: dev
30
+ Requires-Dist: torchflat[test]; extra == "dev"
31
+ Requires-Dist: astropy; extra == "dev"
32
+ Dynamic: license-file
33
+
34
+ # TorchFlat
35
+
36
+ **GPU-native photometric preprocessing pipeline for exoplanet transit searches.**
37
+
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
40
+
41
+ TorchFlat replaces the standard CPU preprocessing workflow (quality filtering, gap handling, sigma clipping, detrending, normalization, and windowing) with a GPU-accelerated pipeline. It uses **UMI** (Unified Median Iterative), a novel asymmetric robust location estimator implemented as a fused HIP/CUDA kernel, to detrend light curves faster and more accurately than existing methods.
42
+
43
+ ## Performance
44
+
45
+ Benchmarked on AMD Radeon RX 9060 XT (16 GB VRAM) with real TESS sector 6 data (19,618 stars):
46
+
47
+ | Method | Rate | Full Sector | Speedup |
48
+ |--------|------|-------------|---------|
49
+ | Celix wotan 12-worker | 4.2 stars/sec | ~78 min | baseline |
50
+ | TorchFlat v0.5.0 (hybrid) | 59.3 stars/sec | ~5.5 min | 14.2x |
51
+ | **TorchFlat v0.8.0 + UMI kernel** | **139 stars/sec** | **~2.4 min** | **33x** |
52
+
53
+ ### Transit Depth Recovery Accuracy
54
+
55
+ Injection-recovery test on 200 real TESS stars, median per-star error (lower = better):
56
+
57
+ | Depth | wotan biweight | TorchFlat UMI | Winner |
58
+ |-------|---------------|---------------|--------|
59
+ | 0.1% (super-Earths) | 19.8% | **14.7%** | TorchFlat |
60
+ | 0.3% (sub-Neptunes) | 8.4% | **3.7%** | TorchFlat |
61
+ | 0.5% (Neptunes) | 1.5% | **1.5%** | tie |
62
+ | 1.0% (hot Jupiters) | **0.4%** | 0.7% | wotan |
63
+ | 5.0% (deep transits) | **0.0%** | 0.1% | both perfect |
64
+
65
+ TorchFlat is more accurate at the transit depths where most detectable planets live (0.1-0.5%).
66
+
67
+ Validated on 3 TESS sectors (6, 7, 12), 42 confirmed planets, and 1000-star train/test split. Results in `results/`.
68
+
69
+ ## The UMI Algorithm
70
+
71
+ UMI (Unified Median Iterative) is a three-phase robust location estimator:
72
+
73
+ 1. **Quickselect median** -- exact median via O(n) selection algorithm, computed per-thread on GPU
74
+ 2. **Upper-RMS scale** -- RMS of points above the median only. Transit dips never contaminate the scale estimate, giving a tighter and more accurate noise measurement than standard MAD
75
+ 3. **Asymmetric bisquare iterations** -- weighted location refinement where downward deviations (transit dips) are penalized 1.5x more than upward ones
76
+
77
+ The asymmetric weight function exploits the fact that **transits are always below the continuum**. Standard biweight treats dips and spikes equally. UMI penalizes dips more aggressively, so the trend stays above the transit and transit depth is preserved.
78
+
79
+ All three phases run in a single fused GPU kernel call -- median, upper-RMS, and 5 iterations happen per-thread with zero global memory traffic between steps.
80
+
81
+ When the GPU kernel is not available (no ROCm/CUDA toolkit), UMI falls back to a pure-PyTorch path using torch.sort for median + upper-RMS scale.
82
+
83
+ ## Installation
84
+
85
+ ```bash
86
+ git clone https://github.com/omarkhan2217/TorchFlat.git
87
+ cd TorchFlat
88
+ pip install -e .
89
+ ```
90
+
91
+ **Requirements:** PyTorch >= 2.1.0, NumPy >= 1.24.0, Numba >= 0.57.0, SciPy >= 1.10.0
92
+
93
+ Works with both **NVIDIA CUDA** and **AMD ROCm** (via PyTorch's unified CUDA API). The UMI kernel compiles automatically on first use via JIT (requires ROCm SDK or CUDA toolkit).
94
+
95
+ ## Quick Start
96
+
97
+ ### Process a TESS sector
98
+
99
+ ```python
100
+ import numpy as np
101
+ import torchflat
102
+
103
+ star_data = [
104
+ {
105
+ "time": np.load("star_001_time.npy"),
106
+ "pdcsap_flux": np.load("star_001_pdcsap.npy"),
107
+ "sap_flux": np.load("star_001_sap.npy"),
108
+ "quality": np.load("star_001_quality.npy"),
109
+ }
110
+ # ... for each star in the sector
111
+ ]
112
+
113
+ results, skipped = torchflat.preprocess_sector(
114
+ star_data,
115
+ device="cuda",
116
+ )
117
+
118
+ for i, result in enumerate(results):
119
+ if not result:
120
+ continue
121
+ windows = result["windows_2048"]
122
+ trend = result["trend"]
123
+ ```
124
+
125
+ ### Standalone UMI detrending
126
+
127
+ ```python
128
+ import torch
129
+ from torchflat import umi_detrend
130
+
131
+ # flux, time, valid_mask, segment_id are [B, L] tensors on GPU
132
+ detrended, trend = umi_detrend(
133
+ flux, time, valid_mask, segment_id,
134
+ window_length_days=0.5,
135
+ asymmetry=2.0, # 2.0=best accuracy, 1.0=variable stars, 1.5=mixed
136
+ )
137
+ ```
138
+
139
+ ## Architecture
140
+
141
+ TorchFlat implements two processing tracks:
142
+
143
+ - **Track A (Transit Search):** Quality filter > gap interpolation > sigma clipping > UMI detrending > normalization > multi-scale window extraction
144
+ - **Track B (Anomaly Detection):** Quality filter > gap interpolation > conservative clipping > FFT highpass filter > MAD normalization > fixed-length padding
145
+
146
+ ### UMI kernel
147
+
148
+ The direct HIP/CUDA kernel (`torchflat/csrc/umi_kernel.cu`) runs one thread per (star, window position) pair. Each thread reads directly from the raw `[B, L]` flux array, no unfold or tensor copies needed:
149
+
150
+ 1. Reads W values from raw flux, checks segment validity inline
151
+ 2. Quickselect for exact median (O(n))
152
+ 3. Absolute deviations + quickselect for exact MAD
153
+ 4. 5 asymmetric bisquare iterations
154
+ 5. Writes the final location estimate
155
+
156
+ VRAM usage: 319 MB for a 50-star batch. The kernel compiles via JIT on first import and is cached for subsequent runs.
157
+
158
+ ## Validation
159
+
160
+ All validation results are saved as JSON in `results/`:
161
+
162
+ | Validation | Result | File |
163
+ |-----------|--------|------|
164
+ | Asymmetry train/test split | optimal=1.5, generalizes across held-out stars | `asymmetry_validation.json` |
165
+ | Known planet recovery | TorchFlat wins 24/41 (59%) confirmed planets | `known_planet_recovery.json` |
166
+ | Multi-sector consistency | UMI wins 10/15 depth-sector combos across sectors 6,7,12 | `multisector_validation.json` |
167
+
168
+ 135/135 unit tests passing.
169
+
170
+ ## Benchmarks
171
+
172
+ ```bash
173
+ # Full sector speed benchmark
174
+ python benchmarks/bench_real_tess.py --data-dir /path/to/fits/sector_6 --n-stars 19618
175
+
176
+ # Asymmetry parameter validation
177
+ python benchmarks/validate_asymmetry.py
178
+
179
+ # Known planet recovery
180
+ python benchmarks/validate_known_planets.py
181
+
182
+ # Multi-sector validation
183
+ python benchmarks/validate_multisector.py
184
+ ```
185
+
186
+ **Note:** Set `$env:TORCHFLAT_NO_KERNEL = "0"` (PowerShell) or `export TORCHFLAT_NO_KERNEL=0` (bash) to enable the UMI kernel.
187
+
188
+ ## API Reference
189
+
190
+ ### Main Entry Points
191
+
192
+ - **`torchflat.preprocess_sector(star_data, ...)`** -- Full pipeline (Track A + Track B).
193
+ - **`torchflat.preprocess_track_a(times, fluxes, qualities, ...)`** -- Track A only.
194
+ - **`torchflat.preprocess_track_b(times, sap_fluxes, qualities, ...)`** -- Track B only.
195
+ - **`torchflat.umi_detrend(flux, time, valid_mask, segment_id, ...)`** -- Standalone UMI kernel.
196
+
197
+ ### Key Parameters
198
+
199
+ | Parameter | Default | Description |
200
+ |-----------|---------|-------------|
201
+ | `device` | `"cuda"` | Torch device |
202
+ | `window_length_days` | `0.5` | Sliding window width (days) |
203
+ | `asymmetry` | `2.0` | Dip penalty: 2.0 (quiet stars), 1.5 (mixed), 1.0 (variable stars) |
204
+ | `n_iter` | `5` | Number of bisquare iterations |
205
+ | `cval` | `5.0` | Rejection threshold in MAD units |
206
+ | `skip_track_b` | `False` | Skip Track B (FFT highpass) |
207
+ | `window_scales` | 4 scales | `[(256,128), (512,256), (2048,512), (8192,2048)]` |
208
+ | `dtype` | `float32` | Computation precision |
209
+
210
+ ## Development
211
+
212
+ ```bash
213
+ git clone https://github.com/omarkhan2217/TorchFlat.git
214
+ cd TorchFlat
215
+ pip install -e ".[dev]"
216
+ pytest tests/ -v
217
+ ```
218
+
219
+ ## Citation
220
+
221
+ If you use TorchFlat in your research, please cite:
222
+
223
+ ```bibtex
224
+ @software{torchflat,
225
+ author = {Khan, Omar},
226
+ title = {TorchFlat: GPU-Accelerated Photometric Preprocessing with UMI Detrending},
227
+ year = {2026},
228
+ url = {https://github.com/omarkhan2217/TorchFlat}
229
+ }
230
+ ```
231
+
232
+ ## License
233
+
234
+ MIT License. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,201 @@
1
+ # TorchFlat
2
+
3
+ **GPU-native photometric preprocessing pipeline for exoplanet transit searches.**
4
+
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
6
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
7
+
8
+ TorchFlat replaces the standard CPU preprocessing workflow (quality filtering, gap handling, sigma clipping, detrending, normalization, and windowing) with a GPU-accelerated pipeline. It uses **UMI** (Unified Median Iterative), a novel asymmetric robust location estimator implemented as a fused HIP/CUDA kernel, to detrend light curves faster and more accurately than existing methods.
9
+
10
+ ## Performance
11
+
12
+ Benchmarked on AMD Radeon RX 9060 XT (16 GB VRAM) with real TESS sector 6 data (19,618 stars):
13
+
14
+ | Method | Rate | Full Sector | Speedup |
15
+ |--------|------|-------------|---------|
16
+ | Celix wotan 12-worker | 4.2 stars/sec | ~78 min | baseline |
17
+ | TorchFlat v0.5.0 (hybrid) | 59.3 stars/sec | ~5.5 min | 14.2x |
18
+ | **TorchFlat v0.8.0 + UMI kernel** | **139 stars/sec** | **~2.4 min** | **33x** |
19
+
20
+ ### Transit Depth Recovery Accuracy
21
+
22
+ Injection-recovery test on 200 real TESS stars, median per-star error (lower = better):
23
+
24
+ | Depth | wotan biweight | TorchFlat UMI | Winner |
25
+ |-------|---------------|---------------|--------|
26
+ | 0.1% (super-Earths) | 19.8% | **14.7%** | TorchFlat |
27
+ | 0.3% (sub-Neptunes) | 8.4% | **3.7%** | TorchFlat |
28
+ | 0.5% (Neptunes) | 1.5% | **1.5%** | tie |
29
+ | 1.0% (hot Jupiters) | **0.4%** | 0.7% | wotan |
30
+ | 5.0% (deep transits) | **0.0%** | 0.1% | both perfect |
31
+
32
+ TorchFlat is more accurate at the transit depths where most detectable planets live (0.1-0.5%).
33
+
34
+ Validated on 3 TESS sectors (6, 7, 12), 42 confirmed planets, and 1000-star train/test split. Results in `results/`.
35
+
36
+ ## The UMI Algorithm
37
+
38
+ UMI (Unified Median Iterative) is a three-phase robust location estimator:
39
+
40
+ 1. **Quickselect median** -- exact median via O(n) selection algorithm, computed per-thread on GPU
41
+ 2. **Upper-RMS scale** -- RMS of points above the median only. Transit dips never contaminate the scale estimate, giving a tighter and more accurate noise measurement than standard MAD
42
+ 3. **Asymmetric bisquare iterations** -- weighted location refinement where downward deviations (transit dips) are penalized 1.5x more than upward ones
43
+
44
+ The asymmetric weight function exploits the fact that **transits are always below the continuum**. Standard biweight treats dips and spikes equally. UMI penalizes dips more aggressively, so the trend stays above the transit and transit depth is preserved.
45
+
46
+ All three phases run in a single fused GPU kernel call -- median, upper-RMS, and 5 iterations happen per-thread with zero global memory traffic between steps.
47
+
48
+ When the GPU kernel is not available (no ROCm/CUDA toolkit), UMI falls back to a pure-PyTorch path using torch.sort for median + upper-RMS scale.
49
+
50
+ ## Installation
51
+
52
+ ```bash
53
+ git clone https://github.com/omarkhan2217/TorchFlat.git
54
+ cd TorchFlat
55
+ pip install -e .
56
+ ```
57
+
58
+ **Requirements:** PyTorch >= 2.1.0, NumPy >= 1.24.0, Numba >= 0.57.0, SciPy >= 1.10.0
59
+
60
+ Works with both **NVIDIA CUDA** and **AMD ROCm** (via PyTorch's unified CUDA API). The UMI kernel compiles automatically on first use via JIT (requires ROCm SDK or CUDA toolkit).
61
+
62
+ ## Quick Start
63
+
64
+ ### Process a TESS sector
65
+
66
+ ```python
67
+ import numpy as np
68
+ import torchflat
69
+
70
+ star_data = [
71
+ {
72
+ "time": np.load("star_001_time.npy"),
73
+ "pdcsap_flux": np.load("star_001_pdcsap.npy"),
74
+ "sap_flux": np.load("star_001_sap.npy"),
75
+ "quality": np.load("star_001_quality.npy"),
76
+ }
77
+ # ... for each star in the sector
78
+ ]
79
+
80
+ results, skipped = torchflat.preprocess_sector(
81
+ star_data,
82
+ device="cuda",
83
+ )
84
+
85
+ for i, result in enumerate(results):
86
+ if not result:
87
+ continue
88
+ windows = result["windows_2048"]
89
+ trend = result["trend"]
90
+ ```
91
+
92
+ ### Standalone UMI detrending
93
+
94
+ ```python
95
+ import torch
96
+ from torchflat import umi_detrend
97
+
98
+ # flux, time, valid_mask, segment_id are [B, L] tensors on GPU
99
+ detrended, trend = umi_detrend(
100
+ flux, time, valid_mask, segment_id,
101
+ window_length_days=0.5,
102
+ asymmetry=2.0, # 2.0=best accuracy, 1.0=variable stars, 1.5=mixed
103
+ )
104
+ ```
105
+
106
+ ## Architecture
107
+
108
+ TorchFlat implements two processing tracks:
109
+
110
+ - **Track A (Transit Search):** Quality filter > gap interpolation > sigma clipping > UMI detrending > normalization > multi-scale window extraction
111
+ - **Track B (Anomaly Detection):** Quality filter > gap interpolation > conservative clipping > FFT highpass filter > MAD normalization > fixed-length padding
112
+
113
+ ### UMI kernel
114
+
115
+ The direct HIP/CUDA kernel (`torchflat/csrc/umi_kernel.cu`) runs one thread per (star, window position) pair. Each thread reads directly from the raw `[B, L]` flux array, no unfold or tensor copies needed:
116
+
117
+ 1. Reads W values from raw flux, checks segment validity inline
118
+ 2. Quickselect for exact median (O(n))
119
+ 3. Absolute deviations + quickselect for exact MAD
120
+ 4. 5 asymmetric bisquare iterations
121
+ 5. Writes the final location estimate
122
+
123
+ VRAM usage: 319 MB for a 50-star batch. The kernel compiles via JIT on first import and is cached for subsequent runs.
124
+
125
+ ## Validation
126
+
127
+ All validation results are saved as JSON in `results/`:
128
+
129
+ | Validation | Result | File |
130
+ |-----------|--------|------|
131
+ | Asymmetry train/test split | optimal=1.5, generalizes across held-out stars | `asymmetry_validation.json` |
132
+ | Known planet recovery | TorchFlat wins 24/41 (59%) confirmed planets | `known_planet_recovery.json` |
133
+ | Multi-sector consistency | UMI wins 10/15 depth-sector combos across sectors 6,7,12 | `multisector_validation.json` |
134
+
135
+ 135/135 unit tests passing.
136
+
137
+ ## Benchmarks
138
+
139
+ ```bash
140
+ # Full sector speed benchmark
141
+ python benchmarks/bench_real_tess.py --data-dir /path/to/fits/sector_6 --n-stars 19618
142
+
143
+ # Asymmetry parameter validation
144
+ python benchmarks/validate_asymmetry.py
145
+
146
+ # Known planet recovery
147
+ python benchmarks/validate_known_planets.py
148
+
149
+ # Multi-sector validation
150
+ python benchmarks/validate_multisector.py
151
+ ```
152
+
153
+ **Note:** Set `$env:TORCHFLAT_NO_KERNEL = "0"` (PowerShell) or `export TORCHFLAT_NO_KERNEL=0` (bash) to enable the UMI kernel.
154
+
155
+ ## API Reference
156
+
157
+ ### Main Entry Points
158
+
159
+ - **`torchflat.preprocess_sector(star_data, ...)`** -- Full pipeline (Track A + Track B).
160
+ - **`torchflat.preprocess_track_a(times, fluxes, qualities, ...)`** -- Track A only.
161
+ - **`torchflat.preprocess_track_b(times, sap_fluxes, qualities, ...)`** -- Track B only.
162
+ - **`torchflat.umi_detrend(flux, time, valid_mask, segment_id, ...)`** -- Standalone UMI kernel.
163
+
164
+ ### Key Parameters
165
+
166
+ | Parameter | Default | Description |
167
+ |-----------|---------|-------------|
168
+ | `device` | `"cuda"` | Torch device |
169
+ | `window_length_days` | `0.5` | Sliding window width (days) |
170
+ | `asymmetry` | `2.0` | Dip penalty: 2.0 (quiet stars), 1.5 (mixed), 1.0 (variable stars) |
171
+ | `n_iter` | `5` | Number of bisquare iterations |
172
+ | `cval` | `5.0` | Rejection threshold in MAD units |
173
+ | `skip_track_b` | `False` | Skip Track B (FFT highpass) |
174
+ | `window_scales` | 4 scales | `[(256,128), (512,256), (2048,512), (8192,2048)]` |
175
+ | `dtype` | `float32` | Computation precision |
176
+
177
+ ## Development
178
+
179
+ ```bash
180
+ git clone https://github.com/omarkhan2217/TorchFlat.git
181
+ cd TorchFlat
182
+ pip install -e ".[dev]"
183
+ pytest tests/ -v
184
+ ```
185
+
186
+ ## Citation
187
+
188
+ If you use TorchFlat in your research, please cite:
189
+
190
+ ```bibtex
191
+ @software{torchflat,
192
+ author = {Khan, Omar},
193
+ title = {TorchFlat: GPU-Accelerated Photometric Preprocessing with UMI Detrending},
194
+ year = {2026},
195
+ url = {https://github.com/omarkhan2217/TorchFlat}
196
+ }
197
+ ```
198
+
199
+ ## License
200
+
201
+ MIT License. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,56 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torchflat"
7
+ version = "0.8.0"
8
+ description = "GPU-accelerated photometric preprocessing with UMI detrending for exoplanet transit searches"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ authors = [{name = "Omar Khan"}]
12
+ requires-python = ">=3.10"
13
+ classifiers = [
14
+ "Development Status :: 4 - Beta",
15
+ "Intended Audience :: Science/Research",
16
+ "License :: OSI Approved :: MIT License",
17
+ "Programming Language :: Python :: 3",
18
+ "Programming Language :: Python :: 3.10",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ "Topic :: Scientific/Engineering :: Astronomy",
23
+ ]
24
+ keywords = ["astronomy", "exoplanets", "transits", "TESS", "GPU", "PyTorch", "detrending", "UMI"]
25
+ dependencies = [
26
+ "torch>=2.1.0",
27
+ "numpy>=1.24.0",
28
+ "scipy>=1.10.0",
29
+ ]
30
+
31
+ [project.scripts]
32
+ torchflat = "torchflat.cli:main"
33
+
34
+ [project.urls]
35
+ Homepage = "https://github.com/omarkhan2217/TorchFlat"
36
+ Repository = "https://github.com/omarkhan2217/TorchFlat"
37
+
38
+ [project.optional-dependencies]
39
+ test = [
40
+ "pytest>=7.0",
41
+ "pytest-benchmark",
42
+ "wotan",
43
+ ]
44
+ dev = [
45
+ "torchflat[test]",
46
+ "astropy",
47
+ ]
48
+
49
+ [tool.setuptools.packages.find]
50
+ include = ["torchflat*"]
51
+
52
+ [tool.setuptools.package-data]
53
+ torchflat = ["csrc/*.cu", "csrc/*.cpp", "csrc/build/*.cpp"]
54
+
55
+ [tool.pytest.ini_options]
56
+ testpaths = ["tests"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,143 @@
1
+ """Tests for torchflat.batching."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import torch
6
+
7
+ from torchflat.batching import (
8
+ assemble_batch,
9
+ bucket_stars,
10
+ compute_max_batch,
11
+ cpu_prescan,
12
+ estimate_peak_vram,
13
+ )
14
+
15
+ CADENCE = 2.0 / 1440.0
16
+
17
+
18
+ def _make_star(n_points: int = 18000, seed: int = 0):
19
+ """Create a clean synthetic star as numpy arrays."""
20
+ rng = np.random.default_rng(seed)
21
+ time = np.arange(n_points, dtype=np.float64) * CADENCE
22
+ flux = (1.0 + rng.normal(0, 0.001, n_points)).astype(np.float32)
23
+ quality = np.zeros(n_points, dtype=np.int32)
24
+ return time, flux, quality
25
+
26
+
27
+ class TestCpuPrescan:
28
+
29
+ def test_clean_star(self):
30
+ t, f, q = _make_star(18000)
31
+ results = cpu_prescan([t], [f], [q])
32
+ assert len(results) == 1
33
+ r = results[0]
34
+ assert r["index"] == 0
35
+ assert r["n_valid"] == 18000
36
+ assert r["degenerate"] is False
37
+ assert r["post_filter_length"] >= r["n_valid"]
38
+
39
+ def test_degenerate_too_few_points(self):
40
+ t, f, q = _make_star(50)
41
+ results = cpu_prescan([t], [f], [q])
42
+ assert results[0]["degenerate"] is True
43
+ assert results[0]["degenerate_reason"] == "too_few_valid_points"
44
+
45
+ def test_degenerate_segment_too_short(self):
46
+ # Star with many small segments, none long enough for biweight window
47
+ n = 5000
48
+ t = np.arange(n, dtype=np.float64) * CADENCE
49
+ # Insert large gaps every 100 points -> segments of ~100, window=360
50
+ for i in range(100, n, 100):
51
+ t[i:] += 20 * CADENCE
52
+ f = np.ones(n, dtype=np.float32)
53
+ q = np.zeros(n, dtype=np.int32)
54
+ results = cpu_prescan([t], [f], [q], window_samples=360)
55
+ assert results[0]["degenerate"] is True
56
+ assert results[0]["degenerate_reason"] == "segment_too_short"
57
+
58
+ def test_multiple_stars(self):
59
+ stars = [_make_star(n, seed=i) for i, n in enumerate([18000, 50, 15000])]
60
+ ts, fs, qs = zip(*stars)
61
+ results = cpu_prescan(list(ts), list(fs), list(qs))
62
+ assert len(results) == 3
63
+ assert results[0]["degenerate"] is False
64
+ assert results[1]["degenerate"] is True
65
+ assert results[2]["degenerate"] is False
66
+
67
+
68
+ class TestBucketStars:
69
+
70
+ def test_bucketing_correctness(self):
71
+ prescan = [
72
+ {"index": 0, "post_filter_length": 14500, "degenerate": False},
73
+ {"index": 1, "post_filter_length": 15500, "degenerate": False},
74
+ {"index": 2, "post_filter_length": 14800, "degenerate": False},
75
+ {"index": 3, "post_filter_length": 50, "degenerate": True, "degenerate_reason": "x"},
76
+ ]
77
+ buckets = bucket_stars(prescan, bucket_width=1000)
78
+ # Star 3 is degenerate -> excluded
79
+ total_stars = sum(len(b["star_indices"]) for b in buckets)
80
+ assert total_stars == 3
81
+
82
+ def test_bucket_width(self):
83
+ prescan = [
84
+ {"index": i, "post_filter_length": 14000 + i * 200, "degenerate": False}
85
+ for i in range(10)
86
+ ]
87
+ buckets = bucket_stars(prescan, bucket_width=1000)
88
+ for b in buckets:
89
+ # All stars in bucket should fit within pad_length
90
+ for idx in b["star_indices"]:
91
+ pfl = prescan[idx]["post_filter_length"]
92
+ assert pfl <= b["pad_length"]
93
+
94
+
95
+ class TestAssembleBatch:
96
+
97
+ def test_padding_no_data_leakage(self):
98
+ t1, f1, q1 = _make_star(100, seed=1)
99
+ t2, f2, q2 = _make_star(80, seed=2)
100
+ batch = assemble_batch([0, 1], [t1, t2], [f1, f2], [q1, q2], 120, torch.device("cpu"))
101
+ # Padded positions should be zero
102
+ assert (batch["flux"][0, 100:] == 0).all()
103
+ assert (batch["flux"][1, 80:] == 0).all()
104
+
105
+ def test_padding_mask(self):
106
+ t, f, q = _make_star(100)
107
+ batch = assemble_batch([0], [t], [f], [q], 150, torch.device("cpu"))
108
+ assert batch["valid_mask"][0, :100].all()
109
+ assert not batch["valid_mask"][0, 100:].any()
110
+
111
+ def test_lengths(self):
112
+ stars = [_make_star(n) for n in [100, 200, 150]]
113
+ ts, fs, qs = zip(*stars)
114
+ batch = assemble_batch([0, 1, 2], list(ts), list(fs), list(qs), 250, torch.device("cpu"))
115
+ assert batch["lengths"].tolist() == [100, 200, 150]
116
+
117
+
118
+ class TestVramEstimation:
119
+
120
+ def test_monotonic(self):
121
+ v1 = estimate_peak_vram(15000, 360)
122
+ v2 = estimate_peak_vram(20000, 360)
123
+ assert v2 > v1
124
+
125
+ def test_reasonable_range(self):
126
+ v = estimate_peak_vram(20000, 360)
127
+ # Should be between 100MB and 500MB per star
128
+ assert 100 * 1024**2 < v < 500 * 1024**2
129
+
130
+
131
+ class TestComputeMaxBatch:
132
+
133
+ def test_override(self):
134
+ assert compute_max_batch(20000, max_batch_override=10) == 10
135
+
136
+ def test_budget(self):
137
+ mb_12 = compute_max_batch(20000, vram_budget_gb=12.0)
138
+ mb_6 = compute_max_batch(20000, vram_budget_gb=6.0)
139
+ assert mb_12 > mb_6
140
+ assert mb_6 >= 1
141
+
142
+ def test_cpu_fallback(self):
143
+ assert compute_max_batch(20000, device=torch.device("cpu")) == 1