torchflat 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchflat/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """TorchFlat: GPU-native photometric preprocessing for exoplanet transit searches."""
2
+
3
+ __version__ = "0.8.0"
4
+
5
+ from torchflat.umi import umi_detrend
6
+ from torchflat.pipeline import preprocess_sector, preprocess_track_a, preprocess_track_b
7
+ from torchflat.windows import DEFAULT_WINDOW_SCALES
8
+
9
+ __all__ = [
10
+ "umi_detrend",
11
+ "preprocess_sector",
12
+ "preprocess_track_a",
13
+ "preprocess_track_b",
14
+ "DEFAULT_WINDOW_SCALES",
15
+ ]
@@ -0,0 +1,289 @@
1
+ """JIT-compile and load UMI CUDA/HIP kernels.
2
+
3
+ Provides two kernels:
4
+ - ``masked_median``: O(n) median via quickselect (legacy, used by rolling_clip)
5
+ - ``umi_median_mad``: O(n) median + MAD in single call (used by umi_detrend)
6
+
7
+ Kernels are compiled on first use and cached for subsequent imports.
8
+ Falls back gracefully if compilation fails (no GPU, no toolkit, etc.).
9
+ Set TORCHFLAT_NO_KERNEL=1 to disable kernels entirely.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import ctypes
15
+ import importlib.util
16
+ import logging
17
+ import os
18
+ import subprocess
19
+ import sys
20
+ import sysconfig
21
+ from pathlib import Path
22
+
23
+ logger = logging.getLogger("torchflat")
24
+
25
+ _umi_kernel_module = None
26
+ _umi_kernel_load_attempted = False
27
+
28
+
29
+ def _short_path(p: str) -> str:
30
+ """Get Windows 8.3 short path (avoids spaces breaking compiler args)."""
31
+ if sys.platform != "win32":
32
+ return p
33
+ buf = ctypes.create_unicode_buffer(260)
34
+ ctypes.windll.kernel32.GetShortPathNameW(str(p), buf, 260)
35
+ return buf.value or str(p)
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # UMI median+MAD kernel (used by umi_detrend)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def _get_umi_kernel():
43
+ """Load the UMI median+MAD kernel. Returns None if unavailable."""
44
+ global _umi_kernel_module, _umi_kernel_load_attempted
45
+
46
+ if _umi_kernel_load_attempted:
47
+ return _umi_kernel_module
48
+
49
+ _umi_kernel_load_attempted = True
50
+
51
+ import torch
52
+
53
+ if not torch.cuda.is_available():
54
+ return None
55
+ if os.environ.get("TORCHFLAT_NO_KERNEL", "0") == "1":
56
+ logger.warning(
57
+ "UMI kernel disabled by TORCHFLAT_NO_KERNEL=1. "
58
+ "Using torch.sort fallback (6x slower). "
59
+ "Unset the variable to enable the kernel."
60
+ )
61
+ return None
62
+
63
+ csrc_dir = Path(__file__).parent / "csrc"
64
+ build_dir = csrc_dir / "build"
65
+ pyd_name = "torchflat_umi_ext"
66
+ pyd_path = build_dir / f"{pyd_name}.pyd"
67
+
68
+ # Ensure ROCm DLLs are findable (amdhip64.dll etc.)
69
+ _add_rocm_dll_dirs()
70
+
71
+ if pyd_path.exists():
72
+ try:
73
+ spec = importlib.util.spec_from_file_location(pyd_name, str(pyd_path))
74
+ mod = importlib.util.module_from_spec(spec)
75
+ spec.loader.exec_module(mod)
76
+ logger.info("Loaded cached UMI kernel from %s", pyd_path)
77
+ _umi_kernel_module = mod
78
+ return _umi_kernel_module
79
+ except Exception as e:
80
+ logger.warning("Failed to load cached UMI kernel: %s", e)
81
+ pyd_path.unlink(missing_ok=True)
82
+
83
+ try:
84
+ if torch.version.hip:
85
+ _umi_kernel_module = _compile_hip_rocm72(
86
+ csrc_dir, pyd_name,
87
+ csrc_dir / "build" / "umi_kernel_hip.cpp",
88
+ csrc_dir / "umi_ext.cpp",
89
+ )
90
+ else:
91
+ _umi_kernel_module = _compile_cuda_umi(csrc_dir)
92
+ if _umi_kernel_module is not None:
93
+ logger.info("UMI kernel compiled and loaded successfully")
94
+ except Exception as e:
95
+ logger.warning("Failed to compile UMI kernel: %s", e)
96
+ _umi_kernel_module = None
97
+
98
+ return _umi_kernel_module
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Compilation backends
103
+ # ---------------------------------------------------------------------------
104
+
105
+ def _compile_hip_rocm72(csrc_dir: Path, pyd_name: str, hip_src: Path, ext_src: Path):
106
+ """Compile a HIP kernel on Windows using ROCm 7.2 SDK from pip."""
107
+ import torch
108
+
109
+ build_dir = csrc_dir / "build"
110
+ build_dir.mkdir(exist_ok=True)
111
+
112
+ pyd_path = build_dir / f"{pyd_name}.pyd"
113
+
114
+ rocm_sdk_path = _find_rocm72_sdk()
115
+ if rocm_sdk_path is None:
116
+ raise RuntimeError(
117
+ "ROCm 7.2 SDK not found. Install it with:\n"
118
+ " pip install https://repo.radeon.com/rocm/windows/rocm-rel-7.2/"
119
+ "rocm_sdk_devel-7.2.0.dev0-py3-none-win_amd64.whl\n"
120
+ " pip install https://repo.radeon.com/rocm/windows/rocm-rel-7.2/"
121
+ "rocm_sdk_core-7.2.0.dev0-py3-none-win_amd64.whl"
122
+ )
123
+
124
+ clang = str(rocm_sdk_path / "lib" / "llvm" / "bin" / "amdclang++.exe")
125
+ if not os.path.exists(clang):
126
+ raise RuntimeError(f"amdclang++ not found at {clang}")
127
+
128
+ if not hip_src.exists():
129
+ raise RuntimeError(f"Hipified kernel source not found at {hip_src}")
130
+
131
+ rocm_sp = _short_path(str(rocm_sdk_path))
132
+ device_lib = _short_path(str(rocm_sdk_path / "lib" / "llvm" / "amdgcn" / "bitcode"))
133
+
134
+ _setup_msvc_env()
135
+
136
+ torch_dir = Path(torch.__file__).parent
137
+ python_inc = _short_path(sysconfig.get_path("include"))
138
+ python_lib = _short_path(str(Path(sysconfig.get_path("stdlib")).parent / "libs"))
139
+ torch_lib = _short_path(str(torch_dir / "lib"))
140
+
141
+ inc = [
142
+ f"-I{_short_path(str(torch_dir / 'include'))}",
143
+ f"-I{_short_path(str(torch_dir / 'include' / 'torch' / 'csrc' / 'api' / 'include'))}",
144
+ f"-I{python_inc}",
145
+ f"-I{rocm_sp}/include",
146
+ ]
147
+ defs = [
148
+ "-D__HIP_PLATFORM_AMD__",
149
+ f"-DTORCH_EXTENSION_NAME={pyd_name}",
150
+ "-DTORCH_API_INCLUDE_EXTENSION_H",
151
+ ]
152
+
153
+ # Step 1: Compile kernel (.cpp as HIP)
154
+ kernel_obj = build_dir / f"{pyd_name}_kernel.o"
155
+ logger.info("Compiling %s kernel with amdclang++...", pyd_name)
156
+ _run_cmd([
157
+ clang, "-O3", "-c", "-x", "hip",
158
+ f"--rocm-path={rocm_sp}",
159
+ f"--rocm-device-lib-path={device_lib}",
160
+ "--offload-arch=gfx1200",
161
+ *defs, *inc, "-std=c++17", "-w",
162
+ str(hip_src), "-o", str(kernel_obj),
163
+ ])
164
+
165
+ # Step 2: Compile binding (.cpp as C++)
166
+ ext_obj = build_dir / f"{pyd_name}_ext.o"
167
+ logger.info("Compiling %s binding...", pyd_name)
168
+ _run_cmd([
169
+ clang, "-O3", "-c",
170
+ *defs, *inc, "-std=c++17", "-w",
171
+ str(ext_src), "-o", str(ext_obj),
172
+ ])
173
+
174
+ # Step 3: Link into .pyd
175
+ logger.info("Linking %s...", pyd_name)
176
+ _run_cmd([
177
+ clang, "-shared",
178
+ str(kernel_obj), str(ext_obj),
179
+ f"-L{torch_lib}", "-ltorch", "-ltorch_cpu", "-ltorch_python", "-lc10", "-lc10_hip",
180
+ f"-L{python_lib}", f"-lpython{sys.version_info.major}{sys.version_info.minor}",
181
+ f"-L{rocm_sp}/lib", "-lamdhip64",
182
+ "-o", str(pyd_path),
183
+ ])
184
+
185
+ spec = importlib.util.spec_from_file_location(pyd_name, str(pyd_path))
186
+ mod = importlib.util.module_from_spec(spec)
187
+ spec.loader.exec_module(mod)
188
+ return mod
189
+
190
+
191
+
192
+ def _compile_cuda_umi(csrc_dir: Path):
193
+ """Compile UMI median+MAD CUDA kernel."""
194
+ import torch.utils.cpp_extension as _ext
195
+ return _ext.load(
196
+ name="torchflat_umi_ext",
197
+ sources=[
198
+ str(csrc_dir / "umi_ext.cpp"),
199
+ str(csrc_dir / "umi_kernel.cu"),
200
+ ],
201
+ extra_cuda_cflags=["-O3", "--use_fast_math"],
202
+ verbose=False,
203
+ )
204
+
205
+
206
+ # ---------------------------------------------------------------------------
207
+ # Shared utilities
208
+ # ---------------------------------------------------------------------------
209
+
210
+ def _add_rocm_dll_dirs():
211
+ """Add ROCm SDK lib directories to DLL search path (Windows).
212
+
213
+ Required so the compiled .pyd can find amdhip64.dll at load time.
214
+ """
215
+ if sys.platform != "win32":
216
+ return
217
+ sdk = _find_rocm72_sdk()
218
+ if sdk is None:
219
+ return
220
+ lib_dir = sdk / "lib"
221
+ bin_dir = sdk / "bin"
222
+ for d in [lib_dir, bin_dir]:
223
+ if d.exists():
224
+ try:
225
+ os.add_dll_directory(str(d))
226
+ except OSError:
227
+ pass
228
+ # Also add to PATH as fallback
229
+ if str(d) not in os.environ.get("PATH", ""):
230
+ os.environ["PATH"] = str(d) + os.pathsep + os.environ.get("PATH", "")
231
+
232
+
233
+ def _find_rocm72_sdk() -> Path | None:
234
+ """Find the ROCm 7.2 SDK installed via pip (rocm-sdk-core package)."""
235
+ candidates = [
236
+ Path(sys.prefix) / "Lib" / "site-packages" / "_rocm_sdk_core",
237
+ Path(os.path.expanduser("~")) / "AppData" / "Roaming" / "Python"
238
+ / f"Python{sys.version_info.major}{sys.version_info.minor}"
239
+ / "site-packages" / "_rocm_sdk_core",
240
+ ]
241
+ try:
242
+ import _rocm_sdk_core
243
+ candidates.insert(0, Path(_rocm_sdk_core.__file__).parent)
244
+ except ImportError:
245
+ pass
246
+
247
+ for p in candidates:
248
+ clang = p / "lib" / "llvm" / "bin" / "amdclang++.exe"
249
+ if clang.exists():
250
+ return p
251
+ return None
252
+
253
+
254
+ def _setup_msvc_env():
255
+ """Setup MSVC compiler environment on Windows."""
256
+ if sys.platform != "win32":
257
+ return
258
+ if os.environ.get("VSINSTALLDIR"):
259
+ return
260
+
261
+ vcvars_candidates = [
262
+ r"C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat",
263
+ r"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvarsall.bat",
264
+ r"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvarsall.bat",
265
+ ]
266
+ for vcvars in vcvars_candidates:
267
+ if os.path.exists(vcvars):
268
+ result = subprocess.run(
269
+ f'cmd /c ""{vcvars}" x64 && set"',
270
+ capture_output=True, text=True, shell=True,
271
+ )
272
+ for line in result.stdout.splitlines():
273
+ if "=" in line:
274
+ k, _, v = line.partition("=")
275
+ os.environ[k] = v
276
+ return
277
+
278
+ logger.warning("MSVC not found, linking may fail")
279
+
280
+
281
+ def _run_cmd(cmd: list[str]):
282
+ """Run a command, raising RuntimeError on failure."""
283
+ result = subprocess.run(cmd, capture_output=True, text=True)
284
+ if result.returncode != 0:
285
+ raise RuntimeError(
286
+ f"Command failed (return {result.returncode}):\n"
287
+ f" {' '.join(cmd[:5])}...\n"
288
+ f" {result.stderr[-500:]}"
289
+ )
torchflat/_utils.py ADDED
@@ -0,0 +1,99 @@
1
+ """Shared utilities: masked median, padding helpers, constants."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ # ---------------------------------------------------------------------------
8
+ # Constants
9
+ # ---------------------------------------------------------------------------
10
+
11
+ # Mission-specific quality bitmasks
12
+ TESS_QUALITY_BITMASK: int = 0b0000110101111111 # = 3455
13
+ KEPLER_QUALITY_BITMASK: int = 0b0001111111111111 # = 8191 (all Kepler quality flags)
14
+
15
+ # Default is TESS (most common use case)
16
+ QUALITY_BITMASK: int = TESS_QUALITY_BITMASK
17
+
18
+ MIN_POINTS: int = 100 # Minimum valid points for a star to be processed
19
+ GAP_THRESHOLD: float = 5.0 # Gap ratio above this = large gap (segment boundary)
20
+ MIN_SEGMENT_LENGTH: int = 50 # Minimum valid points in a biweight window's segment
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Masked median
25
+ # ---------------------------------------------------------------------------
26
+
27
+ def masked_median(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
28
+ """Compute median along the last dimension, only over positions where *mask* is True.
29
+
30
+ Matches numpy.median convention for even-length arrays (averages two
31
+ middle values).
32
+
33
+ Args:
34
+ x: Tensor of any shape ``[..., N]``.
35
+ mask: Boolean tensor, same shape as *x*. True = valid.
36
+
37
+ Returns:
38
+ Tensor of shape ``[...]`` with the median of valid values along the
39
+ last dimension. Returns NaN where no valid values exist.
40
+ """
41
+ if x.shape[-1] == 0:
42
+ out_shape = x.shape[:-1]
43
+ return torch.full(out_shape, float("nan"), dtype=x.dtype, device=x.device)
44
+
45
+ W = x.shape[-1]
46
+
47
+ # torch.sort (O(n log n))
48
+ # Clone and push invalid values to +inf so they sort last
49
+ working = x.clone()
50
+ working[~mask] = float("inf")
51
+
52
+ n_valid = mask.sum(dim=-1) # [...]
53
+
54
+ # Sort ascending: valid values first, +inf values last
55
+ sorted_vals = torch.sort(working, dim=-1).values
56
+
57
+ # Median indices (numpy even-length convention: average two middle values)
58
+ mid_lo = ((n_valid - 1) // 2).clamp(min=0)
59
+ mid_hi = (n_valid // 2).clamp(min=0)
60
+
61
+ val_lo = sorted_vals.gather(-1, mid_lo.unsqueeze(-1)).squeeze(-1)
62
+ val_hi = sorted_vals.gather(-1, mid_hi.unsqueeze(-1)).squeeze(-1)
63
+ median = (val_lo + val_hi) / 2.0
64
+
65
+ # Guard: no valid data -> NaN
66
+ median = median.where(n_valid > 0, torch.tensor(float("nan"), dtype=median.dtype, device=median.device))
67
+
68
+ return median
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Padding helper
73
+ # ---------------------------------------------------------------------------
74
+
75
+ def pad_to_length(
76
+ tensors: list[torch.Tensor],
77
+ target_len: int,
78
+ pad_value: float = 0.0,
79
+ ) -> torch.Tensor:
80
+ """Right-pad a list of 1-D tensors to *target_len* and stack into 2-D.
81
+
82
+ Args:
83
+ tensors: List of 1-D tensors (lengths ``<= target_len``).
84
+ target_len: Length to pad to.
85
+ pad_value: Fill value for padding positions.
86
+
87
+ Returns:
88
+ Tensor of shape ``[len(tensors), target_len]``.
89
+ """
90
+ B = len(tensors)
91
+ out = torch.full(
92
+ (B, target_len),
93
+ fill_value=pad_value,
94
+ dtype=tensors[0].dtype,
95
+ device=tensors[0].device,
96
+ )
97
+ for i, t in enumerate(tensors):
98
+ out[i, : t.shape[0]] = t
99
+ return out
torchflat/batching.py ADDED
@@ -0,0 +1,238 @@
1
+ """CPU pre-scan, length-bucketed batch assembly, VRAM estimation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from torchflat._utils import GAP_THRESHOLD, MIN_POINTS, QUALITY_BITMASK
12
+
13
+ logger = logging.getLogger("torchflat")
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # CPU pre-scan
18
+ # ---------------------------------------------------------------------------
19
+
20
+ def cpu_prescan(
21
+ times: list[np.ndarray],
22
+ fluxes: list[np.ndarray],
23
+ qualities: list[np.ndarray],
24
+ bitmask: int = QUALITY_BITMASK,
25
+ gap_threshold: float = GAP_THRESHOLD,
26
+ min_points: int = MIN_POINTS,
27
+ window_samples: int = 360,
28
+ ) -> list[dict]:
29
+ """CPU pre-scan: compute post-filter length and flag degenerate stars.
30
+
31
+ Runs a fast O(N) scan per star on CPU to determine quality counts, gap
32
+ insertions, and whether the star is degenerate (too few points or
33
+ segments too short for the biweight window).
34
+ """
35
+ results: list[dict] = []
36
+
37
+ for i, (t, f, q) in enumerate(zip(times, fluxes, qualities)):
38
+ valid = ((q & bitmask) == 0) & np.isfinite(f) & np.isfinite(t)
39
+ n_valid = int(valid.sum())
40
+
41
+ if n_valid < 2:
42
+ results.append({
43
+ "index": i,
44
+ "n_valid": n_valid,
45
+ "n_insertions": 0,
46
+ "post_filter_length": n_valid,
47
+ "max_segment_length": n_valid,
48
+ "degenerate": True,
49
+ "degenerate_reason": "too_few_valid_points",
50
+ })
51
+ continue
52
+
53
+ t_valid = t[valid]
54
+ dt = np.diff(t_valid)
55
+ med_cadence = float(np.median(dt))
56
+ if med_cadence <= 0:
57
+ med_cadence = 1e-10
58
+
59
+ gap_ratio = dt / med_cadence
60
+
61
+ # Count small-gap interpolation insertions
62
+ small_gaps = (gap_ratio > 1.5) & (gap_ratio < gap_threshold)
63
+ n_insertions = 0
64
+ for gr in gap_ratio[small_gaps]:
65
+ n_insertions += int(round(gr)) - 1
66
+
67
+ # Longest segment
68
+ large_gap_pos = np.where(gap_ratio >= gap_threshold)[0]
69
+ boundaries = np.concatenate([[-1], large_gap_pos, [len(dt)]])
70
+ segment_lengths = np.diff(boundaries)
71
+ max_segment = int(segment_lengths.max()) if len(segment_lengths) > 0 else n_valid
72
+
73
+ post_filter_length = n_valid + n_insertions
74
+
75
+ degenerate = False
76
+ reason = None
77
+ if n_valid < min_points:
78
+ degenerate = True
79
+ reason = "too_few_valid_points"
80
+ elif max_segment < window_samples:
81
+ degenerate = True
82
+ reason = "segment_too_short"
83
+
84
+ results.append({
85
+ "index": i,
86
+ "n_valid": n_valid,
87
+ "n_insertions": n_insertions,
88
+ "post_filter_length": post_filter_length,
89
+ "max_segment_length": max_segment,
90
+ "degenerate": degenerate,
91
+ "degenerate_reason": reason,
92
+ })
93
+
94
+ return results
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # Bucketing
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def bucket_stars(
102
+ prescan_results: list[dict],
103
+ bucket_width: int = 1000,
104
+ ) -> list[dict]:
105
+ """Group non-degenerate stars into length buckets for batched GPU processing."""
106
+ buckets_map: dict[int, list[int]] = defaultdict(list)
107
+
108
+ for info in prescan_results:
109
+ if info["degenerate"]:
110
+ continue
111
+ pfl = info["post_filter_length"]
112
+ bucket_key = (pfl // bucket_width) * bucket_width + bucket_width
113
+ buckets_map[bucket_key].append(info["index"])
114
+
115
+ buckets: list[dict] = []
116
+ for pad_length in sorted(buckets_map.keys()):
117
+ buckets.append({
118
+ "star_indices": buckets_map[pad_length],
119
+ "pad_length": pad_length,
120
+ })
121
+ return buckets
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Batch assembly
126
+ # ---------------------------------------------------------------------------
127
+
128
+ def assemble_batch(
129
+ star_indices: list[int],
130
+ times: list[np.ndarray],
131
+ fluxes: list[np.ndarray],
132
+ qualities: list[np.ndarray],
133
+ pad_length: int,
134
+ device: torch.device,
135
+ ) -> dict:
136
+ """Pad, stack, and transfer a batch of stars to GPU.
137
+
138
+ Returns dict with tensors ``time``, ``flux``, ``quality``, ``lengths``,
139
+ and ``valid_mask`` on *device*.
140
+ """
141
+ B = len(star_indices)
142
+
143
+ time_batch = torch.zeros(B, pad_length, dtype=torch.float64)
144
+ flux_batch = torch.zeros(B, pad_length, dtype=torch.float32)
145
+ quality_batch = torch.zeros(B, pad_length, dtype=torch.int32)
146
+ lengths = torch.zeros(B, dtype=torch.long)
147
+ valid_mask = torch.zeros(B, pad_length, dtype=torch.bool)
148
+
149
+ for j, idx in enumerate(star_indices):
150
+ t = times[idx]
151
+ f = fluxes[idx].astype(np.float32)
152
+ q = qualities[idx]
153
+ L = len(t)
154
+ n = min(L, pad_length)
155
+
156
+ # CPU gap interpolation (faster than GPU cummax/cummin, <1% of points)
157
+ v = ((q[:n] & QUALITY_BITMASK) == 0) & np.isfinite(f[:n]) & np.isfinite(t[:n])
158
+ i = 0
159
+ while i < n:
160
+ if not v[i]:
161
+ gs = i
162
+ while i < n and not v[i]:
163
+ i += 1
164
+ ge = i
165
+ if gs > 0 and ge < n and (ge - gs) <= 4:
166
+ for k in range(ge - gs):
167
+ frac = (k + 1) / (ge - gs + 1)
168
+ f[gs + k] = f[gs - 1] + frac * (f[ge] - f[gs - 1])
169
+ v[gs + k] = True
170
+ else:
171
+ i += 1
172
+
173
+ time_batch[j, :n] = torch.from_numpy(t[:n].astype(np.float64))
174
+ flux_batch[j, :n] = torch.from_numpy(f[:n])
175
+ quality_batch[j, :n] = torch.from_numpy(q[:n].astype(np.int32))
176
+ lengths[j] = n
177
+ valid_mask[j, :n] = True
178
+
179
+ return {
180
+ "time": time_batch.to(device),
181
+ "flux": flux_batch.to(device),
182
+ "quality": quality_batch.to(device),
183
+ "lengths": lengths.to(device),
184
+ "valid_mask": valid_mask.to(device),
185
+ }
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # VRAM estimation
190
+ # ---------------------------------------------------------------------------
191
+
192
+ def estimate_peak_vram(L: int, win: int, dtype_bytes: int = 4) -> int:
193
+ """Estimate peak VRAM per star (bytes) during biweight detrending.
194
+
195
+ Accounts for persistent window tensors, per-iteration intermediates,
196
+ and temporary topk/sort int64 indices.
197
+ """
198
+ N_pos = L - win + 1
199
+ window_bytes = N_pos * win * dtype_bytes
200
+ bool_window_bytes = N_pos * win * 1
201
+ indices_bytes = N_pos * win * 8
202
+ seg_window_bytes = N_pos * win * 4 # int32 segment IDs
203
+
204
+ base = L * 17 # flux(4) + time(8) + valid(1) + seg_id(4)
205
+ persistent = window_bytes + seg_window_bytes + bool_window_bytes
206
+ per_iter = 4 * window_bytes # topk_clone + abs_dev + u + weights
207
+ peak_temp = indices_bytes
208
+
209
+ return base + persistent + per_iter + peak_temp
210
+
211
+
212
+ def compute_max_batch(
213
+ pad_length: int,
214
+ win: int = 360,
215
+ device: torch.device | None = None,
216
+ vram_budget_gb: float | None = None,
217
+ max_batch_override: int | None = None,
218
+ safety_factor: float = 0.8,
219
+ ) -> int:
220
+ """Dynamic max_batch with 3-tier priority: override > budget > auto-detect."""
221
+ if max_batch_override is not None:
222
+ return max(1, max_batch_override)
223
+
224
+ if vram_budget_gb is not None:
225
+ available = int(vram_budget_gb * 1024**3)
226
+ elif device is not None and device.type == "cuda":
227
+ props = torch.cuda.get_device_properties(device)
228
+ total = props.total_memory
229
+ headroom = 4 * 1024**3
230
+ available = total - headroom
231
+ else:
232
+ return 1 # CPU fallback
233
+
234
+ peak_per_star = estimate_peak_vram(pad_length, win)
235
+ if peak_per_star <= 0:
236
+ return 1
237
+ max_batch = int(available * safety_factor / peak_per_star)
238
+ return max(1, min(max_batch, 50))