sgn-drift 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,150 @@
1
+ """Core PSD estimation logic classes (Math only).
2
+ Refactored to remove invalid boundary zeroing and enforce input shape.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from abc import ABC, abstractmethod
8
+ from collections import deque
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+ import numpy as np
13
+ from scipy.special import loggamma
14
+ from sympy import EulerGamma
15
+
16
+ EULERGAMMA = float(EulerGamma.evalf())
17
+
18
+
19
+ @dataclass
20
+ class BaseEstimator(ABC):
21
+ """Base class for PSD estimation logic."""
22
+
23
+ size: int
24
+ normalization: float = 1.0
25
+
26
+ # Internal State
27
+ n_samples: int = field(init=False, default=0)
28
+ current_psd: np.ndarray = field(init=False, repr=False, default=None)
29
+
30
+ def __post_init__(self):
31
+ # Initialize with ones to avoid divide-by-zero
32
+ self.current_psd = np.ones(self.size)
33
+
34
+ def _check_shape(self, data: np.ndarray) -> None:
35
+ """Validate input data shape matches estimator configuration."""
36
+ if data.shape[-1] != self.size:
37
+ raise ValueError(
38
+ f"Input data size {data.shape[-1]} does not match estimator size {self.size}"
39
+ )
40
+
41
+ @abstractmethod
42
+ def update(self, data: np.ndarray) -> None:
43
+ """Update state with new frequency-domain data."""
44
+ pass
45
+
46
+ def get_psd(self) -> np.ndarray:
47
+ return self.current_psd
48
+
49
+
50
+ @dataclass
51
+ class MGMEstimator(BaseEstimator):
52
+ """
53
+ Median-Geometric-Mean Estimator (Standard LIGO).
54
+ """
55
+
56
+ n_median: int = 7
57
+ n_average: int = 64
58
+
59
+ # Internal State
60
+ history: deque = field(init=False, repr=False, default=None)
61
+ geo_mean_log: Optional[np.ndarray] = field(init=False, repr=False, default=None)
62
+
63
+ def __post_init__(self):
64
+ super().__post_init__()
65
+ self.history = deque(maxlen=self.n_median)
66
+
67
+ @staticmethod
68
+ def _median_bias(nn):
69
+ """XLALMedianBias"""
70
+ ans = 1.0
71
+ n = (nn - 1) // 2
72
+ for i in range(1, n + 1):
73
+ ans -= 1.0 / (2 * i)
74
+ ans += 1.0 / (2 * i + 1)
75
+ return ans
76
+
77
+ @staticmethod
78
+ def _log_median_bias_geometric(nn):
79
+ """XLALLogMedianBiasGeometric"""
80
+ return np.log(MGMEstimator._median_bias(nn)) - nn * (
81
+ loggamma(1.0 / nn) - np.log(nn)
82
+ )
83
+
84
+ def update(self, data: np.ndarray) -> None:
85
+ self._check_shape(data)
86
+
87
+ if np.iscomplexobj(data):
88
+ power = np.abs(data) ** 2
89
+ else:
90
+ power = data
91
+
92
+ self.history.append(power)
93
+
94
+ if self.n_samples == 0:
95
+ self.geo_mean_log = np.log(power)
96
+ self.n_samples += 1
97
+ else:
98
+ self.n_samples = min(self.n_samples + 1, self.n_average)
99
+
100
+ bias = self._log_median_bias_geometric(len(self.history))
101
+
102
+ # Match Legacy: use sort and integer index
103
+ stacked = np.array(self.history)
104
+ sorted_bins = np.sort(stacked, axis=0)
105
+ idx = len(self.history) // 2
106
+ log_bin_median = np.log(sorted_bins[idx])
107
+
108
+ self.geo_mean_log = (
109
+ self.geo_mean_log * (self.n_samples - 1) + log_bin_median - bias
110
+ ) / self.n_samples
111
+
112
+ self.current_psd = np.exp(self.geo_mean_log + EULERGAMMA) * self.normalization
113
+
114
+ def set_reference(self, psd: np.ndarray, weight: int):
115
+ self._check_shape(psd)
116
+
117
+ raw = psd / self.normalization
118
+ # Avoid log(0)
119
+ raw = np.where(raw > 0, raw, 1e-300)
120
+
121
+ self.history.clear()
122
+ for _ in range(self.n_median):
123
+ self.history.append(raw)
124
+
125
+ self.geo_mean_log = np.log(raw) - EULERGAMMA
126
+ self.n_samples = min(weight, self.n_average)
127
+ self.current_psd = psd.copy()
128
+
129
+
130
+ @dataclass
131
+ class RecursiveEstimator(BaseEstimator):
132
+ """
133
+ Exponential Moving Average Estimator.
134
+ """
135
+
136
+ alpha: float = 0.1
137
+ _initialized: bool = field(init=False, default=False)
138
+
139
+ def update(self, data: np.ndarray) -> None:
140
+ self._check_shape(data)
141
+
142
+ power = (
143
+ np.abs(data) ** 2 if np.iscomplexobj(data) else data
144
+ ) * self.normalization
145
+
146
+ if not self._initialized:
147
+ self.current_psd = power
148
+ self._initialized = True
149
+ else:
150
+ self.current_psd = (1 - self.alpha) * self.current_psd + self.alpha * power
File without changes
@@ -0,0 +1,154 @@
1
+ """
2
+ Sinks for Drift Events.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import csv
8
+ import os
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, ClassVar, Optional, TextIO
11
+
12
+ from sgn.base import SinkPad
13
+ from sgnts.base import EventFrame, TSFrame, TSSink
14
+ from sgndrift.transforms.drift import DriftEvent
15
+
16
+
17
+ @dataclass
18
+ class DriftCSVSink(TSSink):
19
+ """
20
+ Writes DriftEvent data to a CSV file.
21
+ Inherits from TSSink to integrate with sgn-ts pipelines.
22
+ """
23
+
24
+ filename: str = "drift.csv"
25
+
26
+ # Mark 'in' as unaligned to prevent Audioadapter creation for discrete events
27
+ static_unaligned_sink_pads: ClassVar[list[str]] = ["in"]
28
+
29
+ # Internal state
30
+ _file: Optional[TextIO] = field(init=False, repr=False, default=None)
31
+ _writer: Any = field(init=False, repr=False, default=None)
32
+
33
+ def __post_init__(self):
34
+ # Force all input pads to be unaligned to prevent Audioadapter creation.
35
+ # This is necessary because EventFrames are discrete and lack sample rates.
36
+ # We set this before super().__post_init__() so TimeSeriesMixin uses it.
37
+ self.unaligned = list(self.sink_pad_names)
38
+ super().__post_init__()
39
+
40
+ def configure(self) -> None:
41
+ """Configure input frame types to expect EventFrame."""
42
+ for name in self.sink_pad_names:
43
+ self.input_frame_types[name] = EventFrame
44
+
45
+ @property
46
+ def min_latest(self) -> int:
47
+ """
48
+ Override min_latest to handle the case where all inputs are unaligned.
49
+ Base implementation crashes if self.inbufs is empty.
50
+ """
51
+ if not self.inbufs:
52
+ latest_offsets = []
53
+ for pad in self.unaligned_sink_pads:
54
+ frame = self.unaligned_data.get(pad)
55
+ if frame and hasattr(frame, "data") and frame.data:
56
+ # Assuming frame.data is list of buffers
57
+ latest_offsets.append(frame.data[-1].noffset)
58
+ return max(latest_offsets) if latest_offsets else 0
59
+ return super().min_latest
60
+
61
+ @property
62
+ def earliest(self) -> int:
63
+ """
64
+ Override earliest to handle the case where all inputs are unaligned.
65
+ """
66
+ if not self.inbufs:
67
+ earliest_offsets = []
68
+ for pad in self.unaligned_sink_pads:
69
+ frame = self.unaligned_data.get(pad)
70
+ if frame and hasattr(frame, "data") and frame.data:
71
+ earliest_offsets.append(frame.data[0].offset)
72
+ return min(earliest_offsets) if earliest_offsets else 0
73
+ return super().earliest
74
+
75
+ def _align(self) -> None:
76
+ """
77
+ Override alignment logic.
78
+ Since input is unaligned, base class _align() would fail.
79
+ We simply check if unaligned data is present.
80
+ """
81
+ # Assume alignment is satisfied if we have data on the first pad
82
+ # For multiple pads, we might want to check all, but TSSink usually has one.
83
+ if not self.sink_pads:
84
+ self._is_aligned = False
85
+ return
86
+
87
+ sink_pad = self.sink_pads[0]
88
+ if self.unaligned_data.get(sink_pad) is not None:
89
+ self._is_aligned = True
90
+ else:
91
+ self._is_aligned = False
92
+
93
+ def process(self, input_frames: dict[SinkPad, TSFrame]) -> None:
94
+ """
95
+ Process incoming frames and write to CSV.
96
+ TSSink.internal() calls this method with frames collected from all pads.
97
+ """
98
+ # We assume a single sink pad named "in"
99
+ # Since we configured the pad to expect EventFrame, input_frames contains EventFrames.
100
+ if not self.sink_pads:
101
+ return
102
+
103
+ sink_pad = self.sink_pads[0]
104
+ frame = input_frames.get(sink_pad)
105
+
106
+ if frame is None:
107
+ return
108
+
109
+ if frame.EOS:
110
+ self.mark_eos(sink_pad)
111
+
112
+ if frame.is_gap:
113
+ return
114
+
115
+ # Check for data
116
+ if not hasattr(frame, "data") or not frame.data:
117
+ return
118
+
119
+ for buf in frame.data:
120
+ if not hasattr(buf, "data") or not buf.data:
121
+ continue
122
+
123
+ for event in buf.data:
124
+ if not isinstance(event, DriftEvent):
125
+ continue
126
+
127
+ row = {"time": event.epoch}
128
+ row.update(event.data)
129
+
130
+ if self._file is None:
131
+ self._open_file(row.keys())
132
+
133
+ self._writer.writerow(row)
134
+
135
+ if self._file:
136
+ self._file.flush()
137
+
138
+ def _open_file(self, keys):
139
+ exists = os.path.exists(self.filename)
140
+ self._file = open(self.filename, "a", newline="")
141
+ # Ensure deterministic column order with 'time' first
142
+ data_keys = sorted([k for k in keys if k != "time"])
143
+ fieldnames = ["time"] + data_keys
144
+ self._writer = csv.DictWriter(self._file, fieldnames=fieldnames)
145
+ if not exists:
146
+ self._writer.writeheader()
147
+
148
+ def cleanup(self):
149
+ if self._file:
150
+ self._file.close()
151
+ self._file = None
152
+
153
+ def __del__(self):
154
+ self.cleanup()
File without changes
@@ -0,0 +1,145 @@
1
+ """
2
+ Geometric Diagnostics: Elements for tracking the manifold velocity of detector noise.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import ClassVar, Dict, Tuple
9
+
10
+ import numpy as np
11
+
12
+ from sgn.base import SourcePad
13
+ from sgndrift.transforms.psd import PSDEvent
14
+ from sgndrift.psd.drift import calculate_fisher_velocity
15
+ from sgnts.base import EventBuffer, EventFrame, TSTransform
16
+
17
+
18
+ @dataclass
19
+ class DriftEvent:
20
+ """
21
+ Container for Fisher Information Velocity (Drift) data.
22
+ """
23
+
24
+ epoch: float
25
+ data: Dict[str, float]
26
+
27
+
28
+ @dataclass
29
+ class FisherVelocity(TSTransform):
30
+ """
31
+ Computes Fisher Information Velocity (Geometric Drift) between consecutive PSDs.
32
+
33
+ Wraps sgndrift.psd.drift.calculate_fisher_velocity.
34
+
35
+ Inputs:
36
+ EventFrame containing PSDEvent objects.
37
+
38
+ Outputs:
39
+ EventFrame containing DriftEvent objects.
40
+ """
41
+
42
+ # Mark 'in' as unaligned to prevent TimeSeriesMixin from creating an Audioadapter.
43
+ static_unaligned_sink_pads: ClassVar[list[str]] = ["in"]
44
+
45
+ bands: Dict[str, Tuple[float, float]] = field(default_factory=dict)
46
+
47
+ _prev_data: np.ndarray = field(init=False, repr=False, default=None)
48
+ _prev_epoch: float = field(init=False, repr=False, default=None)
49
+
50
+ def configure(self) -> None:
51
+ """Configure element-specific attributes."""
52
+ # Inform the element that it handles EventFrames
53
+ for name in self.sink_pad_names:
54
+ self.input_frame_types[name] = EventFrame
55
+ for name in self.source_pad_names:
56
+ self.output_frame_types[name] = EventFrame
57
+
58
+ @property
59
+ def min_latest(self) -> int:
60
+ if not self.inbufs:
61
+ latest_offsets = []
62
+ for pad in self.unaligned_sink_pads:
63
+ frame = self.unaligned_data.get(pad)
64
+ if frame and frame.data:
65
+ latest_offsets.append(frame.data[-1].noffset)
66
+ return max(latest_offsets) if latest_offsets else 0
67
+ return super().min_latest
68
+
69
+ @property
70
+ def earliest(self) -> int:
71
+ if not self.inbufs:
72
+ earliest_offsets = []
73
+ for pad in self.unaligned_sink_pads:
74
+ frame = self.unaligned_data.get(pad)
75
+ if frame and frame.data:
76
+ earliest_offsets.append(frame.data[0].offset)
77
+ return min(earliest_offsets) if earliest_offsets else 0
78
+ return super().earliest
79
+
80
+ def _align(self) -> None:
81
+ sink_pad = self.sink_pads[0]
82
+ if self.unaligned_data.get(sink_pad) is not None:
83
+ self._is_aligned = True
84
+ else:
85
+ self._is_aligned = False
86
+
87
+ def new(self, pad: SourcePad) -> EventFrame:
88
+ sink_pad = self.sink_pads[0]
89
+ in_frame = self.unaligned_data.get(sink_pad)
90
+ self.unaligned_data[sink_pad] = None
91
+
92
+ if in_frame is None or in_frame.is_gap:
93
+ return EventFrame(is_gap=True, EOS=in_frame.EOS if in_frame else False)
94
+
95
+ if not hasattr(in_frame, "data") or not in_frame.data:
96
+ return EventFrame(is_gap=True, EOS=in_frame.EOS)
97
+ if not in_frame.data[0].data:
98
+ return EventFrame(is_gap=True, EOS=in_frame.EOS)
99
+
100
+ psd_event = in_frame.data[0].data[0]
101
+
102
+ if not isinstance(psd_event, PSDEvent):
103
+ return EventFrame(is_gap=True, EOS=in_frame.EOS)
104
+
105
+ current_data = psd_event.data
106
+ current_epoch = psd_event.epoch
107
+ freqs = psd_event.frequencies
108
+ df = psd_event.delta_f
109
+
110
+ drift_results = {}
111
+
112
+ # Only calculate if we have history
113
+ if self._prev_data is not None:
114
+ dt = current_epoch - self._prev_epoch
115
+ if dt > 0:
116
+ drift_results = calculate_fisher_velocity(
117
+ current_psd=current_data,
118
+ previous_psd=self._prev_data,
119
+ dt=dt,
120
+ frequencies=freqs,
121
+ delta_f=df,
122
+ bands=self.bands,
123
+ )
124
+
125
+ # Update History
126
+ self._prev_data = current_data.copy()
127
+ self._prev_epoch = current_epoch
128
+
129
+ # Handle startup transient (return zeros instead of empty)
130
+ if not drift_results:
131
+ bands_keys = self.bands.keys() if self.bands else ["total"]
132
+ drift_results = {k: 0.0 for k in bands_keys}
133
+
134
+ out_event = DriftEvent(epoch=current_epoch, data=drift_results)
135
+
136
+ buf = in_frame.data[0]
137
+ ts = buf.offset
138
+ dur = (
139
+ buf.duration if hasattr(buf, "duration") and buf.duration else 1_000_000_000
140
+ )
141
+ te = ts + dur
142
+
143
+ out_buf = EventBuffer.from_span(ts, te, [out_event])
144
+
145
+ return EventFrame(data=[out_buf], EOS=in_frame.EOS)
@@ -0,0 +1,190 @@
1
+ """
2
+ SGN Elements for PSD Estimation.
3
+ wraps sgnligo.psd.estimators logic into TSTransform elements.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ import scipy.signal
13
+ from sgn.base import SourcePad
14
+ from sgnts.base import (
15
+ AdapterConfig,
16
+ EventBuffer,
17
+ EventFrame,
18
+ Offset,
19
+ TSTransform,
20
+ )
21
+
22
+ from sgndrift.psd.estimators import BaseEstimator, MGMEstimator, RecursiveEstimator
23
+
24
+ # Optional LAL import for conversion methods
25
+ try:
26
+ import lal
27
+ except ImportError:
28
+ lal = None
29
+
30
+
31
+ @dataclass
32
+ class PSDEvent:
33
+ """
34
+ Container for a PSD estimate event (Pure Python/NumPy).
35
+ Decoupled from LAL to ensure stability in non-LAL environments.
36
+ """
37
+
38
+ data: np.ndarray
39
+ frequencies: np.ndarray
40
+ epoch: float
41
+ delta_f: float
42
+
43
+ def to_lal(self) -> Optional[object]:
44
+ """
45
+ Convert to LAL REAL8FrequencySeries if LAL is available.
46
+ Uses standard 'strain^2 s' unit definition.
47
+ """
48
+ if lal is None:
49
+ return None
50
+
51
+ try:
52
+ # Standard unit construction used in sgnligo.psd.psd
53
+ unit = lal.Unit("strain^2 s")
54
+
55
+ series = lal.CreateREAL8FrequencySeries(
56
+ "psd",
57
+ lal.LIGOTimeGPS(self.epoch),
58
+ 0.0,
59
+ self.delta_f,
60
+ unit,
61
+ len(self.data),
62
+ )
63
+ series.data.data = self.data
64
+ return series
65
+ except Exception:
66
+ return None
67
+
68
+
69
+ @dataclass
70
+ class PSDEstimator(TSTransform):
71
+ """
72
+ Base TSTransform for PSD Estimation.
73
+ Outputs EventFrame containing PSDEvent objects.
74
+ """
75
+
76
+ fft_length: float = 4.0
77
+ overlap: float = 0.5
78
+ sample_rate: int = 16384
79
+ window_type: str = "hann"
80
+
81
+ # Internal state
82
+ _estimator: BaseEstimator = field(init=False, repr=False, default=None)
83
+ _window: np.ndarray = field(init=False, repr=False, default=None)
84
+ _freqs: np.ndarray = field(init=False, repr=False, default=None)
85
+ _norm_factor: float = field(init=False, repr=False, default=1.0)
86
+ _delta_f: float = field(init=False, repr=False, default=0.0)
87
+
88
+ def __post_init__(self):
89
+ n_samples = int(self.fft_length * self.sample_rate)
90
+ stride = int(n_samples * (1 - self.overlap))
91
+ overlap_samples = n_samples - stride
92
+
93
+ self.adapter_config = AdapterConfig()
94
+ self.adapter_config.stride = Offset.fromsamples(stride, self.sample_rate)
95
+ self.adapter_config.overlap = (
96
+ 0,
97
+ Offset.fromsamples(overlap_samples, self.sample_rate),
98
+ )
99
+ self.adapter_config.skip_gaps = True
100
+
101
+ super().__post_init__()
102
+
103
+ self._window = scipy.signal.get_window(self.window_type, n_samples)
104
+ s2 = np.sum(self._window**2)
105
+ self._norm_factor = 2.0 / (self.sample_rate * s2)
106
+
107
+ self._freqs = np.fft.rfftfreq(n_samples, d=1 / self.sample_rate)
108
+ self._delta_f = self._freqs[1] - self._freqs[0]
109
+
110
+ self._init_estimator(len(self._freqs))
111
+
112
+ def _init_estimator(self, size: int):
113
+ raise NotImplementedError
114
+
115
+ def new(self, pad: SourcePad) -> EventFrame:
116
+ in_frame = self.preparedframes[self.sink_pads[0]]
117
+
118
+ if in_frame.is_gap or not in_frame.buffers:
119
+ return EventFrame(is_gap=True, EOS=in_frame.EOS)
120
+
121
+ buf = in_frame.buffers[0]
122
+ data = buf.data
123
+
124
+ if len(data) != len(self._window):
125
+ return EventFrame(is_gap=True, EOS=in_frame.EOS)
126
+
127
+ # 1. Compute FFT
128
+ windowed = data * self._window
129
+ fft_data = np.fft.rfft(windowed)
130
+
131
+ # 2. Update Estimator
132
+ self._estimator.update(fft_data)
133
+ psd_data = self._estimator.get_psd().copy()
134
+
135
+ # 3. Create Output Event
136
+ # Calculate timestamps in nanoseconds for EventBuffer
137
+ ts = Offset.tons(buf.offset)
138
+ # Duration is derived from buffer length
139
+ duration_samples = len(data)
140
+ duration_offset = Offset.fromsamples(duration_samples, self.sample_rate)
141
+ te = Offset.tons(buf.offset + duration_offset)
142
+
143
+ # Epoch for PSD metadata (start of window)
144
+ epoch = Offset.tosec(buf.offset)
145
+
146
+ event = PSDEvent(
147
+ data=psd_data, frequencies=self._freqs, epoch=epoch, delta_f=self._delta_f
148
+ )
149
+
150
+ # Use factory method to avoid constructor signature issues
151
+ out_buf = EventBuffer.from_span(ts, te, [event])
152
+
153
+ meta = in_frame.metadata.copy() if in_frame.metadata else {}
154
+
155
+ lal_obj = event.to_lal()
156
+ if lal_obj:
157
+ meta["psd"] = lal_obj
158
+
159
+ meta["psd_numpy"] = psd_data
160
+ meta["psd_freqs"] = self._freqs
161
+
162
+ return EventFrame(data=[out_buf], metadata=meta, EOS=in_frame.EOS)
163
+
164
+
165
+ @dataclass
166
+ class RecursivePSD(PSDEstimator):
167
+ """Fast, Low-Latency PSD Estimator."""
168
+
169
+ alpha: float = 0.1
170
+
171
+ def _init_estimator(self, size: int):
172
+ self._estimator = RecursiveEstimator(
173
+ size=size, normalization=self._norm_factor, alpha=self.alpha
174
+ )
175
+
176
+
177
+ @dataclass
178
+ class MGMPSD(PSDEstimator):
179
+ """Standard Median-Geometric-Mean PSD Estimator."""
180
+
181
+ n_median: int = 7
182
+ n_average: int = 64
183
+
184
+ def _init_estimator(self, size: int):
185
+ self._estimator = MGMEstimator(
186
+ size=size,
187
+ normalization=self._norm_factor,
188
+ n_median=self.n_median,
189
+ n_average=self.n_average,
190
+ )