cuperiod 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cuperiod/__init__.py ADDED
@@ -0,0 +1,85 @@
1
+ """cuPeriod: optimized, GPU-accelerated periodograms for astronomy.
2
+
3
+ Quick start
4
+ -----------
5
+ >>> import cuperiod as cup # doctest: +SKIP
6
+ >>> pg = cup.periodogram((time, mag, mag_err), "GLS") # doctest: +SKIP
7
+ >>> pg.best_period() # doctest: +SKIP
8
+ >>> for peak in pg.best_periods(10): # doctest: +SKIP
9
+ ... print(peak.period, peak.power)
10
+
11
+ The recommended import alias is ``cup``. The two entry points are
12
+ :func:`periodogram` (a single light curve) and :func:`batch_periodograms` (many).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ # Importing the methods package registers the built-in methods (GLS, BLS, ...).
18
+ from cuperiod import methods as methods
19
+ from cuperiod.api import best_periods, periodogram, to_input
20
+ from cuperiod.batch import BatchSummary, batch_periodograms
21
+ from cuperiod.core.columns import ColumnMap, Domain
22
+ from cuperiod.core.config import (
23
+ BatchSettings,
24
+ BLSSettings,
25
+ CESettings,
26
+ GLSSettings,
27
+ MHAOVSettings,
28
+ PDMSettings,
29
+ StringLengthSettings,
30
+ TLSSettings,
31
+ )
32
+ from cuperiod.core.device import GpuInfo, free_gpu_memory, gpu_info, suggest_gpu_workers
33
+ from cuperiod.core.errors import (
34
+ BackendUnavailableError,
35
+ ColumnResolutionError,
36
+ CuPeriodError,
37
+ InsufficientDataError,
38
+ UnknownMethodError,
39
+ )
40
+ from cuperiod.core.grid import GridSpec, log_period_grid, uniform_frequency_grid
41
+ from cuperiod.core.lightcurve import LightCurve, MultiBandLightCurve
42
+ from cuperiod.core.result import MultiResult, Peak, Periodogram
43
+ from cuperiod.methods.base import MethodInfo, get_method, list_methods, method_names
44
+
45
+ __version__ = "1.0.0"
46
+
47
+ __all__ = [
48
+ "BLSSettings",
49
+ "BackendUnavailableError",
50
+ "BatchSettings",
51
+ "BatchSummary",
52
+ "CESettings",
53
+ "ColumnMap",
54
+ "ColumnResolutionError",
55
+ "CuPeriodError",
56
+ "Domain",
57
+ "GLSSettings",
58
+ "GpuInfo",
59
+ "GridSpec",
60
+ "InsufficientDataError",
61
+ "LightCurve",
62
+ "MHAOVSettings",
63
+ "MethodInfo",
64
+ "MultiBandLightCurve",
65
+ "MultiResult",
66
+ "PDMSettings",
67
+ "Peak",
68
+ "Periodogram",
69
+ "StringLengthSettings",
70
+ "TLSSettings",
71
+ "UnknownMethodError",
72
+ "__version__",
73
+ "batch_periodograms",
74
+ "best_periods",
75
+ "free_gpu_memory",
76
+ "get_method",
77
+ "gpu_info",
78
+ "list_methods",
79
+ "log_period_grid",
80
+ "method_names",
81
+ "periodogram",
82
+ "suggest_gpu_workers",
83
+ "to_input",
84
+ "uniform_frequency_grid",
85
+ ]
cuperiod/api.py ADDED
@@ -0,0 +1,187 @@
1
+ """The single-light-curve API: :func:`periodogram`.
2
+
3
+ This is the one entry point most users need. It accepts a light curve in any convenient
4
+ form (a :class:`LightCurve`/:class:`MultiBandLightCurve`, a ``(t, y[, dy])`` tuple, a
5
+ pandas/astropy/pyarrow table, or a dict of arrays), one or several method names, and an
6
+ optional settings/grid/backend, and returns a :class:`Periodogram` (single method) or a
7
+ :class:`MultiResult` (several methods).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from collections.abc import Mapping, Sequence
13
+ from typing import Any
14
+
15
+ from pydantic_settings import BaseSettings
16
+
17
+ from cuperiod.core.columns import ColumnMap, Domain
18
+ from cuperiod.core.grid import GridSpec
19
+ from cuperiod.core.lightcurve import LightCurve, MultiBandLightCurve
20
+ from cuperiod.core.result import MultiResult, Periodogram
21
+ from cuperiod.methods.base import PeriodogramMethod, get_method
22
+
23
+ #: A single method's settings, a ``{method: settings}`` mapping, or ``None`` (defaults).
24
+ SettingsInput = BaseSettings | Mapping[str, BaseSettings] | None
25
+
26
+
27
+ def _looks_like_arrays(obj: Sequence[Any]) -> bool:
28
+ """Whether ``obj`` is a ``(t, y)`` / ``(t, y, dy)`` tuple of array-likes."""
29
+ if not isinstance(obj, (tuple, list)) or not (2 <= len(obj) <= 3):
30
+ return False
31
+ first = obj[0]
32
+ return hasattr(first, "__len__") and not isinstance(first, (str, bytes))
33
+
34
+
35
+ def to_input(
36
+ data: Any,
37
+ *,
38
+ columns: ColumnMap | None = None,
39
+ domain: Domain | None = None,
40
+ ) -> LightCurve | MultiBandLightCurve:
41
+ """Coerce a user input into a :class:`LightCurve` or :class:`MultiBandLightCurve`.
42
+
43
+ Parameters
44
+ ----------
45
+ data : various
46
+ A :class:`LightCurve`/:class:`MultiBandLightCurve`, a ``(t, y[, dy])`` tuple, a
47
+ ``{band: LightCurve}`` mapping, a pandas/astropy/pyarrow table, or a
48
+ ``{column: array}`` mapping.
49
+ columns : ColumnMap, optional
50
+ Column overrides for table inputs.
51
+ domain : Domain, optional
52
+ Brightness-domain override.
53
+
54
+ Returns
55
+ -------
56
+ LightCurve or MultiBandLightCurve
57
+ """
58
+ if isinstance(data, (LightCurve, MultiBandLightCurve)):
59
+ if domain is not None and isinstance(data, LightCurve):
60
+ return data.in_domain(domain)
61
+ return data
62
+ if _looks_like_arrays(data):
63
+ t, v = data[0], data[1]
64
+ e = data[2] if len(data) == 3 else None
65
+ return LightCurve.from_arrays(t, v, e, domain=domain or Domain.MAGNITUDE)
66
+ if isinstance(data, Mapping):
67
+ values = list(data.values())
68
+ if values and all(isinstance(v, LightCurve) for v in values):
69
+ return MultiBandLightCurve.from_light_curves(data)
70
+ # dict of arrays -> a table
71
+ names, getter = _mapping_table(data)
72
+ return LightCurve._from_table(names, getter, columns, domain, None)
73
+ # pandas / astropy / pyarrow table
74
+ return LightCurve.from_dataframe(data, columns=columns, domain=domain)
75
+
76
+
77
+ def _mapping_table(data: Mapping[str, Any]) -> tuple[list[str], Any]:
78
+ from cuperiod.core.lightcurve import _adapt_table
79
+
80
+ return _adapt_table(data)
81
+
82
+
83
+ def _select_settings(
84
+ method: PeriodogramMethod, settings: SettingsInput
85
+ ) -> BaseSettings:
86
+ """Pick the settings object that applies to ``method``."""
87
+ if settings is None:
88
+ return method.coerce_settings(None)
89
+ if isinstance(settings, BaseSettings):
90
+ if isinstance(settings, method.settings_cls):
91
+ return settings
92
+ return method.coerce_settings(None)
93
+ # mapping of {method_name: settings}
94
+ for key, value in settings.items():
95
+ if key.upper() == method.name:
96
+ return method.coerce_settings(value)
97
+ return method.coerce_settings(None)
98
+
99
+
100
+ def _multiband_grid(
101
+ method: PeriodogramMethod, mblc: MultiBandLightCurve, settings: BaseSettings
102
+ ) -> GridSpec:
103
+ """Build a method's default grid from a multi-band stacked baseline."""
104
+ time, value, error, _ = mblc.finite().stacked()
105
+ synthetic = LightCurve.from_arrays(time, value, error)
106
+ return method.default_grid(synthetic, settings)
107
+
108
+
109
+ def _run_one(
110
+ method: PeriodogramMethod,
111
+ lc: LightCurve | MultiBandLightCurve,
112
+ settings: BaseSettings,
113
+ backend: str,
114
+ grid: GridSpec | None,
115
+ ) -> Periodogram:
116
+ resolved_backend = method.resolve_backend(backend)
117
+ if isinstance(lc, MultiBandLightCurve):
118
+ if not method.supports_multiband:
119
+ raise ValueError(f"{method.name} does not support multi-band input")
120
+ g = grid or _multiband_grid(method, lc, settings)
121
+ return method.multiband_power(g, lc, settings, resolved_backend)
122
+ single = lc.in_domain(method.natural_domain) if method.natural_domain else lc
123
+ g = grid or method.default_grid(single, settings)
124
+ return method.power(g, single, settings, resolved_backend, engine=None)
125
+
126
+
127
+ def periodogram(
128
+ data: Any,
129
+ method: str | Sequence[str] = "GLS",
130
+ *,
131
+ backend: str = "auto",
132
+ grid: GridSpec | None = None,
133
+ settings: SettingsInput = None,
134
+ columns: ColumnMap | None = None,
135
+ domain: Domain | None = None,
136
+ ) -> Periodogram | MultiResult:
137
+ """Compute one or more periodograms for a single light curve.
138
+
139
+ Parameters
140
+ ----------
141
+ data : various
142
+ The light curve (see :func:`to_input` for accepted forms).
143
+ method : str or sequence of str, default "GLS"
144
+ One method name, or several. Case-insensitive (``"gls"`` == ``"GLS"``).
145
+ backend : str, default "auto"
146
+ ``"auto"`` (GPU when available, else CPU), ``"cpu"``, ``"gpu"``, or a concrete
147
+ backend name.
148
+ grid : GridSpec, optional
149
+ Custom trial grid; defaults to the method's grid for this light curve.
150
+ settings : settings model or mapping, optional
151
+ A single method's settings, or a ``{method: settings}`` mapping.
152
+ columns : ColumnMap, optional
153
+ Column overrides for table inputs.
154
+ domain : Domain, optional
155
+ Brightness-domain override.
156
+
157
+ Returns
158
+ -------
159
+ Periodogram or MultiResult
160
+ A :class:`Periodogram` if ``method`` is a single string, else a
161
+ :class:`MultiResult` keyed by method name.
162
+
163
+ Examples
164
+ --------
165
+ >>> pg = periodogram((t, mag, err), "GLS") # doctest: +SKIP
166
+ >>> pg.best_period() # doctest: +SKIP
167
+ >>> res = periodogram(df, ["GLS", "BLS"]) # doctest: +SKIP
168
+ """
169
+ lc = to_input(data, columns=columns, domain=domain)
170
+ names = [method] if isinstance(method, str) else list(method)
171
+ results: dict[str, Periodogram] = {}
172
+ for name in names:
173
+ m = get_method(name)
174
+ results[m.name] = _run_one(m, lc, _select_settings(m, settings), backend, grid)
175
+ if isinstance(method, str):
176
+ return results[get_method(method).name]
177
+ return MultiResult(results)
178
+
179
+
180
+ def best_periods(
181
+ result: Periodogram | MultiResult, n: int = 10, **kwargs: Any
182
+ ) -> Any:
183
+ """Convenience wrapper around ``result.best_periods(n)``."""
184
+ return result.best_periods(n, **kwargs)
185
+
186
+
187
+ __all__ = ["best_periods", "periodogram", "to_input"]
@@ -0,0 +1,7 @@
1
+ """Batch processing: scale periodograms over many light curves."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from cuperiod.batch.runner import BatchSummary, batch_periodograms
6
+
7
+ __all__ = ["BatchSummary", "batch_periodograms"]
cuperiod/batch/io.py ADDED
@@ -0,0 +1,232 @@
1
+ """Batch input resolution and result sinks.
2
+
3
+ Inputs can be an iterable of light curves, a glob pattern, a directory of light-curve
4
+ files, or a ``(DataFrame, group_column)`` pair; :func:`resolve_inputs` normalizes any of
5
+ them to ``(key, source)`` items, where ``source`` is loaded lazily in the worker so file
6
+ paths (not big arrays) cross the process boundary. Results are flattened to one row per
7
+ light curve and written as Parquet (default) or CSV.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import glob as _glob
13
+ from collections.abc import Iterable, Mapping, Sequence
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ from cuperiod.core.columns import ColumnMap, Domain
18
+ from cuperiod.core.lightcurve import LightCurve, MultiBandLightCurve
19
+ from cuperiod.core.result import Periodogram
20
+
21
+ #: File extensions treated as light-curve tables when scanning a directory.
22
+ LIGHTCURVE_SUFFIXES = (
23
+ ".parquet",
24
+ ".pq",
25
+ ".csv",
26
+ ".ecsv",
27
+ ".fits",
28
+ ".fit",
29
+ ".dat",
30
+ ".txt",
31
+ ".tsv",
32
+ )
33
+
34
+ Source = LightCurve | MultiBandLightCurve | Path
35
+ InputItem = tuple[str, Source]
36
+
37
+
38
+ def _is_dataframe(obj: Any) -> bool:
39
+ return hasattr(obj, "groupby") and hasattr(obj, "columns")
40
+
41
+
42
+ def resolve_inputs(
43
+ inputs: Any,
44
+ *,
45
+ columns: ColumnMap | None = None,
46
+ domain: Domain | None = None,
47
+ band_column: str | None = None,
48
+ ) -> list[InputItem]:
49
+ """Normalize a batch input into ``(key, source)`` items.
50
+
51
+ Parameters
52
+ ----------
53
+ inputs : various
54
+ Iterable of light curves / ``(key, lc)`` pairs / paths, a glob string, a
55
+ directory path, or a ``(DataFrame, group_column)`` tuple.
56
+ columns, domain : optional
57
+ Forwarded to file/table loading.
58
+ band_column : str, optional
59
+ If given with a ``(DataFrame, group_column)`` input, each group is loaded as a
60
+ :class:`MultiBandLightCurve` split on this column.
61
+
62
+ Returns
63
+ -------
64
+ list of (str, Source)
65
+ Stable keys paired with a light curve or a path to load lazily.
66
+ """
67
+ # (DataFrame, group_column)
68
+ if isinstance(inputs, tuple) and len(inputs) == 2 and _is_dataframe(inputs[0]):
69
+ df, group_col = inputs
70
+ items: list[InputItem] = []
71
+ for key, sub in df.groupby(group_col, sort=False):
72
+ if band_column is not None:
73
+ lc: Source = MultiBandLightCurve.from_dataframe(
74
+ sub, band_column=band_column, columns=columns, domain=domain
75
+ )
76
+ else:
77
+ lc = LightCurve.from_dataframe(sub, columns=columns, domain=domain)
78
+ items.append((str(key), lc))
79
+ return items
80
+
81
+ # Glob / directory / single file given as a string or Path.
82
+ if isinstance(inputs, (str, Path)):
83
+ return _resolve_path_like(inputs)
84
+
85
+ # An iterable of light curves / (key, lc) pairs / paths.
86
+ if isinstance(inputs, Iterable):
87
+ return _resolve_iterable(inputs, columns=columns, domain=domain)
88
+
89
+ raise TypeError(f"unsupported batch input type: {type(inputs)!r}")
90
+
91
+
92
+ def _resolve_path_like(inputs: str | Path) -> list[InputItem]:
93
+ text = str(inputs)
94
+ if any(ch in text for ch in "*?[") or "**" in text:
95
+ paths = sorted(Path(p) for p in _glob.glob(text, recursive=True))
96
+ return [(p.stem, p) for p in paths]
97
+ path = Path(inputs)
98
+ if path.is_dir():
99
+ paths = sorted(
100
+ p for p in path.iterdir() if p.suffix.lower() in LIGHTCURVE_SUFFIXES
101
+ )
102
+ return [(p.stem, p) for p in paths]
103
+ if path.is_file():
104
+ return [(path.stem, path)]
105
+ raise FileNotFoundError(f"no such file, directory, or glob match: {inputs!r}")
106
+
107
+
108
+ def _resolve_iterable(
109
+ inputs: Iterable[Any], *, columns: ColumnMap | None, domain: Domain | None
110
+ ) -> list[InputItem]:
111
+ items: list[InputItem] = []
112
+ for i, entry in enumerate(inputs):
113
+ if (
114
+ isinstance(entry, tuple)
115
+ and len(entry) == 2
116
+ and isinstance(entry[0], (str, int))
117
+ and isinstance(entry[1], (LightCurve, MultiBandLightCurve, str, Path))
118
+ ):
119
+ key, src = entry
120
+ src = Path(src) if isinstance(src, str) else src
121
+ items.append((str(key), src))
122
+ elif isinstance(entry, (LightCurve, MultiBandLightCurve)):
123
+ key = str(entry.meta.get("id", i)) if entry.meta else str(i)
124
+ items.append((key, entry))
125
+ elif isinstance(entry, (str, Path)):
126
+ items.append((Path(entry).stem, Path(entry)))
127
+ else:
128
+ raise TypeError(f"unsupported batch entry at index {i}: {type(entry)!r}")
129
+ return items
130
+
131
+
132
+ def load_source(
133
+ source: Source,
134
+ *,
135
+ columns: ColumnMap | None = None,
136
+ domain: Domain | None = None,
137
+ ) -> LightCurve | MultiBandLightCurve:
138
+ """Materialize a source: return a light curve unchanged, or load it from a path."""
139
+ if isinstance(source, (LightCurve, MultiBandLightCurve)):
140
+ return source
141
+ return LightCurve.from_file(source, columns=columns, domain=domain)
142
+
143
+
144
+ def periodogram_to_row(
145
+ key: str, pg: Periodogram, *, n_best: int, store_raw: bool
146
+ ) -> dict[str, Any]:
147
+ """Flatten a :class:`Periodogram` into one result row (peaks + optional raw)."""
148
+ row: dict[str, Any] = {
149
+ "key": key,
150
+ "method": pg.method,
151
+ "backend": pg.backend,
152
+ "n_samples": pg.n_samples,
153
+ "baseline": pg.baseline,
154
+ "best_period": pg.best_period() if pg.size else float("nan"),
155
+ }
156
+ peaks = pg.best_periods(n_best, alias_diverse=(pg.method == "BLS"))
157
+ for rank in range(1, n_best + 1):
158
+ peak = peaks[rank - 1] if rank <= len(peaks) else None
159
+ row[f"period_{rank}"] = peak.period if peak else float("nan")
160
+ row[f"power_{rank}"] = peak.power if peak else float("nan")
161
+ if peak is not None:
162
+ for ekey, eval_ in peak.extra.items():
163
+ row[f"{ekey}_{rank}"] = eval_
164
+ if store_raw:
165
+ freq, power = pg.downsample()
166
+ row["pgram_frequency"] = freq.tolist()
167
+ row["pgram_power"] = power.tolist()
168
+ return row
169
+
170
+
171
+ def write_rows(rows: Sequence[Mapping[str, Any]], path: Path) -> None:
172
+ """Write result rows to ``path`` as Parquet (``.parquet``) or CSV (``.csv``)."""
173
+ path = Path(path)
174
+ path.parent.mkdir(parents=True, exist_ok=True)
175
+ if path.suffix.lower() in {".parquet", ".pq"}:
176
+ _write_parquet(rows, path)
177
+ elif path.suffix.lower() == ".csv":
178
+ _write_csv(rows, path)
179
+ else:
180
+ raise ValueError(
181
+ f"unsupported sink format {path.suffix!r}; use .parquet or .csv"
182
+ )
183
+
184
+
185
+ def _union_columns(rows: Sequence[Mapping[str, Any]]) -> list[str]:
186
+ seen: dict[str, None] = {}
187
+ for row in rows:
188
+ for key in row:
189
+ seen.setdefault(key, None)
190
+ return list(seen)
191
+
192
+
193
+ def _write_parquet(rows: Sequence[Mapping[str, Any]], path: Path) -> None:
194
+ import pyarrow as pa
195
+ import pyarrow.parquet as pq
196
+
197
+ columns = _union_columns(rows)
198
+ data = {col: [row.get(col) for row in rows] for col in columns}
199
+ table = pa.table(data)
200
+ tmp = path.with_suffix(path.suffix + ".tmp")
201
+ pq.write_table(table, tmp)
202
+ tmp.replace(path)
203
+
204
+
205
+ def _write_csv(rows: Sequence[Mapping[str, Any]], path: Path) -> None:
206
+ import csv
207
+
208
+ columns = _union_columns(rows)
209
+ list_cols = [
210
+ col for col in columns if any(isinstance(row.get(col), list) for row in rows)
211
+ ]
212
+ if list_cols:
213
+ raise ValueError(
214
+ f"CSV sink cannot store array-valued columns {list_cols} (e.g. from "
215
+ "store_raw); use a .parquet sink."
216
+ )
217
+ tmp = path.with_suffix(path.suffix + ".tmp")
218
+ with tmp.open("w", newline="", encoding="utf-8") as fh:
219
+ writer = csv.DictWriter(fh, fieldnames=columns, extrasaction="ignore")
220
+ writer.writeheader()
221
+ writer.writerows(rows)
222
+ tmp.replace(path)
223
+
224
+
225
+ __all__ = [
226
+ "InputItem",
227
+ "Source",
228
+ "load_source",
229
+ "periodogram_to_row",
230
+ "resolve_inputs",
231
+ "write_rows",
232
+ ]