ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
  4. ezmsg/sigproc/affinetransform.py +13 -38
  5. ezmsg/sigproc/aggregate.py +13 -30
  6. ezmsg/sigproc/bandpower.py +7 -15
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +123 -0
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/decimate.py +2 -6
  13. ezmsg/sigproc/denormalize.py +6 -11
  14. ezmsg/sigproc/detrend.py +3 -4
  15. ezmsg/sigproc/diff.py +8 -17
  16. ezmsg/sigproc/downsample.py +6 -14
  17. ezmsg/sigproc/ewma.py +11 -27
  18. ezmsg/sigproc/ewmfilter.py +1 -1
  19. ezmsg/sigproc/extract_axis.py +3 -4
  20. ezmsg/sigproc/fbcca.py +31 -56
  21. ezmsg/sigproc/filter.py +19 -45
  22. ezmsg/sigproc/filterbank.py +33 -70
  23. ezmsg/sigproc/filterbankdesign.py +5 -12
  24. ezmsg/sigproc/fir_hilbert.py +336 -0
  25. ezmsg/sigproc/fir_pmc.py +209 -0
  26. ezmsg/sigproc/firfilter.py +12 -14
  27. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  28. ezmsg/sigproc/kaiser.py +11 -15
  29. ezmsg/sigproc/math/abs.py +1 -3
  30. ezmsg/sigproc/math/add.py +121 -0
  31. ezmsg/sigproc/math/clip.py +1 -1
  32. ezmsg/sigproc/math/difference.py +98 -36
  33. ezmsg/sigproc/math/invert.py +1 -3
  34. ezmsg/sigproc/math/log.py +2 -6
  35. ezmsg/sigproc/messages.py +1 -2
  36. ezmsg/sigproc/quantize.py +2 -4
  37. ezmsg/sigproc/resample.py +13 -34
  38. ezmsg/sigproc/rollingscaler.py +232 -0
  39. ezmsg/sigproc/sampler.py +17 -35
  40. ezmsg/sigproc/scaler.py +8 -18
  41. ezmsg/sigproc/signalinjector.py +6 -16
  42. ezmsg/sigproc/slicer.py +9 -28
  43. ezmsg/sigproc/spectral.py +3 -3
  44. ezmsg/sigproc/spectrogram.py +12 -19
  45. ezmsg/sigproc/spectrum.py +12 -32
  46. ezmsg/sigproc/transpose.py +7 -18
  47. ezmsg/sigproc/util/asio.py +25 -156
  48. ezmsg/sigproc/util/axisarray_buffer.py +10 -26
  49. ezmsg/sigproc/util/buffer.py +18 -43
  50. ezmsg/sigproc/util/message.py +17 -31
  51. ezmsg/sigproc/util/profile.py +23 -174
  52. ezmsg/sigproc/util/sparse.py +5 -15
  53. ezmsg/sigproc/util/typeresolution.py +17 -83
  54. ezmsg/sigproc/wavelets.py +6 -15
  55. ezmsg/sigproc/window.py +24 -78
  56. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
  57. ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
  58. ezmsg/sigproc/synth.py +0 -774
  59. ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
  60. {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
  61. /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
@@ -1,174 +1,23 @@
1
- import functools
2
- import logging
3
- import os
4
- from pathlib import Path
5
- import time
6
- import typing
7
-
8
- import ezmsg.core as ez
9
-
10
-
11
- HEADER = "Time,Source,Topic,SampleTime,PerfCounter,Elapsed"
12
-
13
-
14
- def get_logger_path() -> Path:
15
- # Retrieve the logfile name from the environment variable
16
- logfile = os.environ.get("EZMSG_PROFILE", None)
17
-
18
- # Determine the log file path, defaulting to "ezprofiler.log" if not set
19
- logpath = Path(logfile or "ezprofiler.log")
20
-
21
- # If the log path is not absolute, prepend it with the user's home directory and ".ezmsg/profile"
22
- if not logpath.is_absolute():
23
- logpath = Path.home() / ".ezmsg" / "profile" / logpath
24
-
25
- return logpath
26
-
27
-
28
- def _setup_logger(append: bool = False) -> logging.Logger:
29
- logpath = get_logger_path()
30
- logpath.parent.mkdir(parents=True, exist_ok=True)
31
-
32
- write_header = True
33
- if logpath.exists() and logpath.is_file():
34
- if append:
35
- with open(logpath) as f:
36
- first_line = f.readline().rstrip()
37
- if first_line == HEADER:
38
- write_header = False
39
- else:
40
- # Remove the file if appending, but headers do not match
41
- ezmsg_logger = logging.getLogger("ezmsg")
42
- ezmsg_logger.warning(
43
- "Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
44
- )
45
- logpath.unlink()
46
- else:
47
- # Remove the file if not appending
48
- logpath.unlink()
49
-
50
- # Create a logger with the name "ezprofile"
51
- _logger = logging.getLogger("ezprofile")
52
-
53
- # Set the logger's level to EZMSG_LOGLEVEL env var value if it exists, otherwise INFO
54
- _logger.setLevel(os.environ.get("EZMSG_LOGLEVEL", "INFO").upper())
55
-
56
- # Create a file handler to write log messages to the log file
57
- fh = logging.FileHandler(logpath)
58
- fh.setLevel(logging.DEBUG) # Set the file handler log level to DEBUG
59
-
60
- # Add the file handler to the logger
61
- _logger.addHandler(fh)
62
-
63
- # Add the header if writing to new file or if header matched header in file.
64
- if write_header:
65
- _logger.debug(HEADER)
66
-
67
- # Set the log message format
68
- formatter = logging.Formatter(
69
- "%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
70
- )
71
- fh.setFormatter(formatter)
72
-
73
- return _logger
74
-
75
-
76
- logger = _setup_logger(append=True)
77
-
78
-
79
- def _process_obj(obj, trace_oldest: bool = True):
80
- samp_time = None
81
- if hasattr(obj, "axes") and ("time" in obj.axes or "win" in obj.axes):
82
- axis = "win" if "win" in obj.axes else "time"
83
- ax = obj.get_axis(axis)
84
- len = obj.data.shape[obj.get_axis_idx(axis)]
85
- if len > 0:
86
- idx = 0 if trace_oldest else (len - 1)
87
- if hasattr(ax, "data"):
88
- samp_time = ax.data[idx]
89
- else:
90
- samp_time = ax.value(idx)
91
- if ax == "win" and "time" in obj.axes:
92
- if hasattr(obj.axes["time"], "data"):
93
- samp_time += obj.axes["time"].data[idx]
94
- else:
95
- samp_time += obj.axes["time"].value(idx)
96
- return samp_time
97
-
98
-
99
- def profile_method(trace_oldest: bool = True):
100
- """
101
- Decorator to profile a method by logging its execution time and other details.
102
-
103
- Args:
104
- trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
105
-
106
- Returns:
107
- Callable: The decorated function with profiling.
108
- """
109
-
110
- def profiling_decorator(func: typing.Callable):
111
- @functools.wraps(func)
112
- def wrapped_func(caller, *args, **kwargs):
113
- start = time.perf_counter()
114
- res = func(caller, *args, **kwargs)
115
- stop = time.perf_counter()
116
- source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
117
- topic = f"{caller.address}"
118
- samp_time = _process_obj(res, trace_oldest=trace_oldest)
119
- logger.debug(
120
- ",".join(
121
- [
122
- source,
123
- topic,
124
- f"{samp_time}",
125
- f"{stop}",
126
- f"{(stop - start) * 1e3:0.4f}",
127
- ]
128
- )
129
- )
130
- return res
131
-
132
- return wrapped_func if logger.level == logging.DEBUG else func
133
-
134
- return profiling_decorator
135
-
136
-
137
- def profile_subpub(trace_oldest: bool = True):
138
- """
139
- Decorator to profile a subscriber-publisher method in an ezmsg Unit
140
- by logging its execution time and other details.
141
-
142
- Args:
143
- trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
144
-
145
- Returns:
146
- Callable: The decorated async task with profiling.
147
- """
148
-
149
- def profiling_decorator(func: typing.Callable):
150
- @functools.wraps(func)
151
- async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
152
- source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
153
- topic = f"{unit.address}"
154
- start = time.perf_counter()
155
- async for stream, obj in func(unit, msg):
156
- stop = time.perf_counter()
157
- samp_time = _process_obj(obj, trace_oldest=trace_oldest)
158
- logger.debug(
159
- ",".join(
160
- [
161
- source,
162
- topic,
163
- f"{samp_time}",
164
- f"{stop}",
165
- f"{(stop - start) * 1e3:0.4f}",
166
- ]
167
- )
168
- )
169
- start = stop
170
- yield stream, obj
171
-
172
- return wrapped_task if logger.level == logging.DEBUG else func
173
-
174
- return profiling_decorator
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.profile.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.profile import (
8
+ HEADER,
9
+ _setup_logger,
10
+ get_logger_path,
11
+ logger,
12
+ profile_method,
13
+ profile_subpub,
14
+ )
15
+
16
+ __all__ = [
17
+ "HEADER",
18
+ "get_logger_path",
19
+ "logger",
20
+ "profile_method",
21
+ "profile_subpub",
22
+ "_setup_logger",
23
+ ]
@@ -2,9 +2,7 @@ import numpy as np
2
2
  import sparse
3
3
 
4
4
 
5
- def sliding_win_oneaxis_old(
6
- s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
7
- ) -> sparse.SparseArray:
5
+ def sliding_win_oneaxis_old(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
8
6
  """
9
7
  Like `ezmsg.util.messages.axisarray.sliding_win_oneaxis` but for sparse arrays.
10
8
  This approach is about 4x slower than the version that uses coordinate arithmetic below.
@@ -23,16 +21,12 @@ def sliding_win_oneaxis_old(
23
21
  targ_slices = [slice(_, _ + nwin) for _ in range(0, s.shape[axis] - nwin + 1, step)]
24
22
  s = s.reshape(s.shape[:axis] + (1,) + s.shape[axis:])
25
23
  full_slices = (slice(None),) * s.ndim
26
- full_slices = [
27
- full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices
28
- ]
24
+ full_slices = [full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices]
29
25
  result = sparse.concatenate([s[_] for _ in full_slices], axis=axis)
30
26
  return result
31
27
 
32
28
 
33
- def sliding_win_oneaxis(
34
- s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
35
- ) -> sparse.SparseArray:
29
+ def sliding_win_oneaxis(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
36
30
  """
37
31
  Generates a view-like sparse array using a sliding window of specified length along a specified axis.
38
32
  Sparse analog of an optimized dense as_strided-based implementation with these properties:
@@ -72,9 +66,7 @@ def sliding_win_oneaxis(
72
66
  n_win_out = len(win_starts)
73
67
  if n_win_out <= 0:
74
68
  # Return array with proper shape except empty along windows axis
75
- return sparse.zeros(
76
- s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
77
- )
69
+ return sparse.zeros(s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
78
70
 
79
71
  coo = s.asformat("coo")
80
72
  coords = coo.coords # shape: (ndim, nnz)
@@ -112,9 +104,7 @@ def sliding_win_oneaxis(
112
104
  out_data_blocks.append(data[sel])
113
105
 
114
106
  if not out_coords_blocks:
115
- return sparse.zeros(
116
- s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
117
- )
107
+ return sparse.zeros(s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
118
108
 
119
109
  out_coords = np.hstack(out_coords_blocks)
120
110
  out_data = np.hstack(out_data_blocks)
@@ -1,83 +1,17 @@
1
- from types import UnionType
2
- import typing
3
- from typing_extensions import get_original_bases
4
-
5
-
6
- def resolve_typevar(cls: type, target_typevar: typing.TypeVar) -> type:
7
- """
8
- Resolve the concrete type bound to a TypeVar in a class hierarchy.
9
- This function traverses the method resolution order (MRO) of the class
10
- and checks the original bases of each class in the MRO for the TypeVar.
11
- If the TypeVar is found, it returns the concrete type bound to it.
12
- If the TypeVar is not found, it raises a TypeError.
13
- Args:
14
- cls (type): The class to inspect.
15
- target_typevar (typing.TypeVar): The TypeVar to resolve.
16
- Returns:
17
- type: The concrete type bound to the TypeVar.
18
- """
19
- for base in cls.__mro__:
20
- orig_bases = get_original_bases(base)
21
- for orig_base in orig_bases:
22
- origin = typing.get_origin(orig_base)
23
- if origin is None:
24
- continue
25
- params = getattr(origin, "__parameters__", ())
26
- if not params:
27
- continue
28
- if target_typevar in params:
29
- index = params.index(target_typevar)
30
- args = typing.get_args(orig_base)
31
- try:
32
- return args[index]
33
- except IndexError:
34
- pass
35
- raise TypeError(f"Could not resolve {target_typevar} in {cls}")
36
-
37
-
38
- TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
39
-
40
-
41
- def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
42
- """
43
- Check if two types are compatible for message passing.
44
- Returns True if:
45
- - Both are None/NoneType
46
- - Either is typing.Any
47
- - type1 is a subclass of type2, which includes
48
- - type1 and type2 are concrete types and type1 is a subclass of type2
49
- - type1 is None/NoneType and type2 is typing.Optional, or
50
- - type1 is subtype of the non-None inner type of type2 if type2 is Optional
51
- - type1 is a Union/Optional type and all inner types are compatible with type2
52
- Args:
53
- type1: First type to compare
54
- type2: Second type to compare
55
- Returns:
56
- bool: True if the types are compatible, False otherwise
57
- """
58
- # If either is Any, they are compatible
59
- if type1 is typing.Any or type2 is typing.Any:
60
- return True
61
-
62
- # Handle None as NoneType
63
- if type1 is None:
64
- type1 = type(None)
65
- if type2 is None:
66
- type2 = type(None)
67
-
68
- # Handle if type1 is Optional/Union type
69
- if typing.get_origin(type1) in {typing.Union, UnionType}:
70
- return all(
71
- check_message_type_compatibility(inner_type, type2)
72
- for inner_type in typing.get_args(type1)
73
- )
74
-
75
- # Regular issubclass check. Handles cases like:
76
- # - type1 is a subclass of concrete type2
77
- # - type1 is a subclass of the inner type of type2 if type2 is Optional
78
- # - type1 is a subclass of one of the inner types of type2 if type2 is Union
79
- # - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
80
- try:
81
- return issubclass(type1, type2)
82
- except TypeError:
83
- return False
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.typeresolution.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.typeresolution import (
8
+ TypeLike,
9
+ check_message_type_compatibility,
10
+ resolve_typevar,
11
+ )
12
+
13
+ __all__ = [
14
+ "TypeLike",
15
+ "check_message_type_compatibility",
16
+ "resolve_typevar",
17
+ ]
ezmsg/sigproc/wavelets.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import typing
2
2
 
3
+ import ezmsg.core as ez
3
4
  import numpy as np
4
5
  import numpy.typing as npt
5
6
  import pywt
6
- import ezmsg.core as ez
7
7
  from ezmsg.util.messages.axisarray import AxisArray
8
8
  from ezmsg.util.messages.util import replace
9
9
 
@@ -12,7 +12,7 @@ from .base import (
12
12
  BaseTransformerUnit,
13
13
  processor_state,
14
14
  )
15
- from .filterbank import filterbank, FilterbankMode, MinPhaseMode
15
+ from .filterbank import FilterbankMode, MinPhaseMode, filterbank
16
16
 
17
17
 
18
18
  class CWTSettings(ez.Settings):
@@ -37,9 +37,7 @@ class CWTState:
37
37
  last_conv_samp: npt.NDArray | None = None
38
38
 
39
39
 
40
- class CWTTransformer(
41
- BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
42
- ):
40
+ class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
43
41
  def _hash_message(self, message: AxisArray) -> int:
44
42
  ax_idx = message.get_axis_idx(self.settings.axis)
45
43
  in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -107,25 +105,18 @@ class CWTTransformer(
107
105
  # Create output template
108
106
  ax_idx = message.get_axis_idx(self.settings.axis)
109
107
  in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
110
- freqs = (
111
- pywt.scale2frequency(wavelet, scales, precision)
112
- / message.axes[self.settings.axis].gain
113
- )
108
+ freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
114
109
  dummy_shape = in_shape + (len(scales), 0)
115
110
  self._state.template = AxisArray(
116
111
  np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
117
- dims=message.dims[:ax_idx]
118
- + message.dims[ax_idx + 1 :]
119
- + ["freq", self.settings.axis],
112
+ dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
120
113
  axes={
121
114
  **message.axes,
122
115
  "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
123
116
  },
124
117
  key=message.key,
125
118
  )
126
- self._state.last_conv_samp = np.zeros(
127
- dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
128
- )
119
+ self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
129
120
 
130
121
  def _process(self, message: AxisArray) -> AxisArray:
131
122
  conv_msg = self._state.fbgen.send(message)
ezmsg/sigproc/window.py CHANGED
@@ -5,12 +5,12 @@ import typing
5
5
  import ezmsg.core as ez
6
6
  import numpy.typing as npt
7
7
  import sparse
8
- from array_api_compat import is_pydata_sparse_namespace, get_namespace
8
+ from array_api_compat import get_namespace, is_pydata_sparse_namespace
9
9
  from ezmsg.util.messages.axisarray import (
10
10
  AxisArray,
11
+ replace,
11
12
  slice_along_axis,
12
13
  sliding_win_oneaxis,
13
- replace,
14
14
  )
15
15
 
16
16
  from .base import (
@@ -18,8 +18,8 @@ from .base import (
18
18
  BaseTransformerUnit,
19
19
  processor_state,
20
20
  )
21
- from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
22
21
  from .util.profile import profile_subpub
22
+ from .util.sparse import sliding_win_oneaxis as sparse_sliding_win_oneaxis
23
23
 
24
24
 
25
25
  class Anchor(enum.Enum):
@@ -55,9 +55,7 @@ class WindowState:
55
55
  out_dims: list[str] | None = None
56
56
 
57
57
 
58
- class WindowTransformer(
59
- BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]
60
- ):
58
+ class WindowTransformer(BaseStatefulTransformer[WindowSettings, AxisArray, AxisArray, WindowState]):
61
59
  """
62
60
  Apply a sliding window along the specified axis to input streaming data.
63
61
  The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization
@@ -99,19 +97,13 @@ class WindowTransformer(
99
97
  # if self.settings.newaxis is None:
100
98
  # ez.logger.warning("`newaxis=None` will be replaced with `newaxis='win'`.")
101
99
  # object.__setattr__(self.settings, "newaxis", "win")
102
- if (
103
- self.settings.window_shift is None
104
- and self.settings.zero_pad_until != "input"
105
- ):
100
+ if self.settings.window_shift is None and self.settings.zero_pad_until != "input":
106
101
  ez.logger.warning(
107
102
  "`zero_pad_until` must be 'input' if `window_shift` is None. "
108
103
  f"Ignoring received argument value: {self.settings.zero_pad_until}"
109
104
  )
110
105
  object.__setattr__(self.settings, "zero_pad_until", "input")
111
- elif (
112
- self.settings.window_shift is not None
113
- and self.settings.zero_pad_until == "input"
114
- ):
106
+ elif self.settings.window_shift is not None and self.settings.zero_pad_until == "input":
115
107
  ez.logger.warning(
116
108
  "windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size "
117
109
  "of the first input. We recommend using `zero_pad_until='shift'` when `window_shift` is float-valued."
@@ -135,9 +127,7 @@ class WindowTransformer(
135
127
  def _reset_state(self, message: AxisArray) -> None:
136
128
  _newaxis = self.settings.newaxis or "win"
137
129
  if not self._state.newaxis_warned and _newaxis in message.dims:
138
- ez.logger.warning(
139
- f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead"
140
- )
130
+ ez.logger.warning(f"newaxis {_newaxis} present in input dims. Using {_newaxis}_win instead")
141
131
  self._state.newaxis_warned = True
142
132
  self.settings.newaxis = f"{_newaxis}_win"
143
133
 
@@ -154,33 +144,20 @@ class WindowTransformer(
154
144
  self._state.window_shift_samples = int(self.settings.window_shift * fs)
155
145
  if self.settings.zero_pad_until == "none":
156
146
  req_samples = self._state.window_samples
157
- elif (
158
- self.settings.zero_pad_until == "shift"
159
- and self.settings.window_shift is not None
160
- ):
147
+ elif self.settings.zero_pad_until == "shift" and self.settings.window_shift is not None:
161
148
  req_samples = self._state.window_shift_samples
162
149
  else: # i.e. zero_pad_until == "input"
163
150
  req_samples = message.data.shape[axis_idx]
164
151
  n_zero = max(0, self._state.window_samples - req_samples)
165
- init_buffer_shape = (
166
- message.data.shape[:axis_idx]
167
- + (n_zero,)
168
- + message.data.shape[axis_idx + 1 :]
169
- )
152
+ init_buffer_shape = message.data.shape[:axis_idx] + (n_zero,) + message.data.shape[axis_idx + 1 :]
170
153
  self._state.buffer = xp.zeros(init_buffer_shape, dtype=message.data.dtype)
171
154
 
172
155
  # Prepare reusable parts of output
173
156
  if self._state.out_newaxis is None:
174
- self._state.out_dims = (
175
- list(message.dims[:axis_idx])
176
- + [_newaxis]
177
- + list(message.dims[axis_idx:])
178
- )
157
+ self._state.out_dims = list(message.dims[:axis_idx]) + [_newaxis] + list(message.dims[axis_idx:])
179
158
  self._state.out_newaxis = replace(
180
159
  axis_info,
181
- gain=0.0
182
- if self.settings.window_shift is None
183
- else axis_info.gain * self._state.window_shift_samples,
160
+ gain=0.0 if self.settings.window_shift is None else axis_info.gain * self._state.window_shift_samples,
184
161
  offset=0.0, # offset modified per-msg below
185
162
  )
186
163
 
@@ -204,9 +181,7 @@ class WindowTransformer(
204
181
  # is generally faster than np.roll and slicing anyway, but this could still
205
182
  # be a performance bottleneck for large memory arrays.
206
183
  # A circular buffer might be faster.
207
- self._state.buffer = xp.concatenate(
208
- (self._state.buffer, message.data), axis=axis_idx
209
- )
184
+ self._state.buffer = xp.concatenate((self._state.buffer, message.data), axis=axis_idx)
210
185
 
211
186
  # Create a vector of buffer timestamps to track axis `offset` in output(s)
212
187
  buffer_t0 = 0.0
@@ -222,9 +197,7 @@ class WindowTransformer(
222
197
  if self.settings.window_shift is not None and self._state.shift_deficit > 0:
223
198
  n_skip = min(self._state.buffer.shape[axis_idx], self._state.shift_deficit)
224
199
  if n_skip > 0:
225
- self._state.buffer = slice_along_axis(
226
- self._state.buffer, slice(n_skip, None), axis_idx
227
- )
200
+ self._state.buffer = slice_along_axis(self._state.buffer, slice(n_skip, None), axis_idx)
228
201
  buffer_t0 += n_skip * axis_info.gain
229
202
  buffer_tlen -= n_skip
230
203
  self._state.shift_deficit -= n_skip
@@ -249,20 +222,12 @@ class WindowTransformer(
249
222
  self._state.buffer, slice(-self._state.window_samples, None), axis_idx
250
223
  )
251
224
  out_dat = self._state.buffer.reshape(
252
- self._state.buffer.shape[:axis_idx]
253
- + (1,)
254
- + self._state.buffer.shape[axis_idx:]
255
- )
256
- win_offset = buffer_t0 + axis_info.gain * (
257
- buffer_tlen - self._state.window_samples
225
+ self._state.buffer.shape[:axis_idx] + (1,) + self._state.buffer.shape[axis_idx:]
258
226
  )
227
+ win_offset = buffer_t0 + axis_info.gain * (buffer_tlen - self._state.window_samples)
259
228
  elif self._state.buffer.shape[axis_idx] >= self._state.window_samples:
260
229
  # Deterministic window shifts.
261
- sliding_win_fun = (
262
- sparse_sliding_win_oneaxis
263
- if is_pydata_sparse_namespace(xp)
264
- else sliding_win_oneaxis
265
- )
230
+ sliding_win_fun = sparse_sliding_win_oneaxis if is_pydata_sparse_namespace(xp) else sliding_win_oneaxis
266
231
  out_dat = sliding_win_fun(
267
232
  self._state.buffer,
268
233
  self._state.window_samples,
@@ -273,18 +238,12 @@ class WindowTransformer(
273
238
 
274
239
  # Drop expired beginning of buffer and update shift_deficit
275
240
  multi_shift = self._state.window_shift_samples * out_dat.shape[axis_idx]
276
- self._state.shift_deficit = max(
277
- 0, multi_shift - self._state.buffer.shape[axis_idx]
278
- )
279
- self._state.buffer = slice_along_axis(
280
- self._state.buffer, slice(multi_shift, None), axis_idx
281
- )
241
+ self._state.shift_deficit = max(0, multi_shift - self._state.buffer.shape[axis_idx])
242
+ self._state.buffer = slice_along_axis(self._state.buffer, slice(multi_shift, None), axis_idx)
282
243
  else:
283
244
  # Not enough data to make a new window. Return empty data.
284
245
  empty_data_shape = (
285
- message.data.shape[:axis_idx]
286
- + (0, self._state.window_samples)
287
- + message.data.shape[axis_idx + 1 :]
246
+ message.data.shape[:axis_idx] + (0, self._state.window_samples) + message.data.shape[axis_idx + 1 :]
288
247
  )
289
248
  out_dat = xp.zeros(empty_data_shape, dtype=message.data.dtype)
290
249
  # out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero.
@@ -305,9 +264,7 @@ class WindowTransformer(
305
264
  return msg_out
306
265
 
307
266
 
308
- class Window(
309
- BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]
310
- ):
267
+ class Window(BaseTransformerUnit[WindowSettings, AxisArray, AxisArray, WindowTransformer]):
311
268
  SETTINGS = WindowSettings
312
269
  INPUT_SIGNAL = ez.InputStream(AxisArray)
313
270
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
@@ -325,30 +282,19 @@ class Window(
325
282
  try:
326
283
  ret = self.processor(message)
327
284
  if ret.data.size > 0:
328
- if (
329
- self.SETTINGS.newaxis is not None
330
- or self.SETTINGS.window_dur is None
331
- ):
285
+ if self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None:
332
286
  # Multi-win mode or pass-through mode.
333
287
  yield self.OUTPUT_SIGNAL, ret
334
288
  else:
335
289
  # We need to split out_msg into multiple yields, dropping newaxis.
336
290
  axis_idx = ret.get_axis_idx("win")
337
291
  win_axis = ret.axes["win"]
338
- offsets = win_axis.value(
339
- xp.asarray(range(ret.data.shape[axis_idx]))
340
- )
292
+ offsets = win_axis.value(xp.asarray(range(ret.data.shape[axis_idx])))
341
293
  for msg_ix in range(ret.data.shape[axis_idx]):
342
294
  # Need to drop 'win' and replace self.SETTINGS.axis from axes.
343
295
  _out_axes = {
344
- **{
345
- k: v
346
- for k, v in ret.axes.items()
347
- if k not in ["win", self.SETTINGS.axis]
348
- },
349
- self.SETTINGS.axis: replace(
350
- ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]
351
- ),
296
+ **{k: v for k, v in ret.axes.items() if k not in ["win", self.SETTINGS.axis]},
297
+ self.SETTINGS.axis: replace(ret.axes[self.SETTINGS.axis], offset=offsets[msg_ix]),
352
298
  }
353
299
  _ret = replace(
354
300
  ret,
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-sigproc
3
- Version: 2.4.1
3
+ Version: 2.6.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
- Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
5
+ Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
7
- License-File: LICENSE.txt
7
+ License-File: LICENSE
8
8
  Requires-Python: >=3.10.15
9
9
  Requires-Dist: array-api-compat>=1.11.1
10
+ Requires-Dist: ezmsg-baseproc>=1.0.3
10
11
  Requires-Dist: ezmsg>=3.6.0
11
12
  Requires-Dist: numba>=0.61.0
12
13
  Requires-Dist: numpy>=1.26.0