sgn-drift 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sgn_drift-0.1.0.dist-info/METADATA +91 -0
- sgn_drift-0.1.0.dist-info/RECORD +22 -0
- sgn_drift-0.1.0.dist-info/WHEEL +5 -0
- sgn_drift-0.1.0.dist-info/entry_points.txt +7 -0
- sgn_drift-0.1.0.dist-info/top_level.txt +1 -0
- sgndrift/__init__.py +0 -0
- sgndrift/_version.py +34 -0
- sgndrift/bin/__init__.py +0 -0
- sgndrift/bin/estimate_drift.py +278 -0
- sgndrift/bin/plot_drift.py +177 -0
- sgndrift/bin/plot_drift_comparison.py +211 -0
- sgndrift/bin/plot_drift_super.py +272 -0
- sgndrift/bin/plot_drift_super_comp.py +360 -0
- sgndrift/bin/plot_drift_time.py +210 -0
- sgndrift/psd/__init__.py +0 -0
- sgndrift/psd/drift.py +73 -0
- sgndrift/psd/estimators.py +150 -0
- sgndrift/sinks/__init__.py +0 -0
- sgndrift/sinks/drift_sink.py +154 -0
- sgndrift/transforms/__init__.py +0 -0
- sgndrift/transforms/drift.py +145 -0
- sgndrift/transforms/psd.py +190 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Core PSD estimation logic classes (Math only).
|
|
2
|
+
Refactored to remove invalid boundary zeroing and enforce input shape.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections import deque
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from scipy.special import loggamma
|
|
14
|
+
from sympy import EulerGamma
|
|
15
|
+
|
|
16
|
+
EULERGAMMA = float(EulerGamma.evalf())
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class BaseEstimator(ABC):
|
|
21
|
+
"""Base class for PSD estimation logic."""
|
|
22
|
+
|
|
23
|
+
size: int
|
|
24
|
+
normalization: float = 1.0
|
|
25
|
+
|
|
26
|
+
# Internal State
|
|
27
|
+
n_samples: int = field(init=False, default=0)
|
|
28
|
+
current_psd: np.ndarray = field(init=False, repr=False, default=None)
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
# Initialize with ones to avoid divide-by-zero
|
|
32
|
+
self.current_psd = np.ones(self.size)
|
|
33
|
+
|
|
34
|
+
def _check_shape(self, data: np.ndarray) -> None:
|
|
35
|
+
"""Validate input data shape matches estimator configuration."""
|
|
36
|
+
if data.shape[-1] != self.size:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Input data size {data.shape[-1]} does not match estimator size {self.size}"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def update(self, data: np.ndarray) -> None:
|
|
43
|
+
"""Update state with new frequency-domain data."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def get_psd(self) -> np.ndarray:
|
|
47
|
+
return self.current_psd
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class MGMEstimator(BaseEstimator):
|
|
52
|
+
"""
|
|
53
|
+
Median-Geometric-Mean Estimator (Standard LIGO).
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
n_median: int = 7
|
|
57
|
+
n_average: int = 64
|
|
58
|
+
|
|
59
|
+
# Internal State
|
|
60
|
+
history: deque = field(init=False, repr=False, default=None)
|
|
61
|
+
geo_mean_log: Optional[np.ndarray] = field(init=False, repr=False, default=None)
|
|
62
|
+
|
|
63
|
+
def __post_init__(self):
|
|
64
|
+
super().__post_init__()
|
|
65
|
+
self.history = deque(maxlen=self.n_median)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def _median_bias(nn):
|
|
69
|
+
"""XLALMedianBias"""
|
|
70
|
+
ans = 1.0
|
|
71
|
+
n = (nn - 1) // 2
|
|
72
|
+
for i in range(1, n + 1):
|
|
73
|
+
ans -= 1.0 / (2 * i)
|
|
74
|
+
ans += 1.0 / (2 * i + 1)
|
|
75
|
+
return ans
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _log_median_bias_geometric(nn):
|
|
79
|
+
"""XLALLogMedianBiasGeometric"""
|
|
80
|
+
return np.log(MGMEstimator._median_bias(nn)) - nn * (
|
|
81
|
+
loggamma(1.0 / nn) - np.log(nn)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def update(self, data: np.ndarray) -> None:
|
|
85
|
+
self._check_shape(data)
|
|
86
|
+
|
|
87
|
+
if np.iscomplexobj(data):
|
|
88
|
+
power = np.abs(data) ** 2
|
|
89
|
+
else:
|
|
90
|
+
power = data
|
|
91
|
+
|
|
92
|
+
self.history.append(power)
|
|
93
|
+
|
|
94
|
+
if self.n_samples == 0:
|
|
95
|
+
self.geo_mean_log = np.log(power)
|
|
96
|
+
self.n_samples += 1
|
|
97
|
+
else:
|
|
98
|
+
self.n_samples = min(self.n_samples + 1, self.n_average)
|
|
99
|
+
|
|
100
|
+
bias = self._log_median_bias_geometric(len(self.history))
|
|
101
|
+
|
|
102
|
+
# Match Legacy: use sort and integer index
|
|
103
|
+
stacked = np.array(self.history)
|
|
104
|
+
sorted_bins = np.sort(stacked, axis=0)
|
|
105
|
+
idx = len(self.history) // 2
|
|
106
|
+
log_bin_median = np.log(sorted_bins[idx])
|
|
107
|
+
|
|
108
|
+
self.geo_mean_log = (
|
|
109
|
+
self.geo_mean_log * (self.n_samples - 1) + log_bin_median - bias
|
|
110
|
+
) / self.n_samples
|
|
111
|
+
|
|
112
|
+
self.current_psd = np.exp(self.geo_mean_log + EULERGAMMA) * self.normalization
|
|
113
|
+
|
|
114
|
+
def set_reference(self, psd: np.ndarray, weight: int):
|
|
115
|
+
self._check_shape(psd)
|
|
116
|
+
|
|
117
|
+
raw = psd / self.normalization
|
|
118
|
+
# Avoid log(0)
|
|
119
|
+
raw = np.where(raw > 0, raw, 1e-300)
|
|
120
|
+
|
|
121
|
+
self.history.clear()
|
|
122
|
+
for _ in range(self.n_median):
|
|
123
|
+
self.history.append(raw)
|
|
124
|
+
|
|
125
|
+
self.geo_mean_log = np.log(raw) - EULERGAMMA
|
|
126
|
+
self.n_samples = min(weight, self.n_average)
|
|
127
|
+
self.current_psd = psd.copy()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class RecursiveEstimator(BaseEstimator):
|
|
132
|
+
"""
|
|
133
|
+
Exponential Moving Average Estimator.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
alpha: float = 0.1
|
|
137
|
+
_initialized: bool = field(init=False, default=False)
|
|
138
|
+
|
|
139
|
+
def update(self, data: np.ndarray) -> None:
|
|
140
|
+
self._check_shape(data)
|
|
141
|
+
|
|
142
|
+
power = (
|
|
143
|
+
np.abs(data) ** 2 if np.iscomplexobj(data) else data
|
|
144
|
+
) * self.normalization
|
|
145
|
+
|
|
146
|
+
if not self._initialized:
|
|
147
|
+
self.current_psd = power
|
|
148
|
+
self._initialized = True
|
|
149
|
+
else:
|
|
150
|
+
self.current_psd = (1 - self.alpha) * self.current_psd + self.alpha * power
|
|
File without changes
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sinks for Drift Events.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import csv
|
|
8
|
+
import os
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any, ClassVar, Optional, TextIO
|
|
11
|
+
|
|
12
|
+
from sgn.base import SinkPad
|
|
13
|
+
from sgnts.base import EventFrame, TSFrame, TSSink
|
|
14
|
+
from sgndrift.transforms.drift import DriftEvent
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class DriftCSVSink(TSSink):
|
|
19
|
+
"""
|
|
20
|
+
Writes DriftEvent data to a CSV file.
|
|
21
|
+
Inherits from TSSink to integrate with sgn-ts pipelines.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
filename: str = "drift.csv"
|
|
25
|
+
|
|
26
|
+
# Mark 'in' as unaligned to prevent Audioadapter creation for discrete events
|
|
27
|
+
static_unaligned_sink_pads: ClassVar[list[str]] = ["in"]
|
|
28
|
+
|
|
29
|
+
# Internal state
|
|
30
|
+
_file: Optional[TextIO] = field(init=False, repr=False, default=None)
|
|
31
|
+
_writer: Any = field(init=False, repr=False, default=None)
|
|
32
|
+
|
|
33
|
+
def __post_init__(self):
|
|
34
|
+
# Force all input pads to be unaligned to prevent Audioadapter creation.
|
|
35
|
+
# This is necessary because EventFrames are discrete and lack sample rates.
|
|
36
|
+
# We set this before super().__post_init__() so TimeSeriesMixin uses it.
|
|
37
|
+
self.unaligned = list(self.sink_pad_names)
|
|
38
|
+
super().__post_init__()
|
|
39
|
+
|
|
40
|
+
def configure(self) -> None:
|
|
41
|
+
"""Configure input frame types to expect EventFrame."""
|
|
42
|
+
for name in self.sink_pad_names:
|
|
43
|
+
self.input_frame_types[name] = EventFrame
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def min_latest(self) -> int:
|
|
47
|
+
"""
|
|
48
|
+
Override min_latest to handle the case where all inputs are unaligned.
|
|
49
|
+
Base implementation crashes if self.inbufs is empty.
|
|
50
|
+
"""
|
|
51
|
+
if not self.inbufs:
|
|
52
|
+
latest_offsets = []
|
|
53
|
+
for pad in self.unaligned_sink_pads:
|
|
54
|
+
frame = self.unaligned_data.get(pad)
|
|
55
|
+
if frame and hasattr(frame, "data") and frame.data:
|
|
56
|
+
# Assuming frame.data is list of buffers
|
|
57
|
+
latest_offsets.append(frame.data[-1].noffset)
|
|
58
|
+
return max(latest_offsets) if latest_offsets else 0
|
|
59
|
+
return super().min_latest
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def earliest(self) -> int:
|
|
63
|
+
"""
|
|
64
|
+
Override earliest to handle the case where all inputs are unaligned.
|
|
65
|
+
"""
|
|
66
|
+
if not self.inbufs:
|
|
67
|
+
earliest_offsets = []
|
|
68
|
+
for pad in self.unaligned_sink_pads:
|
|
69
|
+
frame = self.unaligned_data.get(pad)
|
|
70
|
+
if frame and hasattr(frame, "data") and frame.data:
|
|
71
|
+
earliest_offsets.append(frame.data[0].offset)
|
|
72
|
+
return min(earliest_offsets) if earliest_offsets else 0
|
|
73
|
+
return super().earliest
|
|
74
|
+
|
|
75
|
+
def _align(self) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Override alignment logic.
|
|
78
|
+
Since input is unaligned, base class _align() would fail.
|
|
79
|
+
We simply check if unaligned data is present.
|
|
80
|
+
"""
|
|
81
|
+
# Assume alignment is satisfied if we have data on the first pad
|
|
82
|
+
# For multiple pads, we might want to check all, but TSSink usually has one.
|
|
83
|
+
if not self.sink_pads:
|
|
84
|
+
self._is_aligned = False
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
sink_pad = self.sink_pads[0]
|
|
88
|
+
if self.unaligned_data.get(sink_pad) is not None:
|
|
89
|
+
self._is_aligned = True
|
|
90
|
+
else:
|
|
91
|
+
self._is_aligned = False
|
|
92
|
+
|
|
93
|
+
def process(self, input_frames: dict[SinkPad, TSFrame]) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Process incoming frames and write to CSV.
|
|
96
|
+
TSSink.internal() calls this method with frames collected from all pads.
|
|
97
|
+
"""
|
|
98
|
+
# We assume a single sink pad named "in"
|
|
99
|
+
# Since we configured the pad to expect EventFrame, input_frames contains EventFrames.
|
|
100
|
+
if not self.sink_pads:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
sink_pad = self.sink_pads[0]
|
|
104
|
+
frame = input_frames.get(sink_pad)
|
|
105
|
+
|
|
106
|
+
if frame is None:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
if frame.EOS:
|
|
110
|
+
self.mark_eos(sink_pad)
|
|
111
|
+
|
|
112
|
+
if frame.is_gap:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
# Check for data
|
|
116
|
+
if not hasattr(frame, "data") or not frame.data:
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
for buf in frame.data:
|
|
120
|
+
if not hasattr(buf, "data") or not buf.data:
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
for event in buf.data:
|
|
124
|
+
if not isinstance(event, DriftEvent):
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
row = {"time": event.epoch}
|
|
128
|
+
row.update(event.data)
|
|
129
|
+
|
|
130
|
+
if self._file is None:
|
|
131
|
+
self._open_file(row.keys())
|
|
132
|
+
|
|
133
|
+
self._writer.writerow(row)
|
|
134
|
+
|
|
135
|
+
if self._file:
|
|
136
|
+
self._file.flush()
|
|
137
|
+
|
|
138
|
+
def _open_file(self, keys):
|
|
139
|
+
exists = os.path.exists(self.filename)
|
|
140
|
+
self._file = open(self.filename, "a", newline="")
|
|
141
|
+
# Ensure deterministic column order with 'time' first
|
|
142
|
+
data_keys = sorted([k for k in keys if k != "time"])
|
|
143
|
+
fieldnames = ["time"] + data_keys
|
|
144
|
+
self._writer = csv.DictWriter(self._file, fieldnames=fieldnames)
|
|
145
|
+
if not exists:
|
|
146
|
+
self._writer.writeheader()
|
|
147
|
+
|
|
148
|
+
def cleanup(self):
|
|
149
|
+
if self._file:
|
|
150
|
+
self._file.close()
|
|
151
|
+
self._file = None
|
|
152
|
+
|
|
153
|
+
def __del__(self):
|
|
154
|
+
self.cleanup()
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Geometric Diagnostics: Elements for tracking the manifold velocity of detector noise.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import ClassVar, Dict, Tuple
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from sgn.base import SourcePad
|
|
13
|
+
from sgndrift.transforms.psd import PSDEvent
|
|
14
|
+
from sgndrift.psd.drift import calculate_fisher_velocity
|
|
15
|
+
from sgnts.base import EventBuffer, EventFrame, TSTransform
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class DriftEvent:
|
|
20
|
+
"""
|
|
21
|
+
Container for Fisher Information Velocity (Drift) data.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
epoch: float
|
|
25
|
+
data: Dict[str, float]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class FisherVelocity(TSTransform):
|
|
30
|
+
"""
|
|
31
|
+
Computes Fisher Information Velocity (Geometric Drift) between consecutive PSDs.
|
|
32
|
+
|
|
33
|
+
Wraps sgndrift.psd.drift.calculate_fisher_velocity.
|
|
34
|
+
|
|
35
|
+
Inputs:
|
|
36
|
+
EventFrame containing PSDEvent objects.
|
|
37
|
+
|
|
38
|
+
Outputs:
|
|
39
|
+
EventFrame containing DriftEvent objects.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# Mark 'in' as unaligned to prevent TimeSeriesMixin from creating an Audioadapter.
|
|
43
|
+
static_unaligned_sink_pads: ClassVar[list[str]] = ["in"]
|
|
44
|
+
|
|
45
|
+
bands: Dict[str, Tuple[float, float]] = field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
_prev_data: np.ndarray = field(init=False, repr=False, default=None)
|
|
48
|
+
_prev_epoch: float = field(init=False, repr=False, default=None)
|
|
49
|
+
|
|
50
|
+
def configure(self) -> None:
|
|
51
|
+
"""Configure element-specific attributes."""
|
|
52
|
+
# Inform the element that it handles EventFrames
|
|
53
|
+
for name in self.sink_pad_names:
|
|
54
|
+
self.input_frame_types[name] = EventFrame
|
|
55
|
+
for name in self.source_pad_names:
|
|
56
|
+
self.output_frame_types[name] = EventFrame
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def min_latest(self) -> int:
|
|
60
|
+
if not self.inbufs:
|
|
61
|
+
latest_offsets = []
|
|
62
|
+
for pad in self.unaligned_sink_pads:
|
|
63
|
+
frame = self.unaligned_data.get(pad)
|
|
64
|
+
if frame and frame.data:
|
|
65
|
+
latest_offsets.append(frame.data[-1].noffset)
|
|
66
|
+
return max(latest_offsets) if latest_offsets else 0
|
|
67
|
+
return super().min_latest
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def earliest(self) -> int:
|
|
71
|
+
if not self.inbufs:
|
|
72
|
+
earliest_offsets = []
|
|
73
|
+
for pad in self.unaligned_sink_pads:
|
|
74
|
+
frame = self.unaligned_data.get(pad)
|
|
75
|
+
if frame and frame.data:
|
|
76
|
+
earliest_offsets.append(frame.data[0].offset)
|
|
77
|
+
return min(earliest_offsets) if earliest_offsets else 0
|
|
78
|
+
return super().earliest
|
|
79
|
+
|
|
80
|
+
def _align(self) -> None:
|
|
81
|
+
sink_pad = self.sink_pads[0]
|
|
82
|
+
if self.unaligned_data.get(sink_pad) is not None:
|
|
83
|
+
self._is_aligned = True
|
|
84
|
+
else:
|
|
85
|
+
self._is_aligned = False
|
|
86
|
+
|
|
87
|
+
def new(self, pad: SourcePad) -> EventFrame:
|
|
88
|
+
sink_pad = self.sink_pads[0]
|
|
89
|
+
in_frame = self.unaligned_data.get(sink_pad)
|
|
90
|
+
self.unaligned_data[sink_pad] = None
|
|
91
|
+
|
|
92
|
+
if in_frame is None or in_frame.is_gap:
|
|
93
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS if in_frame else False)
|
|
94
|
+
|
|
95
|
+
if not hasattr(in_frame, "data") or not in_frame.data:
|
|
96
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS)
|
|
97
|
+
if not in_frame.data[0].data:
|
|
98
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS)
|
|
99
|
+
|
|
100
|
+
psd_event = in_frame.data[0].data[0]
|
|
101
|
+
|
|
102
|
+
if not isinstance(psd_event, PSDEvent):
|
|
103
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS)
|
|
104
|
+
|
|
105
|
+
current_data = psd_event.data
|
|
106
|
+
current_epoch = psd_event.epoch
|
|
107
|
+
freqs = psd_event.frequencies
|
|
108
|
+
df = psd_event.delta_f
|
|
109
|
+
|
|
110
|
+
drift_results = {}
|
|
111
|
+
|
|
112
|
+
# Only calculate if we have history
|
|
113
|
+
if self._prev_data is not None:
|
|
114
|
+
dt = current_epoch - self._prev_epoch
|
|
115
|
+
if dt > 0:
|
|
116
|
+
drift_results = calculate_fisher_velocity(
|
|
117
|
+
current_psd=current_data,
|
|
118
|
+
previous_psd=self._prev_data,
|
|
119
|
+
dt=dt,
|
|
120
|
+
frequencies=freqs,
|
|
121
|
+
delta_f=df,
|
|
122
|
+
bands=self.bands,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Update History
|
|
126
|
+
self._prev_data = current_data.copy()
|
|
127
|
+
self._prev_epoch = current_epoch
|
|
128
|
+
|
|
129
|
+
# Handle startup transient (return zeros instead of empty)
|
|
130
|
+
if not drift_results:
|
|
131
|
+
bands_keys = self.bands.keys() if self.bands else ["total"]
|
|
132
|
+
drift_results = {k: 0.0 for k in bands_keys}
|
|
133
|
+
|
|
134
|
+
out_event = DriftEvent(epoch=current_epoch, data=drift_results)
|
|
135
|
+
|
|
136
|
+
buf = in_frame.data[0]
|
|
137
|
+
ts = buf.offset
|
|
138
|
+
dur = (
|
|
139
|
+
buf.duration if hasattr(buf, "duration") and buf.duration else 1_000_000_000
|
|
140
|
+
)
|
|
141
|
+
te = ts + dur
|
|
142
|
+
|
|
143
|
+
out_buf = EventBuffer.from_span(ts, te, [out_event])
|
|
144
|
+
|
|
145
|
+
return EventFrame(data=[out_buf], EOS=in_frame.EOS)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SGN Elements for PSD Estimation.
|
|
3
|
+
wraps sgnligo.psd.estimators logic into TSTransform elements.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import scipy.signal
|
|
13
|
+
from sgn.base import SourcePad
|
|
14
|
+
from sgnts.base import (
|
|
15
|
+
AdapterConfig,
|
|
16
|
+
EventBuffer,
|
|
17
|
+
EventFrame,
|
|
18
|
+
Offset,
|
|
19
|
+
TSTransform,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from sgndrift.psd.estimators import BaseEstimator, MGMEstimator, RecursiveEstimator
|
|
23
|
+
|
|
24
|
+
# Optional LAL import for conversion methods
|
|
25
|
+
try:
|
|
26
|
+
import lal
|
|
27
|
+
except ImportError:
|
|
28
|
+
lal = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class PSDEvent:
|
|
33
|
+
"""
|
|
34
|
+
Container for a PSD estimate event (Pure Python/NumPy).
|
|
35
|
+
Decoupled from LAL to ensure stability in non-LAL environments.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
data: np.ndarray
|
|
39
|
+
frequencies: np.ndarray
|
|
40
|
+
epoch: float
|
|
41
|
+
delta_f: float
|
|
42
|
+
|
|
43
|
+
def to_lal(self) -> Optional[object]:
|
|
44
|
+
"""
|
|
45
|
+
Convert to LAL REAL8FrequencySeries if LAL is available.
|
|
46
|
+
Uses standard 'strain^2 s' unit definition.
|
|
47
|
+
"""
|
|
48
|
+
if lal is None:
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
# Standard unit construction used in sgnligo.psd.psd
|
|
53
|
+
unit = lal.Unit("strain^2 s")
|
|
54
|
+
|
|
55
|
+
series = lal.CreateREAL8FrequencySeries(
|
|
56
|
+
"psd",
|
|
57
|
+
lal.LIGOTimeGPS(self.epoch),
|
|
58
|
+
0.0,
|
|
59
|
+
self.delta_f,
|
|
60
|
+
unit,
|
|
61
|
+
len(self.data),
|
|
62
|
+
)
|
|
63
|
+
series.data.data = self.data
|
|
64
|
+
return series
|
|
65
|
+
except Exception:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class PSDEstimator(TSTransform):
|
|
71
|
+
"""
|
|
72
|
+
Base TSTransform for PSD Estimation.
|
|
73
|
+
Outputs EventFrame containing PSDEvent objects.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
fft_length: float = 4.0
|
|
77
|
+
overlap: float = 0.5
|
|
78
|
+
sample_rate: int = 16384
|
|
79
|
+
window_type: str = "hann"
|
|
80
|
+
|
|
81
|
+
# Internal state
|
|
82
|
+
_estimator: BaseEstimator = field(init=False, repr=False, default=None)
|
|
83
|
+
_window: np.ndarray = field(init=False, repr=False, default=None)
|
|
84
|
+
_freqs: np.ndarray = field(init=False, repr=False, default=None)
|
|
85
|
+
_norm_factor: float = field(init=False, repr=False, default=1.0)
|
|
86
|
+
_delta_f: float = field(init=False, repr=False, default=0.0)
|
|
87
|
+
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
n_samples = int(self.fft_length * self.sample_rate)
|
|
90
|
+
stride = int(n_samples * (1 - self.overlap))
|
|
91
|
+
overlap_samples = n_samples - stride
|
|
92
|
+
|
|
93
|
+
self.adapter_config = AdapterConfig()
|
|
94
|
+
self.adapter_config.stride = Offset.fromsamples(stride, self.sample_rate)
|
|
95
|
+
self.adapter_config.overlap = (
|
|
96
|
+
0,
|
|
97
|
+
Offset.fromsamples(overlap_samples, self.sample_rate),
|
|
98
|
+
)
|
|
99
|
+
self.adapter_config.skip_gaps = True
|
|
100
|
+
|
|
101
|
+
super().__post_init__()
|
|
102
|
+
|
|
103
|
+
self._window = scipy.signal.get_window(self.window_type, n_samples)
|
|
104
|
+
s2 = np.sum(self._window**2)
|
|
105
|
+
self._norm_factor = 2.0 / (self.sample_rate * s2)
|
|
106
|
+
|
|
107
|
+
self._freqs = np.fft.rfftfreq(n_samples, d=1 / self.sample_rate)
|
|
108
|
+
self._delta_f = self._freqs[1] - self._freqs[0]
|
|
109
|
+
|
|
110
|
+
self._init_estimator(len(self._freqs))
|
|
111
|
+
|
|
112
|
+
def _init_estimator(self, size: int):
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
def new(self, pad: SourcePad) -> EventFrame:
|
|
116
|
+
in_frame = self.preparedframes[self.sink_pads[0]]
|
|
117
|
+
|
|
118
|
+
if in_frame.is_gap or not in_frame.buffers:
|
|
119
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS)
|
|
120
|
+
|
|
121
|
+
buf = in_frame.buffers[0]
|
|
122
|
+
data = buf.data
|
|
123
|
+
|
|
124
|
+
if len(data) != len(self._window):
|
|
125
|
+
return EventFrame(is_gap=True, EOS=in_frame.EOS)
|
|
126
|
+
|
|
127
|
+
# 1. Compute FFT
|
|
128
|
+
windowed = data * self._window
|
|
129
|
+
fft_data = np.fft.rfft(windowed)
|
|
130
|
+
|
|
131
|
+
# 2. Update Estimator
|
|
132
|
+
self._estimator.update(fft_data)
|
|
133
|
+
psd_data = self._estimator.get_psd().copy()
|
|
134
|
+
|
|
135
|
+
# 3. Create Output Event
|
|
136
|
+
# Calculate timestamps in nanoseconds for EventBuffer
|
|
137
|
+
ts = Offset.tons(buf.offset)
|
|
138
|
+
# Duration is derived from buffer length
|
|
139
|
+
duration_samples = len(data)
|
|
140
|
+
duration_offset = Offset.fromsamples(duration_samples, self.sample_rate)
|
|
141
|
+
te = Offset.tons(buf.offset + duration_offset)
|
|
142
|
+
|
|
143
|
+
# Epoch for PSD metadata (start of window)
|
|
144
|
+
epoch = Offset.tosec(buf.offset)
|
|
145
|
+
|
|
146
|
+
event = PSDEvent(
|
|
147
|
+
data=psd_data, frequencies=self._freqs, epoch=epoch, delta_f=self._delta_f
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Use factory method to avoid constructor signature issues
|
|
151
|
+
out_buf = EventBuffer.from_span(ts, te, [event])
|
|
152
|
+
|
|
153
|
+
meta = in_frame.metadata.copy() if in_frame.metadata else {}
|
|
154
|
+
|
|
155
|
+
lal_obj = event.to_lal()
|
|
156
|
+
if lal_obj:
|
|
157
|
+
meta["psd"] = lal_obj
|
|
158
|
+
|
|
159
|
+
meta["psd_numpy"] = psd_data
|
|
160
|
+
meta["psd_freqs"] = self._freqs
|
|
161
|
+
|
|
162
|
+
return EventFrame(data=[out_buf], metadata=meta, EOS=in_frame.EOS)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class RecursivePSD(PSDEstimator):
|
|
167
|
+
"""Fast, Low-Latency PSD Estimator."""
|
|
168
|
+
|
|
169
|
+
alpha: float = 0.1
|
|
170
|
+
|
|
171
|
+
def _init_estimator(self, size: int):
|
|
172
|
+
self._estimator = RecursiveEstimator(
|
|
173
|
+
size=size, normalization=self._norm_factor, alpha=self.alpha
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass
|
|
178
|
+
class MGMPSD(PSDEstimator):
|
|
179
|
+
"""Standard Median-Geometric-Mean PSD Estimator."""
|
|
180
|
+
|
|
181
|
+
n_median: int = 7
|
|
182
|
+
n_average: int = 64
|
|
183
|
+
|
|
184
|
+
def _init_estimator(self, size: int):
|
|
185
|
+
self._estimator = MGMEstimator(
|
|
186
|
+
size=size,
|
|
187
|
+
normalization=self._norm_factor,
|
|
188
|
+
n_median=self.n_median,
|
|
189
|
+
n_average=self.n_average,
|
|
190
|
+
)
|