jacscanomaly 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,20 @@
1
+ # scanomaly/__init__.py
2
+ from __future__ import annotations
3
+
4
+ from jax import config as jax_config
5
+ jax_config.update("jax_enable_x64", True)
6
+
7
+ from .config import FinderConfig
8
+ from .finder import Finder
9
+ from .plot import AnomalyPlotter
10
+ from .pspl import PSPLFitter, PSPLFitResult
11
+
12
+ __all__ = [
13
+ "FinderConfig",
14
+ "Finder",
15
+ "AnomalyPlotter",
16
+ "PSPLFitter",
17
+ "PSPLFitResult",
18
+ ]
19
+
20
+ __version__ = "0.1.0"
jacscanomaly/config.py ADDED
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class FinderConfig:
8
+ """
9
+ Configuration for :class:`scanomaly.finder.Finder`.
10
+
11
+ This dataclass contains *only* hyperparameters that control the behavior of the
12
+ anomaly search. It is intentionally dependency-free (no NumPy/JAX imports) and
13
+ frozen for reproducibility.
14
+
15
+ Parameter groups
16
+ ----------------
17
+ 1) Season splitting:
18
+ Split the time series into seasons based on large time gaps.
19
+
20
+ 2) Grid construction:
21
+ Build a (t0, teff) grid per season.
22
+
23
+ 3) Grid scan:
24
+ Evaluate delta-chi2 on the grid within a local window.
25
+
26
+ 4) Cluster extraction:
27
+ Group overlapping candidates and pick the best per cluster.
28
+ """
29
+
30
+ # ==================================================
31
+ # 1) Season splitting
32
+ # ==================================================
33
+ gap: float = 100.0
34
+ """Time gap threshold for splitting seasons. A new season starts when dt > gap."""
35
+
36
+ # ==================================================
37
+ # 2) Grid construction (t0, teff)
38
+ # ==================================================
39
+ teff_init: float = 0.03
40
+ """Initial teff value for the grid (first element of the geometric series)."""
41
+
42
+ common_ratio: float = 4.0 / 3.0
43
+ """Common ratio for the geometric series of teff values."""
44
+
45
+ teff_grid_n: int = 5
46
+ """Number of teff values in the grid."""
47
+
48
+ dt0_coeff: float = 0.17
49
+ """
50
+ Grid spacing coefficient for t0:
51
+ dt0 = dt0_coeff * teff
52
+ """
53
+
54
+ # ==================================================
55
+ # 3) Grid scan (local evaluation window)
56
+ # ==================================================
57
+ sigma: float = 3.0
58
+ """
59
+ Threshold parameter used in counting per-point chi2 improvement.
60
+ (Kept for compatibility with your original `n_out` logic.)
61
+ """
62
+
63
+ teff_coeff: float = 3.0
64
+ """
65
+ Window half-width multiplier in units of teff:
66
+ window = [t0 - teff_coeff*teff, t0 + teff_coeff*teff]
67
+ """
68
+
69
+ min_pts_in_window: int = 4
70
+ """Minimum number of data points required inside the window to evaluate a grid point."""
71
+
72
+ # ==================================================
73
+ # 4) Cluster extraction
74
+ # ==================================================
75
+ overlap_sigma: float = 3.0
76
+ """
77
+ Overlap threshold multiplier used to group nearby grid points into clusters:
78
+ |t0_i - t0_j| < overlap_sigma * (teff_i + teff_j)
79
+ """
80
+
81
+ min_cluster_points: int = 3
82
+ """
83
+ Stop extracting clusters when the number of remaining grid points becomes
84
+ smaller than this value.
85
+ """
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ @dataclass
10
+ class ResultExtractor:
11
+ """
12
+ Cluster extractor for grid-scan candidates.
13
+
14
+ Given arrays of (t0, teff, delta_chi2) evaluated on a grid,
15
+ this class groups overlapping candidates and returns one representative
16
+ (the maximum delta_chi2 point) per cluster.
17
+
18
+ Overlap definition
19
+ ------------------
20
+ Two candidates i and j are considered overlapping if:
21
+
22
+ |t0_i - t0_j| < sigma_overlap * (teff_i + teff_j)
23
+
24
+ Notes
25
+ -----
26
+ - This operates on CPU / NumPy arrays (no JAX).
27
+ - Returned `clusters` has shape (K, 3) with rows [t0_best, teff_best, dchi2_best].
28
+ """
29
+
30
+ sigma_overlap: float = 3.0
31
+ min_points: int = 3
32
+
33
+ def _overlap_with_max(
34
+ self,
35
+ t0: np.ndarray,
36
+ teff: np.ndarray,
37
+ dchi2: np.ndarray,
38
+ ) -> Tuple[np.ndarray, int]:
39
+ """
40
+ Compute the overlap mask around the current maximum dchi2 point.
41
+
42
+ Returns
43
+ -------
44
+ overlap_mask : np.ndarray of bool, shape (N,)
45
+ Mask selecting points overlapping with the maximum point.
46
+ i_max : int
47
+ Index of the maximum point within the provided arrays.
48
+ """
49
+ i_max = int(np.nanargmax(dchi2))
50
+ t0_max = t0[i_max]
51
+ teff_max = teff[i_max]
52
+ overlap_mask = np.abs(t0 - t0_max) < self.sigma_overlap * (teff + teff_max)
53
+ return overlap_mask, i_max
54
+
55
+ def iterative_anomaly_extraction(
56
+ self,
57
+ t0_list,
58
+ teff_list,
59
+ dchi2_list,
60
+ ) -> np.ndarray:
61
+ """
62
+ Iteratively extract non-overlapping clusters from grid results.
63
+
64
+ Parameters
65
+ ----------
66
+ t0_list, teff_list, dchi2_list
67
+ 1D arrays (or array-like) of equal length.
68
+
69
+ Returns
70
+ -------
71
+ clusters : np.ndarray, shape (K, 3)
72
+ Each row is [t0, teff, dchi2] for the best (max dchi2) point
73
+ in each extracted cluster.
74
+ Returns an empty array with shape (0, 3) if nothing is extractable.
75
+
76
+ Stopping conditions
77
+ -------------------
78
+ - No remaining candidates.
79
+ - The best remaining candidate is non-finite.
80
+ - Remaining candidate count drops below `min_points`.
81
+ """
82
+ t0 = np.asarray(t0_list, dtype=float)
83
+ teff = np.asarray(teff_list, dtype=float)
84
+ dchi2 = np.asarray(dchi2_list, dtype=float)
85
+
86
+ if t0.size == 0:
87
+ return np.zeros((0, 3), dtype=float)
88
+
89
+ if not (t0.shape == teff.shape == dchi2.shape):
90
+ raise ValueError(
91
+ f"Input arrays must have the same shape, got "
92
+ f"t0={t0.shape}, teff={teff.shape}, dchi2={dchi2.shape}"
93
+ )
94
+
95
+ clusters: List[List[float]] = []
96
+ remaining = np.ones_like(dchi2, dtype=bool)
97
+
98
+ while True:
99
+ if not np.any(remaining):
100
+ break
101
+
102
+ # pick the best remaining point
103
+ dchi2_rem = np.where(remaining, dchi2, -np.inf)
104
+ i_max_global = int(np.argmax(dchi2_rem))
105
+
106
+ if not np.isfinite(dchi2[i_max_global]):
107
+ break
108
+
109
+ # overlap mask in the "compressed" remaining arrays
110
+ overlap_mask, _ = self._overlap_with_max(
111
+ t0[remaining], teff[remaining], dchi2[remaining]
112
+ )
113
+
114
+ # expand to full mask
115
+ full_mask = np.zeros_like(remaining)
116
+ full_mask[np.where(remaining)[0][overlap_mask]] = True
117
+
118
+ # choose the best representative in this cluster
119
+ cluster_dchi2 = dchi2[full_mask]
120
+ cluster_t0 = t0[full_mask]
121
+ cluster_teff = teff[full_mask]
122
+
123
+ i_local_max = int(np.argmax(cluster_dchi2))
124
+ clusters.append(
125
+ [float(cluster_t0[i_local_max]), float(cluster_teff[i_local_max]), float(cluster_dchi2[i_local_max])]
126
+ )
127
+
128
+ # remove this cluster from remaining
129
+ remaining &= ~full_mask
130
+
131
+ if int(np.sum(remaining)) < self.min_points:
132
+ break
133
+
134
+ return np.asarray(clusters, dtype=float)
jacscanomaly/finder.py ADDED
@@ -0,0 +1,225 @@
1
+ # scanomaly/finder.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import jax
9
+ import jax.numpy as jnp
10
+
11
+ from .config import FinderConfig
12
+ from .pspl import PSPLFitter
13
+ from .plot import AnomalyPlotter
14
+ from .seasons import SeasonSplitter
15
+ from .extract import ResultExtractor
16
+ from .runner import SeasonGridRunner
17
+ from .models import AnomalyResult, BestCandidate
18
+
19
+
20
+ @dataclass
21
+ class Finder:
22
+ """
23
+ Main entry point of scanomaly.
24
+
25
+ Finder performs:
26
+ 1) PSPL fit on (time, flux, ferr)
27
+ 2) season splitting
28
+ 3) grid scan on PSPL residuals
29
+ 4) cluster extraction
30
+ 5) selection of the best anomaly candidate
31
+
32
+ Users typically call :meth:`run` and then pass the returned
33
+ :class:`scanomaly.models.AnomalyResult` to :class:`scanomaly.plot.AnomalyPlotter`.
34
+ """
35
+
36
+ # NOTE: use default_factory for dataclass fields (avoid shared mutable defaults)
37
+ config: FinderConfig = field(default_factory=FinderConfig)
38
+
39
+ # Allow dependency injection, but create defaults if None
40
+ fitter: Optional[PSPLFitter] = None
41
+ plotter: Optional[AnomalyPlotter] = None
42
+
43
+ def __post_init__(self) -> None:
44
+ if self.fitter is None:
45
+ self.fitter = PSPLFitter()
46
+ if self.plotter is None:
47
+ self.plotter = AnomalyPlotter()
48
+
49
+ splitter = SeasonSplitter(gap=self.config.gap)
50
+ extractor = ResultExtractor(
51
+ sigma_overlap=self.config.overlap_sigma,
52
+ min_points=self.config.min_cluster_points,
53
+ )
54
+ self.runner = SeasonGridRunner(
55
+ splitter=splitter,
56
+ extractor=extractor,
57
+ config=self.config,
58
+ )
59
+
60
+ _last_result: Optional[AnomalyResult] = field(default=None, init=False)
61
+
62
+ # ----------------------------
63
+ # Public APIs
64
+ # ----------------------------
65
+ def fit_pspl(self, time, flux, ferr, p0):
66
+ """
67
+ Convenience method: run PSPL fit only.
68
+
69
+ Returns
70
+ -------
71
+ PSPLFitResult
72
+ The PSPL fitting result (JAX arrays inside).
73
+ """
74
+ time_j, flux_j, ferr_j, p0_j, _time_np, _flux_np, _ferr_np = self._to_arrays(time, flux, ferr, p0)
75
+ return self.fitter.fit(time_j, flux_j, ferr_j, p0_j)
76
+
77
+ def run(
78
+ self,time,flux,ferr,p0,*,
79
+ verbose: bool = True,log: Optional[logging.Logger] = None,) -> AnomalyResult:
80
+ """
81
+ Run the full anomaly finding pipeline.
82
+
83
+ Parameters
84
+ ----------
85
+ time, flux, ferr : array-like
86
+ 1D arrays. Stored in the output as NumPy arrays on CPU for fast plotting.
87
+ p0 : array-like
88
+ Initial PSPL parameters (t0, tE, u0).
89
+
90
+ Returns
91
+ -------
92
+ AnomalyResult
93
+ Includes PSPL fit, residuals, per-season cluster summaries,
94
+ flattened clusters, and the best candidate (if any).
95
+ """
96
+ time_j, flux_j, ferr_j, p0_j, time_np, flux_np, ferr_np = self._to_arrays(time, flux, ferr, p0)
97
+
98
+ # 1) PSPL fit (JAX)
99
+ fit = self.fitter.fit(time_j, flux_j, ferr_j, p0_j)
100
+ residual_j = fit.residual
101
+ model_flux_j = fit.model_flux
102
+
103
+ # bring to CPU for plotting/analysis
104
+ residual_np, model_flux_np, chi2_dof = jax.device_get((residual_j, model_flux_j, fit.chi2_dof))
105
+ residual_np = np.asarray(residual_np, dtype=float)
106
+ model_flux_np = np.asarray(model_flux_np, dtype=float)
107
+ chi2_dof = float(chi2_dof)
108
+
109
+ # 2-4) season loop & grid scan & extraction
110
+ seasons, clusters_all = self.runner.run(
111
+ time_j=time_j,
112
+ residual_j=residual_j,
113
+ ferr_j=ferr_j,
114
+ time_np=time_np,
115
+ verbose=verbose,
116
+ log=log,
117
+ )
118
+
119
+ # best candidate selection
120
+ best_obj = self._pick_best_candidate(clusters_all)
121
+
122
+ result = AnomalyResult(
123
+ time=time_np,
124
+ flux=flux_np,
125
+ ferr=ferr_np,
126
+ fit=fit,
127
+ residual=residual_np,
128
+ model_flux=model_flux_np,
129
+ chi2_dof=chi2_dof,
130
+ seasons=seasons,
131
+ clusters_all=clusters_all,
132
+ best=best_obj,
133
+ )
134
+
135
+ self._last_result = result
136
+ return result
137
+
138
+ # ----------------------------
139
+ # Internal helpers
140
+ # ----------------------------
141
+ def _to_arrays(self, time, flux, ferr, p0):
142
+ """Convert inputs into both NumPy (CPU) and JAX arrays."""
143
+ time_np = np.asarray(time, dtype=float)
144
+ flux_np = np.asarray(flux, dtype=float)
145
+ ferr_np = np.asarray(ferr, dtype=float)
146
+
147
+ if time_np.ndim != 1 or flux_np.ndim != 1 or ferr_np.ndim != 1:
148
+ raise ValueError("time/flux/ferr must be 1D arrays.")
149
+ if not (len(time_np) == len(flux_np) == len(ferr_np)):
150
+ raise ValueError("time/flux/ferr must have the same length.")
151
+ if np.any(~np.isfinite(time_np)) or np.any(~np.isfinite(flux_np)) or np.any(~np.isfinite(ferr_np)):
152
+ raise ValueError("time/flux/ferr must be finite.")
153
+ if np.any(ferr_np <= 0):
154
+ raise ValueError("ferr must be positive.")
155
+
156
+ time_j = jnp.asarray(time_np)
157
+ flux_j = jnp.asarray(flux_np)
158
+ ferr_j = jnp.asarray(ferr_np)
159
+ p0_j = jnp.asarray(p0, dtype=time_j.dtype)
160
+
161
+ return time_j, flux_j, ferr_j, p0_j, time_np, flux_np, ferr_np
162
+
163
+ def _pick_best_candidate(self, clusters_all: np.ndarray) -> Optional[BestCandidate]:
164
+ """
165
+ Pick the single best candidate from flattened clusters and compute a standardized score.
166
+ """
167
+ if clusters_all is None or clusters_all.size == 0 or clusters_all.shape[0] < 1:
168
+ return None
169
+
170
+ # clusters_all rows: [t0, teff, dchi2]
171
+ max_ind = int(np.argmax(clusters_all[:, 2]))
172
+ best = clusters_all[max_ind]
173
+ others = np.delete(clusters_all, max_ind, axis=0)
174
+
175
+ if others.shape[0] >= 2:
176
+ med = float(np.median(others[:, 2]))
177
+ std = float(np.std(others[:, 2]))
178
+ score = float((best[2] - med) / std) if std > 0 else float("inf")
179
+ else:
180
+ med, std, score = float("nan"), float("nan"), float("nan")
181
+
182
+ return BestCandidate(
183
+ t0=float(best[0]),
184
+ teff=float(best[1]),
185
+ dchi2=float(best[2]),
186
+ med_others=med,
187
+ std_others=std,
188
+ score=score,
189
+ )
190
+
191
+ # ----------------------------
192
+ # Plot sugar APIs
193
+ # ----------------------------
194
+ def _require_result(self) -> AnomalyResult:
195
+ if self._last_result is None:
196
+ raise RuntimeError("Finder.run() has not been called yet.")
197
+ return self._last_result
198
+
199
+ def plot_lc(self, **kwargs):
200
+ """
201
+ Plot light curve with PSPL model using the last result.
202
+ """
203
+ result = self._require_result()
204
+ return self.plotter.plot_lc(result, **kwargs)
205
+
206
+ def plot_residual(self, **kwargs):
207
+ """
208
+ Plot residuals using the last result.
209
+ """
210
+ result = self._require_result()
211
+ return self.plotter.plot_residual(result, **kwargs)
212
+
213
+ def plot_anomaly_window(self, **kwargs):
214
+ """
215
+ Plot residuals around the best anomaly window.
216
+ """
217
+ result = self._require_result()
218
+ return self.plotter.plot_anomaly_window(result, **kwargs)
219
+
220
+ def plot_result(self, **kwargs):
221
+ """
222
+ Full 3-panel diagnostic plot.
223
+ """
224
+ result = self._require_result()
225
+ return self.plotter.plot_result(result, **kwargs)
jacscanomaly/models.py ADDED
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+
8
+ from .pspl import PSPLFitResult
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class BestCandidate:
13
+ """
14
+ Best anomaly candidate selected from all extracted clusters.
15
+
16
+ Attributes
17
+ ----------
18
+ t0 : float
19
+ Candidate center time.
20
+ teff : float
21
+ Candidate effective timescale.
22
+ dchi2 : float
23
+ Improvement in chi-square: chi2_null - chi2_anom (larger is better).
24
+ med_others : float
25
+ Median dchi2 among all other candidates (excluding the best).
26
+ std_others : float
27
+ Standard deviation of dchi2 among all other candidates (excluding the best).
28
+ score : float
29
+ Standardized score of the best candidate:
30
+ (dchi2_best - med_others) / std_others
31
+ (may be NaN/inf depending on the number of candidates / std_others).
32
+ """
33
+ t0: float
34
+ teff: float
35
+ dchi2: float
36
+ med_others: float
37
+ std_others: float
38
+ score: float
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class SeasonSummary:
43
+ """
44
+ Summary of the anomaly scan for a single season.
45
+
46
+ Attributes
47
+ ----------
48
+ season_idx : int
49
+ 0-based season index.
50
+ t_start, t_end : float
51
+ Time range of the season.
52
+ n_grid : int
53
+ Number of grid points evaluated in this season.
54
+ clusters : np.ndarray
55
+ Extracted clusters for this season, shape (K, 3) with rows [t0, teff, dchi2].
56
+ """
57
+ season_idx: int
58
+ t_start: float
59
+ t_end: float
60
+ n_grid: int
61
+ clusters: np.ndarray # shape (K,3): [t0, teff, dchi2]
62
+
63
+
64
+ @dataclass(frozen=True)
65
+ class AnomalyResult:
66
+ """
67
+ Output of :meth:`scanomaly.finder.Finder.run`.
68
+
69
+ This object is designed to be convenient for plotting and downstream analysis.
70
+ Arrays are stored on CPU as NumPy arrays.
71
+
72
+ Attributes
73
+ ----------
74
+ time, flux, ferr : np.ndarray
75
+ Input light curve arrays (1D).
76
+ fit : PSPLFitResult
77
+ PSPL fitting result (contains params, fs, fb, chi2, model_flux, residual, etc.).
78
+ residual : np.ndarray
79
+ Flux residuals on CPU: flux - model_flux.
80
+ model_flux : np.ndarray
81
+ PSPL model flux on CPU.
82
+ chi2_dof : float
83
+ Reduced chi-square of the PSPL fit.
84
+ seasons : list[SeasonSummary]
85
+ Per-season summaries including clusters.
86
+ clusters_all : np.ndarray
87
+ Flattened clusters across all seasons, shape (N, 3) with rows [t0, teff, dchi2].
88
+ best : BestCandidate | None
89
+ Best candidate over all clusters, or None if no candidate exists.
90
+ """
91
+ # input (CPU numpy arrays for fast plotting)
92
+ time: np.ndarray
93
+ flux: np.ndarray
94
+ ferr: np.ndarray
95
+
96
+ # PSPL fit
97
+ fit: PSPLFitResult
98
+ residual: np.ndarray
99
+ model_flux: np.ndarray
100
+ chi2_dof: float
101
+
102
+ # grid/clusters
103
+ seasons: List[SeasonSummary]
104
+ clusters_all: np.ndarray # shape (N,3)
105
+
106
+ # best candidate
107
+ best: Optional[BestCandidate]