torchflat 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchflat/__init__.py +15 -0
- torchflat/_kernel_loader.py +289 -0
- torchflat/_utils.py +99 -0
- torchflat/batching.py +238 -0
- torchflat/cli.py +535 -0
- torchflat/clipping.py +85 -0
- torchflat/csrc/build/test_combined.cpp +29 -0
- torchflat/csrc/build/test_error_check.cpp +47 -0
- torchflat/csrc/build/test_kernel.cpp +29 -0
- torchflat/csrc/build/umi_kernel_hip.cpp +258 -0
- torchflat/csrc/masked_median_kernel_hip.cpp +202 -0
- torchflat/csrc/umi_ext.cpp +24 -0
- torchflat/csrc/umi_kernel.cu +490 -0
- torchflat/gaps.py +146 -0
- torchflat/highpass.py +146 -0
- torchflat/normalize.py +52 -0
- torchflat/pipeline.py +604 -0
- torchflat/py.typed +0 -0
- torchflat/quality.py +30 -0
- torchflat/umi.py +185 -0
- torchflat/windows.py +87 -0
- torchflat-0.8.0.dist-info/METADATA +234 -0
- torchflat-0.8.0.dist-info/RECORD +27 -0
- torchflat-0.8.0.dist-info/WHEEL +5 -0
- torchflat-0.8.0.dist-info/entry_points.txt +2 -0
- torchflat-0.8.0.dist-info/licenses/LICENSE +21 -0
- torchflat-0.8.0.dist-info/top_level.txt +1 -0
torchflat/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""TorchFlat: GPU-native photometric preprocessing for exoplanet transit searches."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.8.0"
|
|
4
|
+
|
|
5
|
+
from torchflat.umi import umi_detrend
|
|
6
|
+
from torchflat.pipeline import preprocess_sector, preprocess_track_a, preprocess_track_b
|
|
7
|
+
from torchflat.windows import DEFAULT_WINDOW_SCALES
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"umi_detrend",
|
|
11
|
+
"preprocess_sector",
|
|
12
|
+
"preprocess_track_a",
|
|
13
|
+
"preprocess_track_b",
|
|
14
|
+
"DEFAULT_WINDOW_SCALES",
|
|
15
|
+
]
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""JIT-compile and load UMI CUDA/HIP kernels.
|
|
2
|
+
|
|
3
|
+
Provides two kernels:
|
|
4
|
+
- ``masked_median``: O(n) median via quickselect (legacy, used by rolling_clip)
|
|
5
|
+
- ``umi_median_mad``: O(n) median + MAD in single call (used by umi_detrend)
|
|
6
|
+
|
|
7
|
+
Kernels are compiled on first use and cached for subsequent imports.
|
|
8
|
+
Falls back gracefully if compilation fails (no GPU, no toolkit, etc.).
|
|
9
|
+
Set TORCHFLAT_NO_KERNEL=1 to disable kernels entirely.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import ctypes
|
|
15
|
+
import importlib.util
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import subprocess
|
|
19
|
+
import sys
|
|
20
|
+
import sysconfig
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger("torchflat")
|
|
24
|
+
|
|
25
|
+
_umi_kernel_module = None
|
|
26
|
+
_umi_kernel_load_attempted = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _short_path(p: str) -> str:
|
|
30
|
+
"""Get Windows 8.3 short path (avoids spaces breaking compiler args)."""
|
|
31
|
+
if sys.platform != "win32":
|
|
32
|
+
return p
|
|
33
|
+
buf = ctypes.create_unicode_buffer(260)
|
|
34
|
+
ctypes.windll.kernel32.GetShortPathNameW(str(p), buf, 260)
|
|
35
|
+
return buf.value or str(p)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# UMI median+MAD kernel (used by umi_detrend)
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
|
|
42
|
+
def _get_umi_kernel():
|
|
43
|
+
"""Load the UMI median+MAD kernel. Returns None if unavailable."""
|
|
44
|
+
global _umi_kernel_module, _umi_kernel_load_attempted
|
|
45
|
+
|
|
46
|
+
if _umi_kernel_load_attempted:
|
|
47
|
+
return _umi_kernel_module
|
|
48
|
+
|
|
49
|
+
_umi_kernel_load_attempted = True
|
|
50
|
+
|
|
51
|
+
import torch
|
|
52
|
+
|
|
53
|
+
if not torch.cuda.is_available():
|
|
54
|
+
return None
|
|
55
|
+
if os.environ.get("TORCHFLAT_NO_KERNEL", "0") == "1":
|
|
56
|
+
logger.warning(
|
|
57
|
+
"UMI kernel disabled by TORCHFLAT_NO_KERNEL=1. "
|
|
58
|
+
"Using torch.sort fallback (6x slower). "
|
|
59
|
+
"Unset the variable to enable the kernel."
|
|
60
|
+
)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
csrc_dir = Path(__file__).parent / "csrc"
|
|
64
|
+
build_dir = csrc_dir / "build"
|
|
65
|
+
pyd_name = "torchflat_umi_ext"
|
|
66
|
+
pyd_path = build_dir / f"{pyd_name}.pyd"
|
|
67
|
+
|
|
68
|
+
# Ensure ROCm DLLs are findable (amdhip64.dll etc.)
|
|
69
|
+
_add_rocm_dll_dirs()
|
|
70
|
+
|
|
71
|
+
if pyd_path.exists():
|
|
72
|
+
try:
|
|
73
|
+
spec = importlib.util.spec_from_file_location(pyd_name, str(pyd_path))
|
|
74
|
+
mod = importlib.util.module_from_spec(spec)
|
|
75
|
+
spec.loader.exec_module(mod)
|
|
76
|
+
logger.info("Loaded cached UMI kernel from %s", pyd_path)
|
|
77
|
+
_umi_kernel_module = mod
|
|
78
|
+
return _umi_kernel_module
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.warning("Failed to load cached UMI kernel: %s", e)
|
|
81
|
+
pyd_path.unlink(missing_ok=True)
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
if torch.version.hip:
|
|
85
|
+
_umi_kernel_module = _compile_hip_rocm72(
|
|
86
|
+
csrc_dir, pyd_name,
|
|
87
|
+
csrc_dir / "build" / "umi_kernel_hip.cpp",
|
|
88
|
+
csrc_dir / "umi_ext.cpp",
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
_umi_kernel_module = _compile_cuda_umi(csrc_dir)
|
|
92
|
+
if _umi_kernel_module is not None:
|
|
93
|
+
logger.info("UMI kernel compiled and loaded successfully")
|
|
94
|
+
except Exception as e:
|
|
95
|
+
logger.warning("Failed to compile UMI kernel: %s", e)
|
|
96
|
+
_umi_kernel_module = None
|
|
97
|
+
|
|
98
|
+
return _umi_kernel_module
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ---------------------------------------------------------------------------
|
|
102
|
+
# Compilation backends
|
|
103
|
+
# ---------------------------------------------------------------------------
|
|
104
|
+
|
|
105
|
+
def _compile_hip_rocm72(csrc_dir: Path, pyd_name: str, hip_src: Path, ext_src: Path):
|
|
106
|
+
"""Compile a HIP kernel on Windows using ROCm 7.2 SDK from pip."""
|
|
107
|
+
import torch
|
|
108
|
+
|
|
109
|
+
build_dir = csrc_dir / "build"
|
|
110
|
+
build_dir.mkdir(exist_ok=True)
|
|
111
|
+
|
|
112
|
+
pyd_path = build_dir / f"{pyd_name}.pyd"
|
|
113
|
+
|
|
114
|
+
rocm_sdk_path = _find_rocm72_sdk()
|
|
115
|
+
if rocm_sdk_path is None:
|
|
116
|
+
raise RuntimeError(
|
|
117
|
+
"ROCm 7.2 SDK not found. Install it with:\n"
|
|
118
|
+
" pip install https://repo.radeon.com/rocm/windows/rocm-rel-7.2/"
|
|
119
|
+
"rocm_sdk_devel-7.2.0.dev0-py3-none-win_amd64.whl\n"
|
|
120
|
+
" pip install https://repo.radeon.com/rocm/windows/rocm-rel-7.2/"
|
|
121
|
+
"rocm_sdk_core-7.2.0.dev0-py3-none-win_amd64.whl"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
clang = str(rocm_sdk_path / "lib" / "llvm" / "bin" / "amdclang++.exe")
|
|
125
|
+
if not os.path.exists(clang):
|
|
126
|
+
raise RuntimeError(f"amdclang++ not found at {clang}")
|
|
127
|
+
|
|
128
|
+
if not hip_src.exists():
|
|
129
|
+
raise RuntimeError(f"Hipified kernel source not found at {hip_src}")
|
|
130
|
+
|
|
131
|
+
rocm_sp = _short_path(str(rocm_sdk_path))
|
|
132
|
+
device_lib = _short_path(str(rocm_sdk_path / "lib" / "llvm" / "amdgcn" / "bitcode"))
|
|
133
|
+
|
|
134
|
+
_setup_msvc_env()
|
|
135
|
+
|
|
136
|
+
torch_dir = Path(torch.__file__).parent
|
|
137
|
+
python_inc = _short_path(sysconfig.get_path("include"))
|
|
138
|
+
python_lib = _short_path(str(Path(sysconfig.get_path("stdlib")).parent / "libs"))
|
|
139
|
+
torch_lib = _short_path(str(torch_dir / "lib"))
|
|
140
|
+
|
|
141
|
+
inc = [
|
|
142
|
+
f"-I{_short_path(str(torch_dir / 'include'))}",
|
|
143
|
+
f"-I{_short_path(str(torch_dir / 'include' / 'torch' / 'csrc' / 'api' / 'include'))}",
|
|
144
|
+
f"-I{python_inc}",
|
|
145
|
+
f"-I{rocm_sp}/include",
|
|
146
|
+
]
|
|
147
|
+
defs = [
|
|
148
|
+
"-D__HIP_PLATFORM_AMD__",
|
|
149
|
+
f"-DTORCH_EXTENSION_NAME={pyd_name}",
|
|
150
|
+
"-DTORCH_API_INCLUDE_EXTENSION_H",
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
# Step 1: Compile kernel (.cpp as HIP)
|
|
154
|
+
kernel_obj = build_dir / f"{pyd_name}_kernel.o"
|
|
155
|
+
logger.info("Compiling %s kernel with amdclang++...", pyd_name)
|
|
156
|
+
_run_cmd([
|
|
157
|
+
clang, "-O3", "-c", "-x", "hip",
|
|
158
|
+
f"--rocm-path={rocm_sp}",
|
|
159
|
+
f"--rocm-device-lib-path={device_lib}",
|
|
160
|
+
"--offload-arch=gfx1200",
|
|
161
|
+
*defs, *inc, "-std=c++17", "-w",
|
|
162
|
+
str(hip_src), "-o", str(kernel_obj),
|
|
163
|
+
])
|
|
164
|
+
|
|
165
|
+
# Step 2: Compile binding (.cpp as C++)
|
|
166
|
+
ext_obj = build_dir / f"{pyd_name}_ext.o"
|
|
167
|
+
logger.info("Compiling %s binding...", pyd_name)
|
|
168
|
+
_run_cmd([
|
|
169
|
+
clang, "-O3", "-c",
|
|
170
|
+
*defs, *inc, "-std=c++17", "-w",
|
|
171
|
+
str(ext_src), "-o", str(ext_obj),
|
|
172
|
+
])
|
|
173
|
+
|
|
174
|
+
# Step 3: Link into .pyd
|
|
175
|
+
logger.info("Linking %s...", pyd_name)
|
|
176
|
+
_run_cmd([
|
|
177
|
+
clang, "-shared",
|
|
178
|
+
str(kernel_obj), str(ext_obj),
|
|
179
|
+
f"-L{torch_lib}", "-ltorch", "-ltorch_cpu", "-ltorch_python", "-lc10", "-lc10_hip",
|
|
180
|
+
f"-L{python_lib}", f"-lpython{sys.version_info.major}{sys.version_info.minor}",
|
|
181
|
+
f"-L{rocm_sp}/lib", "-lamdhip64",
|
|
182
|
+
"-o", str(pyd_path),
|
|
183
|
+
])
|
|
184
|
+
|
|
185
|
+
spec = importlib.util.spec_from_file_location(pyd_name, str(pyd_path))
|
|
186
|
+
mod = importlib.util.module_from_spec(spec)
|
|
187
|
+
spec.loader.exec_module(mod)
|
|
188
|
+
return mod
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _compile_cuda_umi(csrc_dir: Path):
|
|
193
|
+
"""Compile UMI median+MAD CUDA kernel."""
|
|
194
|
+
import torch.utils.cpp_extension as _ext
|
|
195
|
+
return _ext.load(
|
|
196
|
+
name="torchflat_umi_ext",
|
|
197
|
+
sources=[
|
|
198
|
+
str(csrc_dir / "umi_ext.cpp"),
|
|
199
|
+
str(csrc_dir / "umi_kernel.cu"),
|
|
200
|
+
],
|
|
201
|
+
extra_cuda_cflags=["-O3", "--use_fast_math"],
|
|
202
|
+
verbose=False,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# ---------------------------------------------------------------------------
|
|
207
|
+
# Shared utilities
|
|
208
|
+
# ---------------------------------------------------------------------------
|
|
209
|
+
|
|
210
|
+
def _add_rocm_dll_dirs():
|
|
211
|
+
"""Add ROCm SDK lib directories to DLL search path (Windows).
|
|
212
|
+
|
|
213
|
+
Required so the compiled .pyd can find amdhip64.dll at load time.
|
|
214
|
+
"""
|
|
215
|
+
if sys.platform != "win32":
|
|
216
|
+
return
|
|
217
|
+
sdk = _find_rocm72_sdk()
|
|
218
|
+
if sdk is None:
|
|
219
|
+
return
|
|
220
|
+
lib_dir = sdk / "lib"
|
|
221
|
+
bin_dir = sdk / "bin"
|
|
222
|
+
for d in [lib_dir, bin_dir]:
|
|
223
|
+
if d.exists():
|
|
224
|
+
try:
|
|
225
|
+
os.add_dll_directory(str(d))
|
|
226
|
+
except OSError:
|
|
227
|
+
pass
|
|
228
|
+
# Also add to PATH as fallback
|
|
229
|
+
if str(d) not in os.environ.get("PATH", ""):
|
|
230
|
+
os.environ["PATH"] = str(d) + os.pathsep + os.environ.get("PATH", "")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _find_rocm72_sdk() -> Path | None:
|
|
234
|
+
"""Find the ROCm 7.2 SDK installed via pip (rocm-sdk-core package)."""
|
|
235
|
+
candidates = [
|
|
236
|
+
Path(sys.prefix) / "Lib" / "site-packages" / "_rocm_sdk_core",
|
|
237
|
+
Path(os.path.expanduser("~")) / "AppData" / "Roaming" / "Python"
|
|
238
|
+
/ f"Python{sys.version_info.major}{sys.version_info.minor}"
|
|
239
|
+
/ "site-packages" / "_rocm_sdk_core",
|
|
240
|
+
]
|
|
241
|
+
try:
|
|
242
|
+
import _rocm_sdk_core
|
|
243
|
+
candidates.insert(0, Path(_rocm_sdk_core.__file__).parent)
|
|
244
|
+
except ImportError:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
for p in candidates:
|
|
248
|
+
clang = p / "lib" / "llvm" / "bin" / "amdclang++.exe"
|
|
249
|
+
if clang.exists():
|
|
250
|
+
return p
|
|
251
|
+
return None
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _setup_msvc_env():
|
|
255
|
+
"""Setup MSVC compiler environment on Windows."""
|
|
256
|
+
if sys.platform != "win32":
|
|
257
|
+
return
|
|
258
|
+
if os.environ.get("VSINSTALLDIR"):
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
vcvars_candidates = [
|
|
262
|
+
r"C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat",
|
|
263
|
+
r"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvarsall.bat",
|
|
264
|
+
r"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat",
|
|
265
|
+
]
|
|
266
|
+
for vcvars in vcvars_candidates:
|
|
267
|
+
if os.path.exists(vcvars):
|
|
268
|
+
result = subprocess.run(
|
|
269
|
+
f'cmd /c ""{vcvars}" x64 && set"',
|
|
270
|
+
capture_output=True, text=True, shell=True,
|
|
271
|
+
)
|
|
272
|
+
for line in result.stdout.splitlines():
|
|
273
|
+
if "=" in line:
|
|
274
|
+
k, _, v = line.partition("=")
|
|
275
|
+
os.environ[k] = v
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
logger.warning("MSVC not found, linking may fail")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _run_cmd(cmd: list[str]):
|
|
282
|
+
"""Run a command, raising RuntimeError on failure."""
|
|
283
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
284
|
+
if result.returncode != 0:
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
f"Command failed (return {result.returncode}):\n"
|
|
287
|
+
f" {' '.join(cmd[:5])}...\n"
|
|
288
|
+
f" {result.stderr[-500:]}"
|
|
289
|
+
)
|
torchflat/_utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Shared utilities: masked median, padding helpers, constants."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
# Constants
|
|
9
|
+
# ---------------------------------------------------------------------------
|
|
10
|
+
|
|
11
|
+
# Mission-specific quality bitmasks
|
|
12
|
+
TESS_QUALITY_BITMASK: int = 0b0000110101111111 # = 3455
|
|
13
|
+
KEPLER_QUALITY_BITMASK: int = 0b0001111111111111 # = 8191 (all Kepler quality flags)
|
|
14
|
+
|
|
15
|
+
# Default is TESS (most common use case)
|
|
16
|
+
QUALITY_BITMASK: int = TESS_QUALITY_BITMASK
|
|
17
|
+
|
|
18
|
+
MIN_POINTS: int = 100 # Minimum valid points for a star to be processed
|
|
19
|
+
GAP_THRESHOLD: float = 5.0 # Gap ratio above this = large gap (segment boundary)
|
|
20
|
+
MIN_SEGMENT_LENGTH: int = 50 # Minimum valid points in a biweight window's segment
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ---------------------------------------------------------------------------
|
|
24
|
+
# Masked median
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
|
|
27
|
+
def masked_median(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
28
|
+
"""Compute median along the last dimension, only over positions where *mask* is True.
|
|
29
|
+
|
|
30
|
+
Matches numpy.median convention for even-length arrays (averages two
|
|
31
|
+
middle values).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
x: Tensor of any shape ``[..., N]``.
|
|
35
|
+
mask: Boolean tensor, same shape as *x*. True = valid.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Tensor of shape ``[...]`` with the median of valid values along the
|
|
39
|
+
last dimension. Returns NaN where no valid values exist.
|
|
40
|
+
"""
|
|
41
|
+
if x.shape[-1] == 0:
|
|
42
|
+
out_shape = x.shape[:-1]
|
|
43
|
+
return torch.full(out_shape, float("nan"), dtype=x.dtype, device=x.device)
|
|
44
|
+
|
|
45
|
+
W = x.shape[-1]
|
|
46
|
+
|
|
47
|
+
# torch.sort (O(n log n))
|
|
48
|
+
# Clone and push invalid values to +inf so they sort last
|
|
49
|
+
working = x.clone()
|
|
50
|
+
working[~mask] = float("inf")
|
|
51
|
+
|
|
52
|
+
n_valid = mask.sum(dim=-1) # [...]
|
|
53
|
+
|
|
54
|
+
# Sort ascending: valid values first, +inf values last
|
|
55
|
+
sorted_vals = torch.sort(working, dim=-1).values
|
|
56
|
+
|
|
57
|
+
# Median indices (numpy even-length convention: average two middle values)
|
|
58
|
+
mid_lo = ((n_valid - 1) // 2).clamp(min=0)
|
|
59
|
+
mid_hi = (n_valid // 2).clamp(min=0)
|
|
60
|
+
|
|
61
|
+
val_lo = sorted_vals.gather(-1, mid_lo.unsqueeze(-1)).squeeze(-1)
|
|
62
|
+
val_hi = sorted_vals.gather(-1, mid_hi.unsqueeze(-1)).squeeze(-1)
|
|
63
|
+
median = (val_lo + val_hi) / 2.0
|
|
64
|
+
|
|
65
|
+
# Guard: no valid data -> NaN
|
|
66
|
+
median = median.where(n_valid > 0, torch.tensor(float("nan"), dtype=median.dtype, device=median.device))
|
|
67
|
+
|
|
68
|
+
return median
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
# Padding helper
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
def pad_to_length(
|
|
76
|
+
tensors: list[torch.Tensor],
|
|
77
|
+
target_len: int,
|
|
78
|
+
pad_value: float = 0.0,
|
|
79
|
+
) -> torch.Tensor:
|
|
80
|
+
"""Right-pad a list of 1-D tensors to *target_len* and stack into 2-D.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
tensors: List of 1-D tensors (lengths ``<= target_len``).
|
|
84
|
+
target_len: Length to pad to.
|
|
85
|
+
pad_value: Fill value for padding positions.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tensor of shape ``[len(tensors), target_len]``.
|
|
89
|
+
"""
|
|
90
|
+
B = len(tensors)
|
|
91
|
+
out = torch.full(
|
|
92
|
+
(B, target_len),
|
|
93
|
+
fill_value=pad_value,
|
|
94
|
+
dtype=tensors[0].dtype,
|
|
95
|
+
device=tensors[0].device,
|
|
96
|
+
)
|
|
97
|
+
for i, t in enumerate(tensors):
|
|
98
|
+
out[i, : t.shape[0]] = t
|
|
99
|
+
return out
|
torchflat/batching.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
"""CPU pre-scan, length-bucketed batch assembly, VRAM estimation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from torchflat._utils import GAP_THRESHOLD, MIN_POINTS, QUALITY_BITMASK
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("torchflat")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# ---------------------------------------------------------------------------
|
|
17
|
+
# CPU pre-scan
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
|
|
20
|
+
def cpu_prescan(
|
|
21
|
+
times: list[np.ndarray],
|
|
22
|
+
fluxes: list[np.ndarray],
|
|
23
|
+
qualities: list[np.ndarray],
|
|
24
|
+
bitmask: int = QUALITY_BITMASK,
|
|
25
|
+
gap_threshold: float = GAP_THRESHOLD,
|
|
26
|
+
min_points: int = MIN_POINTS,
|
|
27
|
+
window_samples: int = 360,
|
|
28
|
+
) -> list[dict]:
|
|
29
|
+
"""CPU pre-scan: compute post-filter length and flag degenerate stars.
|
|
30
|
+
|
|
31
|
+
Runs a fast O(N) scan per star on CPU to determine quality counts, gap
|
|
32
|
+
insertions, and whether the star is degenerate (too few points or
|
|
33
|
+
segments too short for the biweight window).
|
|
34
|
+
"""
|
|
35
|
+
results: list[dict] = []
|
|
36
|
+
|
|
37
|
+
for i, (t, f, q) in enumerate(zip(times, fluxes, qualities)):
|
|
38
|
+
valid = ((q & bitmask) == 0) & np.isfinite(f) & np.isfinite(t)
|
|
39
|
+
n_valid = int(valid.sum())
|
|
40
|
+
|
|
41
|
+
if n_valid < 2:
|
|
42
|
+
results.append({
|
|
43
|
+
"index": i,
|
|
44
|
+
"n_valid": n_valid,
|
|
45
|
+
"n_insertions": 0,
|
|
46
|
+
"post_filter_length": n_valid,
|
|
47
|
+
"max_segment_length": n_valid,
|
|
48
|
+
"degenerate": True,
|
|
49
|
+
"degenerate_reason": "too_few_valid_points",
|
|
50
|
+
})
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
t_valid = t[valid]
|
|
54
|
+
dt = np.diff(t_valid)
|
|
55
|
+
med_cadence = float(np.median(dt))
|
|
56
|
+
if med_cadence <= 0:
|
|
57
|
+
med_cadence = 1e-10
|
|
58
|
+
|
|
59
|
+
gap_ratio = dt / med_cadence
|
|
60
|
+
|
|
61
|
+
# Count small-gap interpolation insertions
|
|
62
|
+
small_gaps = (gap_ratio > 1.5) & (gap_ratio < gap_threshold)
|
|
63
|
+
n_insertions = 0
|
|
64
|
+
for gr in gap_ratio[small_gaps]:
|
|
65
|
+
n_insertions += int(round(gr)) - 1
|
|
66
|
+
|
|
67
|
+
# Longest segment
|
|
68
|
+
large_gap_pos = np.where(gap_ratio >= gap_threshold)[0]
|
|
69
|
+
boundaries = np.concatenate([[-1], large_gap_pos, [len(dt)]])
|
|
70
|
+
segment_lengths = np.diff(boundaries)
|
|
71
|
+
max_segment = int(segment_lengths.max()) if len(segment_lengths) > 0 else n_valid
|
|
72
|
+
|
|
73
|
+
post_filter_length = n_valid + n_insertions
|
|
74
|
+
|
|
75
|
+
degenerate = False
|
|
76
|
+
reason = None
|
|
77
|
+
if n_valid < min_points:
|
|
78
|
+
degenerate = True
|
|
79
|
+
reason = "too_few_valid_points"
|
|
80
|
+
elif max_segment < window_samples:
|
|
81
|
+
degenerate = True
|
|
82
|
+
reason = "segment_too_short"
|
|
83
|
+
|
|
84
|
+
results.append({
|
|
85
|
+
"index": i,
|
|
86
|
+
"n_valid": n_valid,
|
|
87
|
+
"n_insertions": n_insertions,
|
|
88
|
+
"post_filter_length": post_filter_length,
|
|
89
|
+
"max_segment_length": max_segment,
|
|
90
|
+
"degenerate": degenerate,
|
|
91
|
+
"degenerate_reason": reason,
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
return results
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ---------------------------------------------------------------------------
|
|
98
|
+
# Bucketing
|
|
99
|
+
# ---------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
def bucket_stars(
|
|
102
|
+
prescan_results: list[dict],
|
|
103
|
+
bucket_width: int = 1000,
|
|
104
|
+
) -> list[dict]:
|
|
105
|
+
"""Group non-degenerate stars into length buckets for batched GPU processing."""
|
|
106
|
+
buckets_map: dict[int, list[int]] = defaultdict(list)
|
|
107
|
+
|
|
108
|
+
for info in prescan_results:
|
|
109
|
+
if info["degenerate"]:
|
|
110
|
+
continue
|
|
111
|
+
pfl = info["post_filter_length"]
|
|
112
|
+
bucket_key = (pfl // bucket_width) * bucket_width + bucket_width
|
|
113
|
+
buckets_map[bucket_key].append(info["index"])
|
|
114
|
+
|
|
115
|
+
buckets: list[dict] = []
|
|
116
|
+
for pad_length in sorted(buckets_map.keys()):
|
|
117
|
+
buckets.append({
|
|
118
|
+
"star_indices": buckets_map[pad_length],
|
|
119
|
+
"pad_length": pad_length,
|
|
120
|
+
})
|
|
121
|
+
return buckets
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# ---------------------------------------------------------------------------
|
|
125
|
+
# Batch assembly
|
|
126
|
+
# ---------------------------------------------------------------------------
|
|
127
|
+
|
|
128
|
+
def assemble_batch(
|
|
129
|
+
star_indices: list[int],
|
|
130
|
+
times: list[np.ndarray],
|
|
131
|
+
fluxes: list[np.ndarray],
|
|
132
|
+
qualities: list[np.ndarray],
|
|
133
|
+
pad_length: int,
|
|
134
|
+
device: torch.device,
|
|
135
|
+
) -> dict:
|
|
136
|
+
"""Pad, stack, and transfer a batch of stars to GPU.
|
|
137
|
+
|
|
138
|
+
Returns dict with tensors ``time``, ``flux``, ``quality``, ``lengths``,
|
|
139
|
+
and ``valid_mask`` on *device*.
|
|
140
|
+
"""
|
|
141
|
+
B = len(star_indices)
|
|
142
|
+
|
|
143
|
+
time_batch = torch.zeros(B, pad_length, dtype=torch.float64)
|
|
144
|
+
flux_batch = torch.zeros(B, pad_length, dtype=torch.float32)
|
|
145
|
+
quality_batch = torch.zeros(B, pad_length, dtype=torch.int32)
|
|
146
|
+
lengths = torch.zeros(B, dtype=torch.long)
|
|
147
|
+
valid_mask = torch.zeros(B, pad_length, dtype=torch.bool)
|
|
148
|
+
|
|
149
|
+
for j, idx in enumerate(star_indices):
|
|
150
|
+
t = times[idx]
|
|
151
|
+
f = fluxes[idx].astype(np.float32)
|
|
152
|
+
q = qualities[idx]
|
|
153
|
+
L = len(t)
|
|
154
|
+
n = min(L, pad_length)
|
|
155
|
+
|
|
156
|
+
# CPU gap interpolation (faster than GPU cummax/cummin, <1% of points)
|
|
157
|
+
v = ((q[:n] & QUALITY_BITMASK) == 0) & np.isfinite(f[:n]) & np.isfinite(t[:n])
|
|
158
|
+
i = 0
|
|
159
|
+
while i < n:
|
|
160
|
+
if not v[i]:
|
|
161
|
+
gs = i
|
|
162
|
+
while i < n and not v[i]:
|
|
163
|
+
i += 1
|
|
164
|
+
ge = i
|
|
165
|
+
if gs > 0 and ge < n and (ge - gs) <= 4:
|
|
166
|
+
for k in range(ge - gs):
|
|
167
|
+
frac = (k + 1) / (ge - gs + 1)
|
|
168
|
+
f[gs + k] = f[gs - 1] + frac * (f[ge] - f[gs - 1])
|
|
169
|
+
v[gs + k] = True
|
|
170
|
+
else:
|
|
171
|
+
i += 1
|
|
172
|
+
|
|
173
|
+
time_batch[j, :n] = torch.from_numpy(t[:n].astype(np.float64))
|
|
174
|
+
flux_batch[j, :n] = torch.from_numpy(f[:n])
|
|
175
|
+
quality_batch[j, :n] = torch.from_numpy(q[:n].astype(np.int32))
|
|
176
|
+
lengths[j] = n
|
|
177
|
+
valid_mask[j, :n] = True
|
|
178
|
+
|
|
179
|
+
return {
|
|
180
|
+
"time": time_batch.to(device),
|
|
181
|
+
"flux": flux_batch.to(device),
|
|
182
|
+
"quality": quality_batch.to(device),
|
|
183
|
+
"lengths": lengths.to(device),
|
|
184
|
+
"valid_mask": valid_mask.to(device),
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# ---------------------------------------------------------------------------
|
|
189
|
+
# VRAM estimation
|
|
190
|
+
# ---------------------------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
def estimate_peak_vram(L: int, win: int, dtype_bytes: int = 4) -> int:
|
|
193
|
+
"""Estimate peak VRAM per star (bytes) during biweight detrending.
|
|
194
|
+
|
|
195
|
+
Accounts for persistent window tensors, per-iteration intermediates,
|
|
196
|
+
and temporary topk/sort int64 indices.
|
|
197
|
+
"""
|
|
198
|
+
N_pos = L - win + 1
|
|
199
|
+
window_bytes = N_pos * win * dtype_bytes
|
|
200
|
+
bool_window_bytes = N_pos * win * 1
|
|
201
|
+
indices_bytes = N_pos * win * 8
|
|
202
|
+
seg_window_bytes = N_pos * win * 4 # int32 segment IDs
|
|
203
|
+
|
|
204
|
+
base = L * 17 # flux(4) + time(8) + valid(1) + seg_id(4)
|
|
205
|
+
persistent = window_bytes + seg_window_bytes + bool_window_bytes
|
|
206
|
+
per_iter = 4 * window_bytes # topk_clone + abs_dev + u + weights
|
|
207
|
+
peak_temp = indices_bytes
|
|
208
|
+
|
|
209
|
+
return base + persistent + per_iter + peak_temp
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def compute_max_batch(
|
|
213
|
+
pad_length: int,
|
|
214
|
+
win: int = 360,
|
|
215
|
+
device: torch.device | None = None,
|
|
216
|
+
vram_budget_gb: float | None = None,
|
|
217
|
+
max_batch_override: int | None = None,
|
|
218
|
+
safety_factor: float = 0.8,
|
|
219
|
+
) -> int:
|
|
220
|
+
"""Dynamic max_batch with 3-tier priority: override > budget > auto-detect."""
|
|
221
|
+
if max_batch_override is not None:
|
|
222
|
+
return max(1, max_batch_override)
|
|
223
|
+
|
|
224
|
+
if vram_budget_gb is not None:
|
|
225
|
+
available = int(vram_budget_gb * 1024**3)
|
|
226
|
+
elif device is not None and device.type == "cuda":
|
|
227
|
+
props = torch.cuda.get_device_properties(device)
|
|
228
|
+
total = props.total_memory
|
|
229
|
+
headroom = 4 * 1024**3
|
|
230
|
+
available = total - headroom
|
|
231
|
+
else:
|
|
232
|
+
return 1 # CPU fallback
|
|
233
|
+
|
|
234
|
+
peak_per_star = estimate_peak_vram(pad_length, win)
|
|
235
|
+
if peak_per_star <= 0:
|
|
236
|
+
return 1
|
|
237
|
+
max_batch = int(available * safety_factor / peak_per_star)
|
|
238
|
+
return max(1, min(max_batch, 50))
|