ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
- ezmsg/sigproc/affinetransform.py +16 -42
- ezmsg/sigproc/aggregate.py +17 -34
- ezmsg/sigproc/bandpower.py +12 -20
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +7 -16
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/coordinatespaces.py +142 -0
- ezmsg/sigproc/decimate.py +3 -7
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +11 -20
- ezmsg/sigproc/ewma.py +11 -28
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +34 -59
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +37 -74
- ezmsg/sigproc/filterbankdesign.py +7 -14
- ezmsg/sigproc/fir_hilbert.py +13 -30
- ezmsg/sigproc/fir_pmc.py +5 -10
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +4 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +4 -1
- ezmsg/sigproc/math/difference.py +100 -36
- ezmsg/sigproc/math/invert.py +3 -3
- ezmsg/sigproc/math/log.py +5 -6
- ezmsg/sigproc/math/scale.py +2 -0
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +3 -6
- ezmsg/sigproc/resample.py +17 -38
- ezmsg/sigproc/rollingscaler.py +12 -37
- ezmsg/sigproc/sampler.py +19 -37
- ezmsg/sigproc/scaler.py +11 -22
- ezmsg/sigproc/signalinjector.py +7 -18
- ezmsg/sigproc/slicer.py +14 -34
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +17 -38
- ezmsg/sigproc/transpose.py +12 -24
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +12 -26
- ezmsg/sigproc/util/buffer.py +22 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +7 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +10 -19
- ezmsg/sigproc/window.py +29 -83
- ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
- ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
- ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
- {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/util/buffer.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
"""A stateful, FIFO buffer that combines a deque for fast appends with a
|
|
2
|
+
contiguous circular buffer for efficient, advancing reads.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
import collections
|
|
2
6
|
import math
|
|
3
7
|
import typing
|
|
@@ -63,9 +67,7 @@ class HybridBuffer:
|
|
|
63
67
|
self._buff_unread = 0 # Number of unread samples in the circular buffer
|
|
64
68
|
self._buff_read = 0 # Tracks samples read and still in buffer
|
|
65
69
|
self._deque_len = 0 # Number of unread samples in the deque
|
|
66
|
-
self._last_overflow =
|
|
67
|
-
0 # Tracks the last overflow count, overwritten or skipped
|
|
68
|
-
)
|
|
70
|
+
self._last_overflow = 0 # Tracks the last overflow count, overwritten or skipped
|
|
69
71
|
self._warned = False # Tracks if we've warned already (for warn_once)
|
|
70
72
|
|
|
71
73
|
@property
|
|
@@ -96,9 +98,7 @@ class HybridBuffer:
|
|
|
96
98
|
block = block[:, self.xp.newaxis]
|
|
97
99
|
|
|
98
100
|
if block.shape[1:] != other_shape:
|
|
99
|
-
raise ValueError(
|
|
100
|
-
f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}"
|
|
101
|
-
)
|
|
101
|
+
raise ValueError(f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}")
|
|
102
102
|
|
|
103
103
|
# Most overflow strategies are handled during flush, but there are a couple
|
|
104
104
|
# scenarios that can be evaluated on write to give immediate feedback.
|
|
@@ -117,8 +117,7 @@ class HybridBuffer:
|
|
|
117
117
|
self._deque_len += block.shape[0]
|
|
118
118
|
|
|
119
119
|
if self._update_strategy == "immediate" or (
|
|
120
|
-
self._update_strategy == "threshold"
|
|
121
|
-
and (0 < self._threshold <= self._deque_len)
|
|
120
|
+
self._update_strategy == "threshold" and (0 < self._threshold <= self._deque_len)
|
|
122
121
|
):
|
|
123
122
|
self.flush()
|
|
124
123
|
|
|
@@ -128,9 +127,7 @@ class HybridBuffer:
|
|
|
128
127
|
from the buffer.
|
|
129
128
|
"""
|
|
130
129
|
if n_samples > self.available():
|
|
131
|
-
raise ValueError(
|
|
132
|
-
f"Requested {n_samples} samples, but only {self.available()} are available."
|
|
133
|
-
)
|
|
130
|
+
raise ValueError(f"Requested {n_samples} samples, but only {self.available()} are available.")
|
|
134
131
|
n_overflow = 0
|
|
135
132
|
if self._deque and (n_samples > self._buff_unread):
|
|
136
133
|
# We would cause a flush, but would that cause an overflow?
|
|
@@ -161,14 +158,10 @@ class HybridBuffer:
|
|
|
161
158
|
n_overflow = self._estimate_overflow(n_samples)
|
|
162
159
|
if n_overflow > 0:
|
|
163
160
|
first_read = self._buff_unread
|
|
164
|
-
if (n_overflow - first_read) < self.capacity or (
|
|
165
|
-
self._overflow_strategy == "drop"
|
|
166
|
-
):
|
|
161
|
+
if (n_overflow - first_read) < self.capacity or (self._overflow_strategy == "drop"):
|
|
167
162
|
# We can prevent the overflow (or at least *some* if using "drop"
|
|
168
163
|
# strategy) by reading the samples in the buffer first to make room.
|
|
169
|
-
data = self.xp.empty(
|
|
170
|
-
(n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
|
|
171
|
-
)
|
|
164
|
+
data = self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
|
|
172
165
|
self.peek(first_read, out=data[:first_read])
|
|
173
166
|
offset += first_read
|
|
174
167
|
self.seek(first_read)
|
|
@@ -204,13 +197,9 @@ class HybridBuffer:
|
|
|
204
197
|
if n_samples is None:
|
|
205
198
|
n_samples = self.available()
|
|
206
199
|
elif n_samples > self.available():
|
|
207
|
-
raise ValueError(
|
|
208
|
-
f"Requested to peek {n_samples} samples, but only {self.available()} are available."
|
|
209
|
-
)
|
|
200
|
+
raise ValueError(f"Requested to peek {n_samples} samples, but only {self.available()} are available.")
|
|
210
201
|
if out is not None and out.shape[0] < n_samples:
|
|
211
|
-
raise ValueError(
|
|
212
|
-
f"Output array shape {out.shape} is smaller than requested {n_samples} samples."
|
|
213
|
-
)
|
|
202
|
+
raise ValueError(f"Output array shape {out.shape} is smaller than requested {n_samples} samples.")
|
|
214
203
|
|
|
215
204
|
if n_samples == 0:
|
|
216
205
|
return self._buffer[:0]
|
|
@@ -224,9 +213,7 @@ class HybridBuffer:
|
|
|
224
213
|
out = (
|
|
225
214
|
out
|
|
226
215
|
if out is not None
|
|
227
|
-
else self.xp.empty(
|
|
228
|
-
(n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
|
|
229
|
-
)
|
|
216
|
+
else self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
|
|
230
217
|
)
|
|
231
218
|
out[:part1_len] = self._buffer[self._tail :]
|
|
232
219
|
out[part1_len:] = self._buffer[:part2_len]
|
|
@@ -258,9 +245,7 @@ class HybridBuffer:
|
|
|
258
245
|
if not allow_flush and idx >= self._buff_unread:
|
|
259
246
|
# The requested sample is in the deque.
|
|
260
247
|
idx -= self._buff_unread
|
|
261
|
-
deq_splits = self.xp.cumsum(
|
|
262
|
-
[0] + [_.shape[0] for _ in self._deque], dtype=int
|
|
263
|
-
)
|
|
248
|
+
deq_splits = self.xp.cumsum([0] + [_.shape[0] for _ in self._deque], dtype=int)
|
|
264
249
|
arr_idx = self.xp.searchsorted(deq_splits, idx, side="right") - 1
|
|
265
250
|
idx -= deq_splits[arr_idx]
|
|
266
251
|
return self._deque[arr_idx][idx : idx + 1]
|
|
@@ -334,7 +319,8 @@ class HybridBuffer:
|
|
|
334
319
|
if n_overflow > 0 and (not self._warn_once or not self._warned):
|
|
335
320
|
self._warned = True
|
|
336
321
|
warnings.warn(
|
|
337
|
-
f"Buffer overflow: {n_new} samples received,
|
|
322
|
+
f"Buffer overflow: {n_new} samples received, "
|
|
323
|
+
f"but only {self._capacity - self._buff_unread} available. "
|
|
338
324
|
f"Overwriting {n_overflow} previous samples.",
|
|
339
325
|
RuntimeWarning,
|
|
340
326
|
)
|
|
@@ -347,10 +333,9 @@ class HybridBuffer:
|
|
|
347
333
|
break
|
|
348
334
|
n_to_copy = min(block.shape[0], samples_to_copy - copied_samples)
|
|
349
335
|
start_idx = block.shape[0] - n_to_copy
|
|
350
|
-
self._buffer[
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
] = block[start_idx:]
|
|
336
|
+
self._buffer[samples_to_copy - copied_samples - n_to_copy : samples_to_copy - copied_samples] = block[
|
|
337
|
+
start_idx:
|
|
338
|
+
]
|
|
354
339
|
copied_samples += n_to_copy
|
|
355
340
|
|
|
356
341
|
self._head = 0
|
|
@@ -362,9 +347,7 @@ class HybridBuffer:
|
|
|
362
347
|
else:
|
|
363
348
|
if n_overflow > 0:
|
|
364
349
|
if self._overflow_strategy == "raise":
|
|
365
|
-
raise OverflowError(
|
|
366
|
-
f"Buffer overflow: {n_new} samples received, but only {n_free} available."
|
|
367
|
-
)
|
|
350
|
+
raise OverflowError(f"Buffer overflow: {n_new} samples received, but only {n_free} available.")
|
|
368
351
|
elif self._overflow_strategy == "warn-overwrite":
|
|
369
352
|
if not self._warn_once or not self._warned:
|
|
370
353
|
self._warned = True
|
|
@@ -430,9 +413,7 @@ class HybridBuffer:
|
|
|
430
413
|
return
|
|
431
414
|
|
|
432
415
|
other_shape = self._buffer.shape[1:]
|
|
433
|
-
max_capacity = self._max_size / (
|
|
434
|
-
self._buffer.dtype.itemsize * math.prod(other_shape)
|
|
435
|
-
)
|
|
416
|
+
max_capacity = self._max_size / (self._buffer.dtype.itemsize * math.prod(other_shape))
|
|
436
417
|
if min_capacity > max_capacity:
|
|
437
418
|
raise OverflowError(
|
|
438
419
|
f"Cannot grow buffer to {min_capacity} samples, "
|
|
@@ -440,9 +421,7 @@ class HybridBuffer:
|
|
|
440
421
|
)
|
|
441
422
|
|
|
442
423
|
new_capacity = min(max_capacity, max(self._capacity * 2, min_capacity))
|
|
443
|
-
new_buffer = self.xp.empty(
|
|
444
|
-
(new_capacity, *other_shape), dtype=self._buffer.dtype
|
|
445
|
-
)
|
|
424
|
+
new_buffer = self.xp.empty((new_capacity, *other_shape), dtype=self._buffer.dtype)
|
|
446
425
|
|
|
447
426
|
# Copy existing data to new buffer
|
|
448
427
|
total_samples = self._buff_read + self._buff_unread
|
ezmsg/sigproc/util/message.py
CHANGED
|
@@ -1,31 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@dataclass
|
|
21
|
-
class SampleMessage:
|
|
22
|
-
trigger: SampleTriggerMessage
|
|
23
|
-
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
|
-
|
|
25
|
-
sample: AxisArray
|
|
26
|
-
"""The data sampled around the trigger."""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
|
|
30
|
-
"""Check if the message is a SampleMessage."""
|
|
31
|
-
return hasattr(message, "trigger")
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.message.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.message import (
|
|
8
|
+
SampleMessage,
|
|
9
|
+
SampleTriggerMessage,
|
|
10
|
+
is_sample_message,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"SampleMessage",
|
|
15
|
+
"SampleTriggerMessage",
|
|
16
|
+
"is_sample_message",
|
|
17
|
+
]
|
ezmsg/sigproc/util/profile.py
CHANGED
|
@@ -1,174 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
return logpath
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def _setup_logger(append: bool = False) -> logging.Logger:
|
|
29
|
-
logpath = get_logger_path()
|
|
30
|
-
logpath.parent.mkdir(parents=True, exist_ok=True)
|
|
31
|
-
|
|
32
|
-
write_header = True
|
|
33
|
-
if logpath.exists() and logpath.is_file():
|
|
34
|
-
if append:
|
|
35
|
-
with open(logpath) as f:
|
|
36
|
-
first_line = f.readline().rstrip()
|
|
37
|
-
if first_line == HEADER:
|
|
38
|
-
write_header = False
|
|
39
|
-
else:
|
|
40
|
-
# Remove the file if appending, but headers do not match
|
|
41
|
-
ezmsg_logger = logging.getLogger("ezmsg")
|
|
42
|
-
ezmsg_logger.warning(
|
|
43
|
-
"Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
|
|
44
|
-
)
|
|
45
|
-
logpath.unlink()
|
|
46
|
-
else:
|
|
47
|
-
# Remove the file if not appending
|
|
48
|
-
logpath.unlink()
|
|
49
|
-
|
|
50
|
-
# Create a logger with the name "ezprofile"
|
|
51
|
-
_logger = logging.getLogger("ezprofile")
|
|
52
|
-
|
|
53
|
-
# Set the logger's level to EZMSG_LOGLEVEL env var value if it exists, otherwise INFO
|
|
54
|
-
_logger.setLevel(os.environ.get("EZMSG_LOGLEVEL", "INFO").upper())
|
|
55
|
-
|
|
56
|
-
# Create a file handler to write log messages to the log file
|
|
57
|
-
fh = logging.FileHandler(logpath)
|
|
58
|
-
fh.setLevel(logging.DEBUG) # Set the file handler log level to DEBUG
|
|
59
|
-
|
|
60
|
-
# Add the file handler to the logger
|
|
61
|
-
_logger.addHandler(fh)
|
|
62
|
-
|
|
63
|
-
# Add the header if writing to new file or if header matched header in file.
|
|
64
|
-
if write_header:
|
|
65
|
-
_logger.debug(HEADER)
|
|
66
|
-
|
|
67
|
-
# Set the log message format
|
|
68
|
-
formatter = logging.Formatter(
|
|
69
|
-
"%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
|
|
70
|
-
)
|
|
71
|
-
fh.setFormatter(formatter)
|
|
72
|
-
|
|
73
|
-
return _logger
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
logger = _setup_logger(append=True)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _process_obj(obj, trace_oldest: bool = True):
|
|
80
|
-
samp_time = None
|
|
81
|
-
if hasattr(obj, "axes") and ("time" in obj.axes or "win" in obj.axes):
|
|
82
|
-
axis = "win" if "win" in obj.axes else "time"
|
|
83
|
-
ax = obj.get_axis(axis)
|
|
84
|
-
len = obj.data.shape[obj.get_axis_idx(axis)]
|
|
85
|
-
if len > 0:
|
|
86
|
-
idx = 0 if trace_oldest else (len - 1)
|
|
87
|
-
if hasattr(ax, "data"):
|
|
88
|
-
samp_time = ax.data[idx]
|
|
89
|
-
else:
|
|
90
|
-
samp_time = ax.value(idx)
|
|
91
|
-
if ax == "win" and "time" in obj.axes:
|
|
92
|
-
if hasattr(obj.axes["time"], "data"):
|
|
93
|
-
samp_time += obj.axes["time"].data[idx]
|
|
94
|
-
else:
|
|
95
|
-
samp_time += obj.axes["time"].value(idx)
|
|
96
|
-
return samp_time
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def profile_method(trace_oldest: bool = True):
|
|
100
|
-
"""
|
|
101
|
-
Decorator to profile a method by logging its execution time and other details.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
Callable: The decorated function with profiling.
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
def profiling_decorator(func: typing.Callable):
|
|
111
|
-
@functools.wraps(func)
|
|
112
|
-
def wrapped_func(caller, *args, **kwargs):
|
|
113
|
-
start = time.perf_counter()
|
|
114
|
-
res = func(caller, *args, **kwargs)
|
|
115
|
-
stop = time.perf_counter()
|
|
116
|
-
source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
|
|
117
|
-
topic = f"{caller.address}"
|
|
118
|
-
samp_time = _process_obj(res, trace_oldest=trace_oldest)
|
|
119
|
-
logger.debug(
|
|
120
|
-
",".join(
|
|
121
|
-
[
|
|
122
|
-
source,
|
|
123
|
-
topic,
|
|
124
|
-
f"{samp_time}",
|
|
125
|
-
f"{stop}",
|
|
126
|
-
f"{(stop - start) * 1e3:0.4f}",
|
|
127
|
-
]
|
|
128
|
-
)
|
|
129
|
-
)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
return wrapped_func if logger.level == logging.DEBUG else func
|
|
133
|
-
|
|
134
|
-
return profiling_decorator
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def profile_subpub(trace_oldest: bool = True):
|
|
138
|
-
"""
|
|
139
|
-
Decorator to profile a subscriber-publisher method in an ezmsg Unit
|
|
140
|
-
by logging its execution time and other details.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
Callable: The decorated async task with profiling.
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
def profiling_decorator(func: typing.Callable):
|
|
150
|
-
@functools.wraps(func)
|
|
151
|
-
async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
|
|
152
|
-
source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
|
|
153
|
-
topic = f"{unit.address}"
|
|
154
|
-
start = time.perf_counter()
|
|
155
|
-
async for stream, obj in func(unit, msg):
|
|
156
|
-
stop = time.perf_counter()
|
|
157
|
-
samp_time = _process_obj(obj, trace_oldest=trace_oldest)
|
|
158
|
-
logger.debug(
|
|
159
|
-
",".join(
|
|
160
|
-
[
|
|
161
|
-
source,
|
|
162
|
-
topic,
|
|
163
|
-
f"{samp_time}",
|
|
164
|
-
f"{stop}",
|
|
165
|
-
f"{(stop - start) * 1e3:0.4f}",
|
|
166
|
-
]
|
|
167
|
-
)
|
|
168
|
-
)
|
|
169
|
-
start = stop
|
|
170
|
-
yield stream, obj
|
|
171
|
-
|
|
172
|
-
return wrapped_task if logger.level == logging.DEBUG else func
|
|
173
|
-
|
|
174
|
-
return profiling_decorator
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.profile.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.profile import (
|
|
8
|
+
HEADER,
|
|
9
|
+
_setup_logger,
|
|
10
|
+
get_logger_path,
|
|
11
|
+
logger,
|
|
12
|
+
profile_method,
|
|
13
|
+
profile_subpub,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"HEADER",
|
|
18
|
+
"get_logger_path",
|
|
19
|
+
"logger",
|
|
20
|
+
"profile_method",
|
|
21
|
+
"profile_subpub",
|
|
22
|
+
"_setup_logger",
|
|
23
|
+
]
|
ezmsg/sigproc/util/sparse.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
+
"""Methods for sparse array signal processing operations."""
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
import sparse
|
|
3
5
|
|
|
4
6
|
|
|
5
|
-
def sliding_win_oneaxis_old(
|
|
6
|
-
s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
|
|
7
|
-
) -> sparse.SparseArray:
|
|
7
|
+
def sliding_win_oneaxis_old(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
|
|
8
8
|
"""
|
|
9
9
|
Like `ezmsg.util.messages.axisarray.sliding_win_oneaxis` but for sparse arrays.
|
|
10
10
|
This approach is about 4x slower than the version that uses coordinate arithmetic below.
|
|
@@ -23,16 +23,12 @@ def sliding_win_oneaxis_old(
|
|
|
23
23
|
targ_slices = [slice(_, _ + nwin) for _ in range(0, s.shape[axis] - nwin + 1, step)]
|
|
24
24
|
s = s.reshape(s.shape[:axis] + (1,) + s.shape[axis:])
|
|
25
25
|
full_slices = (slice(None),) * s.ndim
|
|
26
|
-
full_slices = [
|
|
27
|
-
full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices
|
|
28
|
-
]
|
|
26
|
+
full_slices = [full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices]
|
|
29
27
|
result = sparse.concatenate([s[_] for _ in full_slices], axis=axis)
|
|
30
28
|
return result
|
|
31
29
|
|
|
32
30
|
|
|
33
|
-
def sliding_win_oneaxis(
|
|
34
|
-
s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
|
|
35
|
-
) -> sparse.SparseArray:
|
|
31
|
+
def sliding_win_oneaxis(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
|
|
36
32
|
"""
|
|
37
33
|
Generates a view-like sparse array using a sliding window of specified length along a specified axis.
|
|
38
34
|
Sparse analog of an optimized dense as_strided-based implementation with these properties:
|
|
@@ -72,9 +68,7 @@ def sliding_win_oneaxis(
|
|
|
72
68
|
n_win_out = len(win_starts)
|
|
73
69
|
if n_win_out <= 0:
|
|
74
70
|
# Return array with proper shape except empty along windows axis
|
|
75
|
-
return sparse.zeros(
|
|
76
|
-
s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
|
|
77
|
-
)
|
|
71
|
+
return sparse.zeros(s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
|
|
78
72
|
|
|
79
73
|
coo = s.asformat("coo")
|
|
80
74
|
coords = coo.coords # shape: (ndim, nnz)
|
|
@@ -112,9 +106,7 @@ def sliding_win_oneaxis(
|
|
|
112
106
|
out_data_blocks.append(data[sel])
|
|
113
107
|
|
|
114
108
|
if not out_coords_blocks:
|
|
115
|
-
return sparse.zeros(
|
|
116
|
-
s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
|
|
117
|
-
)
|
|
109
|
+
return sparse.zeros(s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
|
|
118
110
|
|
|
119
111
|
out_coords = np.hstack(out_coords_blocks)
|
|
120
112
|
out_data = np.hstack(out_data_blocks)
|
|
@@ -1,83 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
for base in cls.__mro__:
|
|
20
|
-
orig_bases = get_original_bases(base)
|
|
21
|
-
for orig_base in orig_bases:
|
|
22
|
-
origin = typing.get_origin(orig_base)
|
|
23
|
-
if origin is None:
|
|
24
|
-
continue
|
|
25
|
-
params = getattr(origin, "__parameters__", ())
|
|
26
|
-
if not params:
|
|
27
|
-
continue
|
|
28
|
-
if target_typevar in params:
|
|
29
|
-
index = params.index(target_typevar)
|
|
30
|
-
args = typing.get_args(orig_base)
|
|
31
|
-
try:
|
|
32
|
-
return args[index]
|
|
33
|
-
except IndexError:
|
|
34
|
-
pass
|
|
35
|
-
raise TypeError(f"Could not resolve {target_typevar} in {cls}")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
|
|
42
|
-
"""
|
|
43
|
-
Check if two types are compatible for message passing.
|
|
44
|
-
Returns True if:
|
|
45
|
-
- Both are None/NoneType
|
|
46
|
-
- Either is typing.Any
|
|
47
|
-
- type1 is a subclass of type2, which includes
|
|
48
|
-
- type1 and type2 are concrete types and type1 is a subclass of type2
|
|
49
|
-
- type1 is None/NoneType and type2 is typing.Optional, or
|
|
50
|
-
- type1 is subtype of the non-None inner type of type2 if type2 is Optional
|
|
51
|
-
- type1 is a Union/Optional type and all inner types are compatible with type2
|
|
52
|
-
Args:
|
|
53
|
-
type1: First type to compare
|
|
54
|
-
type2: Second type to compare
|
|
55
|
-
Returns:
|
|
56
|
-
bool: True if the types are compatible, False otherwise
|
|
57
|
-
"""
|
|
58
|
-
# If either is Any, they are compatible
|
|
59
|
-
if type1 is typing.Any or type2 is typing.Any:
|
|
60
|
-
return True
|
|
61
|
-
|
|
62
|
-
# Handle None as NoneType
|
|
63
|
-
if type1 is None:
|
|
64
|
-
type1 = type(None)
|
|
65
|
-
if type2 is None:
|
|
66
|
-
type2 = type(None)
|
|
67
|
-
|
|
68
|
-
# Handle if type1 is Optional/Union type
|
|
69
|
-
if typing.get_origin(type1) in {typing.Union, UnionType}:
|
|
70
|
-
return all(
|
|
71
|
-
check_message_type_compatibility(inner_type, type2)
|
|
72
|
-
for inner_type in typing.get_args(type1)
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# Regular issubclass check. Handles cases like:
|
|
76
|
-
# - type1 is a subclass of concrete type2
|
|
77
|
-
# - type1 is a subclass of the inner type of type2 if type2 is Optional
|
|
78
|
-
# - type1 is a subclass of one of the inner types of type2 if type2 is Union
|
|
79
|
-
# - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
|
|
80
|
-
try:
|
|
81
|
-
return issubclass(type1, type2)
|
|
82
|
-
except TypeError:
|
|
83
|
-
return False
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.typeresolution.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.typeresolution import (
|
|
8
|
+
TypeLike,
|
|
9
|
+
check_message_type_compatibility,
|
|
10
|
+
resolve_typevar,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TypeLike",
|
|
15
|
+
"check_message_type_compatibility",
|
|
16
|
+
"resolve_typevar",
|
|
17
|
+
]
|
ezmsg/sigproc/wavelets.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
3
4
|
import numpy as np
|
|
4
5
|
import numpy.typing as npt
|
|
5
6
|
import pywt
|
|
6
|
-
|
|
7
|
-
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
-
from ezmsg.util.messages.util import replace
|
|
9
|
-
|
|
10
|
-
from .base import (
|
|
7
|
+
from ezmsg.baseproc import (
|
|
11
8
|
BaseStatefulTransformer,
|
|
12
9
|
BaseTransformerUnit,
|
|
13
10
|
processor_state,
|
|
14
11
|
)
|
|
15
|
-
from .
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.messages.util import replace
|
|
14
|
+
|
|
15
|
+
from .filterbank import FilterbankMode, MinPhaseMode, filterbank
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class CWTSettings(ez.Settings):
|
|
@@ -37,9 +37,7 @@ class CWTState:
|
|
|
37
37
|
last_conv_samp: npt.NDArray | None = None
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
class CWTTransformer(
|
|
41
|
-
BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
|
|
42
|
-
):
|
|
40
|
+
class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
|
|
43
41
|
def _hash_message(self, message: AxisArray) -> int:
|
|
44
42
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
45
43
|
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
@@ -107,25 +105,18 @@ class CWTTransformer(
|
|
|
107
105
|
# Create output template
|
|
108
106
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
109
107
|
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
110
|
-
freqs = (
|
|
111
|
-
pywt.scale2frequency(wavelet, scales, precision)
|
|
112
|
-
/ message.axes[self.settings.axis].gain
|
|
113
|
-
)
|
|
108
|
+
freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
|
|
114
109
|
dummy_shape = in_shape + (len(scales), 0)
|
|
115
110
|
self._state.template = AxisArray(
|
|
116
111
|
np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
|
|
117
|
-
dims=message.dims[:ax_idx]
|
|
118
|
-
+ message.dims[ax_idx + 1 :]
|
|
119
|
-
+ ["freq", self.settings.axis],
|
|
112
|
+
dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
|
|
120
113
|
axes={
|
|
121
114
|
**message.axes,
|
|
122
115
|
"freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
|
|
123
116
|
},
|
|
124
117
|
key=message.key,
|
|
125
118
|
)
|
|
126
|
-
self._state.last_conv_samp = np.zeros(
|
|
127
|
-
dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
|
|
128
|
-
)
|
|
119
|
+
self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
|
|
129
120
|
|
|
130
121
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
131
122
|
conv_msg = self._state.fbgen.send(message)
|