ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
- ezmsg/sigproc/affinetransform.py +13 -38
- ezmsg/sigproc/aggregate.py +13 -30
- ezmsg/sigproc/bandpower.py +7 -15
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +123 -0
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/decimate.py +2 -6
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +6 -14
- ezmsg/sigproc/ewma.py +11 -27
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +31 -56
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +33 -70
- ezmsg/sigproc/filterbankdesign.py +5 -12
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +1 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +98 -36
- ezmsg/sigproc/math/invert.py +1 -3
- ezmsg/sigproc/math/log.py +2 -6
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +2 -4
- ezmsg/sigproc/resample.py +13 -34
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +17 -35
- ezmsg/sigproc/scaler.py +8 -18
- ezmsg/sigproc/signalinjector.py +6 -16
- ezmsg/sigproc/slicer.py +9 -28
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +12 -32
- ezmsg/sigproc/transpose.py +7 -18
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +10 -26
- ezmsg/sigproc/util/buffer.py +18 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +5 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +6 -15
- ezmsg/sigproc/window.py +24 -78
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
- ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/util/profile.py
CHANGED
|
@@ -1,174 +1,23 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
return logpath
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def _setup_logger(append: bool = False) -> logging.Logger:
|
|
29
|
-
logpath = get_logger_path()
|
|
30
|
-
logpath.parent.mkdir(parents=True, exist_ok=True)
|
|
31
|
-
|
|
32
|
-
write_header = True
|
|
33
|
-
if logpath.exists() and logpath.is_file():
|
|
34
|
-
if append:
|
|
35
|
-
with open(logpath) as f:
|
|
36
|
-
first_line = f.readline().rstrip()
|
|
37
|
-
if first_line == HEADER:
|
|
38
|
-
write_header = False
|
|
39
|
-
else:
|
|
40
|
-
# Remove the file if appending, but headers do not match
|
|
41
|
-
ezmsg_logger = logging.getLogger("ezmsg")
|
|
42
|
-
ezmsg_logger.warning(
|
|
43
|
-
"Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
|
|
44
|
-
)
|
|
45
|
-
logpath.unlink()
|
|
46
|
-
else:
|
|
47
|
-
# Remove the file if not appending
|
|
48
|
-
logpath.unlink()
|
|
49
|
-
|
|
50
|
-
# Create a logger with the name "ezprofile"
|
|
51
|
-
_logger = logging.getLogger("ezprofile")
|
|
52
|
-
|
|
53
|
-
# Set the logger's level to EZMSG_LOGLEVEL env var value if it exists, otherwise INFO
|
|
54
|
-
_logger.setLevel(os.environ.get("EZMSG_LOGLEVEL", "INFO").upper())
|
|
55
|
-
|
|
56
|
-
# Create a file handler to write log messages to the log file
|
|
57
|
-
fh = logging.FileHandler(logpath)
|
|
58
|
-
fh.setLevel(logging.DEBUG) # Set the file handler log level to DEBUG
|
|
59
|
-
|
|
60
|
-
# Add the file handler to the logger
|
|
61
|
-
_logger.addHandler(fh)
|
|
62
|
-
|
|
63
|
-
# Add the header if writing to new file or if header matched header in file.
|
|
64
|
-
if write_header:
|
|
65
|
-
_logger.debug(HEADER)
|
|
66
|
-
|
|
67
|
-
# Set the log message format
|
|
68
|
-
formatter = logging.Formatter(
|
|
69
|
-
"%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
|
|
70
|
-
)
|
|
71
|
-
fh.setFormatter(formatter)
|
|
72
|
-
|
|
73
|
-
return _logger
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
logger = _setup_logger(append=True)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def _process_obj(obj, trace_oldest: bool = True):
|
|
80
|
-
samp_time = None
|
|
81
|
-
if hasattr(obj, "axes") and ("time" in obj.axes or "win" in obj.axes):
|
|
82
|
-
axis = "win" if "win" in obj.axes else "time"
|
|
83
|
-
ax = obj.get_axis(axis)
|
|
84
|
-
len = obj.data.shape[obj.get_axis_idx(axis)]
|
|
85
|
-
if len > 0:
|
|
86
|
-
idx = 0 if trace_oldest else (len - 1)
|
|
87
|
-
if hasattr(ax, "data"):
|
|
88
|
-
samp_time = ax.data[idx]
|
|
89
|
-
else:
|
|
90
|
-
samp_time = ax.value(idx)
|
|
91
|
-
if ax == "win" and "time" in obj.axes:
|
|
92
|
-
if hasattr(obj.axes["time"], "data"):
|
|
93
|
-
samp_time += obj.axes["time"].data[idx]
|
|
94
|
-
else:
|
|
95
|
-
samp_time += obj.axes["time"].value(idx)
|
|
96
|
-
return samp_time
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def profile_method(trace_oldest: bool = True):
|
|
100
|
-
"""
|
|
101
|
-
Decorator to profile a method by logging its execution time and other details.
|
|
102
|
-
|
|
103
|
-
Args:
|
|
104
|
-
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
Callable: The decorated function with profiling.
|
|
108
|
-
"""
|
|
109
|
-
|
|
110
|
-
def profiling_decorator(func: typing.Callable):
|
|
111
|
-
@functools.wraps(func)
|
|
112
|
-
def wrapped_func(caller, *args, **kwargs):
|
|
113
|
-
start = time.perf_counter()
|
|
114
|
-
res = func(caller, *args, **kwargs)
|
|
115
|
-
stop = time.perf_counter()
|
|
116
|
-
source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
|
|
117
|
-
topic = f"{caller.address}"
|
|
118
|
-
samp_time = _process_obj(res, trace_oldest=trace_oldest)
|
|
119
|
-
logger.debug(
|
|
120
|
-
",".join(
|
|
121
|
-
[
|
|
122
|
-
source,
|
|
123
|
-
topic,
|
|
124
|
-
f"{samp_time}",
|
|
125
|
-
f"{stop}",
|
|
126
|
-
f"{(stop - start) * 1e3:0.4f}",
|
|
127
|
-
]
|
|
128
|
-
)
|
|
129
|
-
)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
return wrapped_func if logger.level == logging.DEBUG else func
|
|
133
|
-
|
|
134
|
-
return profiling_decorator
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def profile_subpub(trace_oldest: bool = True):
|
|
138
|
-
"""
|
|
139
|
-
Decorator to profile a subscriber-publisher method in an ezmsg Unit
|
|
140
|
-
by logging its execution time and other details.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
Callable: The decorated async task with profiling.
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
def profiling_decorator(func: typing.Callable):
|
|
150
|
-
@functools.wraps(func)
|
|
151
|
-
async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
|
|
152
|
-
source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
|
|
153
|
-
topic = f"{unit.address}"
|
|
154
|
-
start = time.perf_counter()
|
|
155
|
-
async for stream, obj in func(unit, msg):
|
|
156
|
-
stop = time.perf_counter()
|
|
157
|
-
samp_time = _process_obj(obj, trace_oldest=trace_oldest)
|
|
158
|
-
logger.debug(
|
|
159
|
-
",".join(
|
|
160
|
-
[
|
|
161
|
-
source,
|
|
162
|
-
topic,
|
|
163
|
-
f"{samp_time}",
|
|
164
|
-
f"{stop}",
|
|
165
|
-
f"{(stop - start) * 1e3:0.4f}",
|
|
166
|
-
]
|
|
167
|
-
)
|
|
168
|
-
)
|
|
169
|
-
start = stop
|
|
170
|
-
yield stream, obj
|
|
171
|
-
|
|
172
|
-
return wrapped_task if logger.level == logging.DEBUG else func
|
|
173
|
-
|
|
174
|
-
return profiling_decorator
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.profile.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.profile import (
|
|
8
|
+
HEADER,
|
|
9
|
+
_setup_logger,
|
|
10
|
+
get_logger_path,
|
|
11
|
+
logger,
|
|
12
|
+
profile_method,
|
|
13
|
+
profile_subpub,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"HEADER",
|
|
18
|
+
"get_logger_path",
|
|
19
|
+
"logger",
|
|
20
|
+
"profile_method",
|
|
21
|
+
"profile_subpub",
|
|
22
|
+
"_setup_logger",
|
|
23
|
+
]
|
ezmsg/sigproc/util/sparse.py
CHANGED
|
@@ -2,9 +2,7 @@ import numpy as np
|
|
|
2
2
|
import sparse
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
def sliding_win_oneaxis_old(
|
|
6
|
-
s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
|
|
7
|
-
) -> sparse.SparseArray:
|
|
5
|
+
def sliding_win_oneaxis_old(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
|
|
8
6
|
"""
|
|
9
7
|
Like `ezmsg.util.messages.axisarray.sliding_win_oneaxis` but for sparse arrays.
|
|
10
8
|
This approach is about 4x slower than the version that uses coordinate arithmetic below.
|
|
@@ -23,16 +21,12 @@ def sliding_win_oneaxis_old(
|
|
|
23
21
|
targ_slices = [slice(_, _ + nwin) for _ in range(0, s.shape[axis] - nwin + 1, step)]
|
|
24
22
|
s = s.reshape(s.shape[:axis] + (1,) + s.shape[axis:])
|
|
25
23
|
full_slices = (slice(None),) * s.ndim
|
|
26
|
-
full_slices = [
|
|
27
|
-
full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices
|
|
28
|
-
]
|
|
24
|
+
full_slices = [full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices]
|
|
29
25
|
result = sparse.concatenate([s[_] for _ in full_slices], axis=axis)
|
|
30
26
|
return result
|
|
31
27
|
|
|
32
28
|
|
|
33
|
-
def sliding_win_oneaxis(
|
|
34
|
-
s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
|
|
35
|
-
) -> sparse.SparseArray:
|
|
29
|
+
def sliding_win_oneaxis(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
|
|
36
30
|
"""
|
|
37
31
|
Generates a view-like sparse array using a sliding window of specified length along a specified axis.
|
|
38
32
|
Sparse analog of an optimized dense as_strided-based implementation with these properties:
|
|
@@ -72,9 +66,7 @@ def sliding_win_oneaxis(
|
|
|
72
66
|
n_win_out = len(win_starts)
|
|
73
67
|
if n_win_out <= 0:
|
|
74
68
|
# Return array with proper shape except empty along windows axis
|
|
75
|
-
return sparse.zeros(
|
|
76
|
-
s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
|
|
77
|
-
)
|
|
69
|
+
return sparse.zeros(s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
|
|
78
70
|
|
|
79
71
|
coo = s.asformat("coo")
|
|
80
72
|
coords = coo.coords # shape: (ndim, nnz)
|
|
@@ -112,9 +104,7 @@ def sliding_win_oneaxis(
|
|
|
112
104
|
out_data_blocks.append(data[sel])
|
|
113
105
|
|
|
114
106
|
if not out_coords_blocks:
|
|
115
|
-
return sparse.zeros(
|
|
116
|
-
s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
|
|
117
|
-
)
|
|
107
|
+
return sparse.zeros(s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
|
|
118
108
|
|
|
119
109
|
out_coords = np.hstack(out_coords_blocks)
|
|
120
110
|
out_data = np.hstack(out_data_blocks)
|
|
@@ -1,83 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
for base in cls.__mro__:
|
|
20
|
-
orig_bases = get_original_bases(base)
|
|
21
|
-
for orig_base in orig_bases:
|
|
22
|
-
origin = typing.get_origin(orig_base)
|
|
23
|
-
if origin is None:
|
|
24
|
-
continue
|
|
25
|
-
params = getattr(origin, "__parameters__", ())
|
|
26
|
-
if not params:
|
|
27
|
-
continue
|
|
28
|
-
if target_typevar in params:
|
|
29
|
-
index = params.index(target_typevar)
|
|
30
|
-
args = typing.get_args(orig_base)
|
|
31
|
-
try:
|
|
32
|
-
return args[index]
|
|
33
|
-
except IndexError:
|
|
34
|
-
pass
|
|
35
|
-
raise TypeError(f"Could not resolve {target_typevar} in {cls}")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
|
|
42
|
-
"""
|
|
43
|
-
Check if two types are compatible for message passing.
|
|
44
|
-
Returns True if:
|
|
45
|
-
- Both are None/NoneType
|
|
46
|
-
- Either is typing.Any
|
|
47
|
-
- type1 is a subclass of type2, which includes
|
|
48
|
-
- type1 and type2 are concrete types and type1 is a subclass of type2
|
|
49
|
-
- type1 is None/NoneType and type2 is typing.Optional, or
|
|
50
|
-
- type1 is subtype of the non-None inner type of type2 if type2 is Optional
|
|
51
|
-
- type1 is a Union/Optional type and all inner types are compatible with type2
|
|
52
|
-
Args:
|
|
53
|
-
type1: First type to compare
|
|
54
|
-
type2: Second type to compare
|
|
55
|
-
Returns:
|
|
56
|
-
bool: True if the types are compatible, False otherwise
|
|
57
|
-
"""
|
|
58
|
-
# If either is Any, they are compatible
|
|
59
|
-
if type1 is typing.Any or type2 is typing.Any:
|
|
60
|
-
return True
|
|
61
|
-
|
|
62
|
-
# Handle None as NoneType
|
|
63
|
-
if type1 is None:
|
|
64
|
-
type1 = type(None)
|
|
65
|
-
if type2 is None:
|
|
66
|
-
type2 = type(None)
|
|
67
|
-
|
|
68
|
-
# Handle if type1 is Optional/Union type
|
|
69
|
-
if typing.get_origin(type1) in {typing.Union, UnionType}:
|
|
70
|
-
return all(
|
|
71
|
-
check_message_type_compatibility(inner_type, type2)
|
|
72
|
-
for inner_type in typing.get_args(type1)
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
# Regular issubclass check. Handles cases like:
|
|
76
|
-
# - type1 is a subclass of concrete type2
|
|
77
|
-
# - type1 is a subclass of the inner type of type2 if type2 is Optional
|
|
78
|
-
# - type1 is a subclass of one of the inner types of type2 if type2 is Union
|
|
79
|
-
# - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
|
|
80
|
-
try:
|
|
81
|
-
return issubclass(type1, type2)
|
|
82
|
-
except TypeError:
|
|
83
|
-
return False
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.util.typeresolution.
|
|
3
|
+
|
|
4
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ezmsg.baseproc.util.typeresolution import (
|
|
8
|
+
TypeLike,
|
|
9
|
+
check_message_type_compatibility,
|
|
10
|
+
resolve_typevar,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TypeLike",
|
|
15
|
+
"check_message_type_compatibility",
|
|
16
|
+
"resolve_typevar",
|
|
17
|
+
]
|
ezmsg/sigproc/wavelets.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import typing
|
|
2
2
|
|
|
3
|
+
import ezmsg.core as ez
|
|
3
4
|
import numpy as np
|
|
4
5
|
import numpy.typing as npt
|
|
5
6
|
import pywt
|
|
6
|
-
import ezmsg.core as ez
|
|
7
7
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
8
|
from ezmsg.util.messages.util import replace
|
|
9
9
|
|
|
@@ -12,7 +12,7 @@ from .base import (
|
|
|
12
12
|
BaseTransformerUnit,
|
|
13
13
|
processor_state,
|
|
14
14
|
)
|
|
15
|
-
from .filterbank import
|
|
15
|
+
from .filterbank import FilterbankMode, MinPhaseMode, filterbank
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class CWTSettings(ez.Settings):
|
|
@@ -37,9 +37,7 @@ class CWTState:
|
|
|
37
37
|
last_conv_samp: npt.NDArray | None = None
|
|
38
38
|
|
|
39
39
|
|
|
40
|
-
class CWTTransformer(
|
|
41
|
-
BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
|
|
42
|
-
):
|
|
40
|
+
class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
|
|
43
41
|
def _hash_message(self, message: AxisArray) -> int:
|
|
44
42
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
45
43
|
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
@@ -107,25 +105,18 @@ class CWTTransformer(
|
|
|
107
105
|
# Create output template
|
|
108
106
|
ax_idx = message.get_axis_idx(self.settings.axis)
|
|
109
107
|
in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
|
|
110
|
-
freqs = (
|
|
111
|
-
pywt.scale2frequency(wavelet, scales, precision)
|
|
112
|
-
/ message.axes[self.settings.axis].gain
|
|
113
|
-
)
|
|
108
|
+
freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
|
|
114
109
|
dummy_shape = in_shape + (len(scales), 0)
|
|
115
110
|
self._state.template = AxisArray(
|
|
116
111
|
np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
|
|
117
|
-
dims=message.dims[:ax_idx]
|
|
118
|
-
+ message.dims[ax_idx + 1 :]
|
|
119
|
-
+ ["freq", self.settings.axis],
|
|
112
|
+
dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
|
|
120
113
|
axes={
|
|
121
114
|
**message.axes,
|
|
122
115
|
"freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
|
|
123
116
|
},
|
|
124
117
|
key=message.key,
|
|
125
118
|
)
|
|
126
|
-
self._state.last_conv_samp = np.zeros(
|
|
127
|
-
dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
|
|
128
|
-
)
|
|
119
|
+
self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
|
|
129
120
|
|
|
130
121
|
def _process(self, message: AxisArray) -> AxisArray:
|
|
131
122
|
conv_msg = self._state.fbgen.send(message)
|
ezmsg/sigproc/window.py
CHANGED
|
@@ -5,12 +5,12 @@ import typing
|
|
|
5
5
|
import ezmsg.core as ez
|
|
6
6
|
import numpy.typing as npt
|
|
7
7
|
import sparse
|
|
8
|
-
from array_api_compat import
|
|
8
|
+
from array_api_compat import get_namespace, is_pydata_sparse_namespace
|
|
9
9
|
from ezmsg.util.messages.axisarray import (
|
|
10
10
|
AxisArray,
|
|
11
|
+
replace,
|
|
11
12
|
slice_along_axis,
|
|
12
13
|
sliding_win_oneaxis,
|
|
13
|
-
replace,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
16
|
from .base import (
|
|
@@ -18,8 +18,8 @@ from .base import (
|
|
|
18
18
|
BaseTransformerUnit,
|
|
19
19
|
processor_state,
|
|
20
20
|
)
|
|
21
|
-
from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
|
|
22
21
|
from .util.profile import profile_subpub
|
|
22
|
+
from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class Anchor(enum.Enum):
|
|
@@ -55,9 +55,7 @@ class WindowState:
|
|
|
55
55
|
out_dims: list[str] | None = None
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
class WindowTransformer(
|
|
59
|
-
BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]
|
|
60
|
-
):
|
|
58
|
+
class WindowTransformer(BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]):
|
|
61
59
|
"""
|
|
62
60
|
Apply a sliding window along the specified axis to input streaming data.
|
|
63
61
|
The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization
|
|
@@ -99,19 +97,13 @@ class WindowTransformer(
|
|
|
99
97
|
# if self.settings.newaxis is None:
|
|
100
98
|
# ez.logger.warning("`newaxis=None` will be replaced with `newaxis='win'`.")
|
|
101
99
|
# object.__setattr__(self.settings, "newaxis", "win")
|
|
102
|
-
if
|
|
103
|
-
self.settings.window_shift is None
|
|
104
|
-
and self.settings.zero_pad_until != "input"
|
|
105
|
-
):
|
|
100
|
+
if self.settings.window_shift is None and self.settings.zero_pad_until != "input":
|
|
106
101
|
ez.logger.warning(
|
|
107
102
|
"`zero_pad_until` must be 'input' if `window_shift` is None. "
|
|
108
103
|
f"Ignoring received argument value: {self.settings.zero_pad_until}"
|
|
109
104
|
)
|
|
110
105
|
object.__setattr__(self.settings, "zero_pad_until", "input")
|
|
111
|
-
elif
|
|
112
|
-
self.settings.window_shift is not None
|
|
113
|
-
and self.settings.zero_pad_until == "input"
|
|
114
|
-
):
|
|
106
|
+
elif self.settings.window_shift is not None and self.settings.zero_pad_until == "input":
|
|
115
107
|
ez.logger.warning(
|
|
116
108
|
"windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
|
|
117
109
|
"of the first input. We recommend using `zero_pad_until='shift'` when `window_shift` is float-valued."
|
|
@@ -135,9 +127,7 @@ class WindowTransformer(
|
|
|
135
127
|
def _reset_state(self, message: AxisArray) -> None:
|
|
136
128
|
_newaxis = self.settings.newaxis or "win"
|
|
137
129
|
if not self._state.newaxis_warned and _newaxis in message.dims:
|
|
138
|
-
ez.logger.warning(
|
|
139
|
-
f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead"
|
|
140
|
-
)
|
|
130
|
+
ez.logger.warning(f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead")
|
|
141
131
|
self._state.newaxis_warned = True
|
|
142
132
|
self.settings.newaxis = f"{_newaxis}_win"
|
|
143
133
|
|
|
@@ -154,33 +144,20 @@ class WindowTransformer(
|
|
|
154
144
|
self._state.window_shift_samples = int(self.settings.window_shift * fs)
|
|
155
145
|
if self.settings.zero_pad_until == "none":
|
|
156
146
|
req_samples = self._state.window_samples
|
|
157
|
-
elif
|
|
158
|
-
self.settings.zero_pad_until == "shift"
|
|
159
|
-
and self.settings.window_shift is not None
|
|
160
|
-
):
|
|
147
|
+
elif self.settings.zero_pad_until == "shift" and self.settings.window_shift is not None:
|
|
161
148
|
req_samples = self._state.window_shift_samples
|
|
162
149
|
else: # i.e. zero_pad_until == "input"
|
|
163
150
|
req_samples = message.data.shape[axis_idx]
|
|
164
151
|
n_zero = max(0, self._state.window_samples - req_samples)
|
|
165
|
-
init_buffer_shape = (
|
|
166
|
-
message.data.shape[:axis_idx]
|
|
167
|
-
+ (n_zero,)
|
|
168
|
-
+ message.data.shape[axis_idx + 1 :]
|
|
169
|
-
)
|
|
152
|
+
init_buffer_shape = message.data.shape[:axis_idx] + (n_zero,) + message.data.shape[axis_idx + 1 :]
|
|
170
153
|
self._state.buffer = xp.zeros(init_buffer_shape, dtype=message.data.dtype)
|
|
171
154
|
|
|
172
155
|
# Prepare reusable parts of output
|
|
173
156
|
if self._state.out_newaxis is None:
|
|
174
|
-
self._state.out_dims = (
|
|
175
|
-
list(message.dims[:axis_idx])
|
|
176
|
-
+ [_newaxis]
|
|
177
|
-
+ list(message.dims[axis_idx:])
|
|
178
|
-
)
|
|
157
|
+
self._state.out_dims = list(message.dims[:axis_idx]) + [_newaxis] + list(message.dims[axis_idx:])
|
|
179
158
|
self._state.out_newaxis = replace(
|
|
180
159
|
axis_info,
|
|
181
|
-
gain=0.0
|
|
182
|
-
if self.settings.window_shift is None
|
|
183
|
-
else axis_info.gain * self._state.window_shift_samples,
|
|
160
|
+
gain=0.0 if self.settings.window_shift is None else axis_info.gain * self._state.window_shift_samples,
|
|
184
161
|
offset=0.0, # offset modified per-msg below
|
|
185
162
|
)
|
|
186
163
|
|
|
@@ -204,9 +181,7 @@ class WindowTransformer(
|
|
|
204
181
|
# is generally faster than np.roll and slicing anyway, but this could still
|
|
205
182
|
# be a performance bottleneck for large memory arrays.
|
|
206
183
|
# A circular buffer might be faster.
|
|
207
|
-
self._state.buffer = xp.concatenate(
|
|
208
|
-
(self._state.buffer, message.data), axis=axis_idx
|
|
209
|
-
)
|
|
184
|
+
self._state.buffer = xp.concatenate((self._state.buffer, message.data), axis=axis_idx)
|
|
210
185
|
|
|
211
186
|
# Create a vector of buffer timestamps to track axis `offset` in output(s)
|
|
212
187
|
buffer_t0 = 0.0
|
|
@@ -222,9 +197,7 @@ class WindowTransformer(
|
|
|
222
197
|
if self.settings.window_shift is not None and self._state.shift_deficit > 0:
|
|
223
198
|
n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit)
|
|
224
199
|
if n_skip > 0:
|
|
225
|
-
self._state.buffer = slice_along_axis(
|
|
226
|
-
self._state.buffer, slice(n_skip, None), axis_idx
|
|
227
|
-
)
|
|
200
|
+
self._state.buffer = slice_along_axis(self._state.buffer, slice(n_skip, None), axis_idx)
|
|
228
201
|
buffer_t0 += n_skip * axis_info.gain
|
|
229
202
|
buffer_tlen -= n_skip
|
|
230
203
|
self._state.shift_deficit -= n_skip
|
|
@@ -249,20 +222,12 @@ class WindowTransformer(
|
|
|
249
222
|
self._state.buffer, slice(-self._state.window_samples, None), axis_idx
|
|
250
223
|
)
|
|
251
224
|
out_dat = self._state.buffer.reshape(
|
|
252
|
-
self._state.buffer.shape[:axis_idx]
|
|
253
|
-
+ (1,)
|
|
254
|
-
+ self._state.buffer.shape[axis_idx:]
|
|
255
|
-
)
|
|
256
|
-
win_offset = buffer_t0 + axis_info.gain * (
|
|
257
|
-
buffer_tlen - self._state.window_samples
|
|
225
|
+
self._state.buffer.shape[:axis_idx] + (1,) + self._state.buffer.shape[axis_idx:]
|
|
258
226
|
)
|
|
227
|
+
win_offset = buffer_t0 + axis_info.gain * (buffer_tlen - self._state.window_samples)
|
|
259
228
|
elif self._state.buffer.shape[axis_idx] >= self._state.window_samples:
|
|
260
229
|
# Deterministic window shifts.
|
|
261
|
-
sliding_win_fun = (
|
|
262
|
-
sparse_sliding_win_oneaxis
|
|
263
|
-
if is_pydata_sparse_namespace(xp)
|
|
264
|
-
else sliding_win_oneaxis
|
|
265
|
-
)
|
|
230
|
+
sliding_win_fun = sparse_sliding_win_oneaxis if is_pydata_sparse_namespace(xp) else sliding_win_oneaxis
|
|
266
231
|
out_dat = sliding_win_fun(
|
|
267
232
|
self._state.buffer,
|
|
268
233
|
self._state.window_samples,
|
|
@@ -273,18 +238,12 @@ class WindowTransformer(
|
|
|
273
238
|
|
|
274
239
|
# Drop expired beginning of buffer and update shift_deficit
|
|
275
240
|
multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx]
|
|
276
|
-
self._state.shift_deficit = max(
|
|
277
|
-
|
|
278
|
-
)
|
|
279
|
-
self._state.buffer = slice_along_axis(
|
|
280
|
-
self._state.buffer, slice(multi_shift, None), axis_idx
|
|
281
|
-
)
|
|
241
|
+
self._state.shift_deficit = max(0, multi_shift - self._state.buffer.shape[axis_idx])
|
|
242
|
+
self._state.buffer = slice_along_axis(self._state.buffer, slice(multi_shift, None), axis_idx)
|
|
282
243
|
else:
|
|
283
244
|
# Not enough data to make a new window. Return empty data.
|
|
284
245
|
empty_data_shape = (
|
|
285
|
-
message.data.shape[:axis_idx]
|
|
286
|
-
+ (0, self._state.window_samples)
|
|
287
|
-
+ message.data.shape[axis_idx + 1 :]
|
|
246
|
+
message.data.shape[:axis_idx] + (0, self._state.window_samples) + message.data.shape[axis_idx + 1 :]
|
|
288
247
|
)
|
|
289
248
|
out_dat = xp.zeros(empty_data_shape, dtype=message.data.dtype)
|
|
290
249
|
# out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
|
|
@@ -305,9 +264,7 @@ class WindowTransformer(
|
|
|
305
264
|
return msg_out
|
|
306
265
|
|
|
307
266
|
|
|
308
|
-
class Window(
|
|
309
|
-
BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]
|
|
310
|
-
):
|
|
267
|
+
class Window(BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]):
|
|
311
268
|
SETTINGS = WindowSettings
|
|
312
269
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
313
270
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
@@ -325,30 +282,19 @@ class Window(
|
|
|
325
282
|
try:
|
|
326
283
|
ret = self.processor(message)
|
|
327
284
|
if ret.data.size > 0:
|
|
328
|
-
if
|
|
329
|
-
self.SETTINGS.newaxis is not None
|
|
330
|
-
or self.SETTINGS.window_dur is None
|
|
331
|
-
):
|
|
285
|
+
if self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None:
|
|
332
286
|
# Multi-win mode or pass-through mode.
|
|
333
287
|
yield self.OUTPUT_SIGNAL, ret
|
|
334
288
|
else:
|
|
335
289
|
# We need to split out_msg into multiple yields, dropping newaxis.
|
|
336
290
|
axis_idx = ret.get_axis_idx("win")
|
|
337
291
|
win_axis = ret.axes["win"]
|
|
338
|
-
offsets = win_axis.value(
|
|
339
|
-
xp.asarray(range(ret.data.shape[axis_idx]))
|
|
340
|
-
)
|
|
292
|
+
offsets = win_axis.value(xp.asarray(range(ret.data.shape[axis_idx])))
|
|
341
293
|
for msg_ix in range(ret.data.shape[axis_idx]):
|
|
342
294
|
# Need to drop 'win' and replace self.SETTINGS.axis from axes.
|
|
343
295
|
_out_axes = {
|
|
344
|
-
**{
|
|
345
|
-
|
|
346
|
-
for k, v in ret.axes.items()
|
|
347
|
-
if k not in ["win", self.SETTINGS.axis]
|
|
348
|
-
},
|
|
349
|
-
self.SETTINGS.axis: replace(
|
|
350
|
-
ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]
|
|
351
|
-
),
|
|
296
|
+
**{k: v for k, v in ret.axes.items() if k not in ["win", self.SETTINGS.axis]},
|
|
297
|
+
self.SETTINGS.axis: replace(ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]),
|
|
352
298
|
}
|
|
353
299
|
_ret = replace(
|
|
354
300
|
ret,
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-sigproc
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.6.0
|
|
4
4
|
Summary: Timeseries signal processing implementations in ezmsg
|
|
5
|
-
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
|
|
5
|
+
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
7
|
-
License-File: LICENSE
|
|
7
|
+
License-File: LICENSE
|
|
8
8
|
Requires-Python: >=3.10.15
|
|
9
9
|
Requires-Dist: array-api-compat>=1.11.1
|
|
10
|
+
Requires-Dist: ezmsg-baseproc>=1.0.3
|
|
10
11
|
Requires-Dist: ezmsg>=3.6.0
|
|
11
12
|
Requires-Dist: numba>=0.61.0
|
|
12
13
|
Requires-Dist: numpy>=1.26.0
|