ezmsg-sigproc 1.7.0__py3-none-any.whl → 2.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. ezmsg/sigproc/__version__.py +22 -4
  2. ezmsg/sigproc/activation.py +31 -40
  3. ezmsg/sigproc/adaptive_lattice_notch.py +212 -0
  4. ezmsg/sigproc/affinetransform.py +171 -169
  5. ezmsg/sigproc/aggregate.py +190 -97
  6. ezmsg/sigproc/bandpower.py +60 -55
  7. ezmsg/sigproc/base.py +143 -33
  8. ezmsg/sigproc/butterworthfilter.py +34 -38
  9. ezmsg/sigproc/butterworthzerophase.py +305 -0
  10. ezmsg/sigproc/cheby.py +23 -17
  11. ezmsg/sigproc/combfilter.py +160 -0
  12. ezmsg/sigproc/coordinatespaces.py +159 -0
  13. ezmsg/sigproc/decimate.py +15 -10
  14. ezmsg/sigproc/denormalize.py +78 -0
  15. ezmsg/sigproc/detrend.py +28 -0
  16. ezmsg/sigproc/diff.py +82 -0
  17. ezmsg/sigproc/downsample.py +72 -81
  18. ezmsg/sigproc/ewma.py +217 -0
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +39 -0
  21. ezmsg/sigproc/fbcca.py +307 -0
  22. ezmsg/sigproc/filter.py +254 -148
  23. ezmsg/sigproc/filterbank.py +226 -214
  24. ezmsg/sigproc/filterbankdesign.py +129 -0
  25. ezmsg/sigproc/fir_hilbert.py +336 -0
  26. ezmsg/sigproc/fir_pmc.py +209 -0
  27. ezmsg/sigproc/firfilter.py +117 -0
  28. ezmsg/sigproc/gaussiansmoothing.py +89 -0
  29. ezmsg/sigproc/kaiser.py +106 -0
  30. ezmsg/sigproc/linear.py +120 -0
  31. ezmsg/sigproc/math/abs.py +23 -22
  32. ezmsg/sigproc/math/add.py +120 -0
  33. ezmsg/sigproc/math/clip.py +33 -25
  34. ezmsg/sigproc/math/difference.py +117 -43
  35. ezmsg/sigproc/math/invert.py +18 -25
  36. ezmsg/sigproc/math/log.py +38 -33
  37. ezmsg/sigproc/math/scale.py +24 -25
  38. ezmsg/sigproc/messages.py +1 -2
  39. ezmsg/sigproc/quantize.py +68 -0
  40. ezmsg/sigproc/resample.py +278 -0
  41. ezmsg/sigproc/rollingscaler.py +232 -0
  42. ezmsg/sigproc/sampler.py +209 -254
  43. ezmsg/sigproc/scaler.py +93 -218
  44. ezmsg/sigproc/signalinjector.py +44 -43
  45. ezmsg/sigproc/slicer.py +74 -102
  46. ezmsg/sigproc/spectral.py +3 -3
  47. ezmsg/sigproc/spectrogram.py +70 -70
  48. ezmsg/sigproc/spectrum.py +187 -173
  49. ezmsg/sigproc/transpose.py +134 -0
  50. ezmsg/sigproc/util/__init__.py +0 -0
  51. ezmsg/sigproc/util/asio.py +25 -0
  52. ezmsg/sigproc/util/axisarray_buffer.py +365 -0
  53. ezmsg/sigproc/util/buffer.py +449 -0
  54. ezmsg/sigproc/util/message.py +17 -0
  55. ezmsg/sigproc/util/profile.py +23 -0
  56. ezmsg/sigproc/util/sparse.py +115 -0
  57. ezmsg/sigproc/util/typeresolution.py +17 -0
  58. ezmsg/sigproc/wavelets.py +147 -154
  59. ezmsg/sigproc/window.py +248 -210
  60. ezmsg_sigproc-2.10.0.dist-info/METADATA +60 -0
  61. ezmsg_sigproc-2.10.0.dist-info/RECORD +65 -0
  62. {ezmsg_sigproc-1.7.0.dist-info → ezmsg_sigproc-2.10.0.dist-info}/WHEEL +1 -1
  63. ezmsg/sigproc/synth.py +0 -621
  64. ezmsg_sigproc-1.7.0.dist-info/METADATA +0 -58
  65. ezmsg_sigproc-1.7.0.dist-info/RECORD +0 -36
  66. /ezmsg_sigproc-1.7.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.10.0.dist-info/licenses/LICENSE +0 -0
@@ -1,195 +1,198 @@
1
+ """Affine transformations via matrix multiplication: y = Ax or y = Ax + B.
2
+
3
+ For full matrix transformations where channels are mixed (off-diagonal weights),
4
+ use :obj:`AffineTransformTransformer` or the `AffineTransform` unit.
5
+
6
+ For simple per-channel scaling and offset (diagonal weights only), use
7
+ :obj:`LinearTransformTransformer` from :mod:`ezmsg.sigproc.linear` instead,
8
+ which is more efficient as it avoids matrix multiplication.
9
+ """
10
+
1
11
  import os
2
12
  from pathlib import Path
3
- import typing
4
13
 
14
+ import ezmsg.core as ez
5
15
  import numpy as np
6
16
  import numpy.typing as npt
7
- import ezmsg.core as ez
17
+ from ezmsg.baseproc import (
18
+ BaseStatefulTransformer,
19
+ BaseTransformer,
20
+ BaseTransformerUnit,
21
+ processor_state,
22
+ )
8
23
  from ezmsg.util.messages.axisarray import AxisArray, AxisBase
9
24
  from ezmsg.util.messages.util import replace
10
- from ezmsg.util.generator import consumer
11
-
12
- from .base import GenAxisArray
13
25
 
14
26
 
15
- @consumer
16
- def affine_transform(
17
- weights: np.ndarray | str | Path,
18
- axis: str | None = None,
19
- right_multiply: bool = True,
20
- ) -> typing.Generator[AxisArray, AxisArray, None]:
27
+ class AffineTransformSettings(ez.Settings):
28
+ """
29
+ Settings for :obj:`AffineTransform`.
21
30
  """
22
- Perform affine transformations on streaming data.
23
31
 
24
- Args:
25
- weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
26
- axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
27
- right_multiply: Set False to transpose the weights before applying.
32
+ weights: np.ndarray | str | Path
33
+ """An array of weights or a path to a file with weights compatible with np.loadtxt."""
28
34
 
29
- Returns:
30
- A primed generator object that yields an :obj:`AxisArray` object for every
31
- :obj:`AxisArray` it receives via `send`.
32
- """
33
- msg_out = AxisArray(np.array([]), dims=[""])
35
+ axis: str | None = None
36
+ """The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array."""
34
37
 
35
- # Check parameters
36
- if isinstance(weights, str):
37
- if weights == "passthrough":
38
- weights = None
39
- else:
40
- weights = Path(os.path.abspath(os.path.expanduser(weights)))
41
- if isinstance(weights, Path):
42
- weights = np.loadtxt(weights, delimiter=",")
43
- if not right_multiply:
44
- weights = weights.T
45
- if weights is not None:
46
- weights = np.ascontiguousarray(weights)
47
-
48
- # State variables
49
- # New axis with transformed labels, if required
38
+ right_multiply: bool = True
39
+ """Set False to transpose the weights before applying."""
40
+
41
+
42
+ @processor_state
43
+ class AffineTransformState:
44
+ weights: npt.NDArray | None = None
50
45
  new_axis: AxisBase | None = None
51
46
 
52
- # Reset if any of these change.
53
- check_input = {"key": None}
54
- # We assume key change catches labels change; we don't want to check labels every message
55
- # We don't need to check if input size has changed because weights multiplication will fail if so.
56
-
57
- while True:
58
- msg_in: AxisArray = yield msg_out
59
-
60
- if weights is None:
61
- msg_out = msg_in
62
- continue
63
-
64
- axis = axis or msg_in.dims[-1] # Note: Most nodes default do dim[0]
65
- axis_idx = msg_in.get_axis_idx(axis)
66
-
67
- b_reset = msg_in.key != check_input["key"]
68
- if b_reset:
69
- # First sample or key has changed. Reset the state.
70
- check_input["key"] = msg_in.key
71
- # Determine if we need to modify the transformed axis.
72
- if (
73
- axis in msg_in.axes
74
- and hasattr(msg_in.axes[axis], "data")
75
- and weights.shape[0] != weights.shape[1]
76
- ):
77
- in_labels = msg_in.axes[axis].data
78
- new_labels = []
79
- n_in, n_out = weights.shape
80
- if len(in_labels) != n_in:
81
- # Something upstream did something it wasn't supposed to. We will drop the labels.
82
- ez.logger.warning(
83
- f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
84
- )
85
- else:
86
- b_filled_outputs = np.any(weights, axis=0)
87
- b_used_inputs = np.any(weights, axis=1)
88
- if np.all(b_used_inputs) and np.all(b_filled_outputs):
89
- # All inputs are used and all outputs are used, but n_in != n_out.
90
- # Mapping cannot be determined.
91
- new_labels = []
92
- elif np.all(b_used_inputs):
93
- # Strange scenario: New outputs are filled with empty data.
94
- in_ix = 0
95
- new_labels = []
96
- for out_ix in range(n_out):
97
- if b_filled_outputs[out_ix]:
98
- new_labels.append(in_labels[in_ix])
99
- in_ix += 1
100
- else:
101
- new_labels.append("")
102
- elif np.all(b_filled_outputs):
103
- # Transform is dropping some of the inputs.
104
- new_labels = np.array(in_labels)[b_used_inputs]
105
- new_axis = replace(msg_in.axes[axis], data=np.array(new_labels))
106
-
107
- data = msg_in.data
108
-
109
- if data.shape[axis_idx] == (weights.shape[0] - 1):
47
+
48
+ class AffineTransformTransformer(
49
+ BaseStatefulTransformer[AffineTransformSettings, AxisArray, AxisArray, AffineTransformState]
50
+ ):
51
+ """Apply affine transformation via matrix multiplication: y = Ax or y = Ax + B.
52
+
53
+ Use this transformer when you need full matrix transformations that mix
54
+ channels (off-diagonal weights), such as spatial filters or projections.
55
+
56
+ For simple per-channel scaling and offset where each output channel depends
57
+ only on its corresponding input channel (diagonal weight matrix), use
58
+ :obj:`LinearTransformTransformer` instead, which is more efficient.
59
+
60
+ The weights matrix can include an offset row (stacked as [A|B]) where the
61
+ input is automatically augmented with a column of ones to compute y = Ax + B.
62
+ """
63
+
64
+ def __call__(self, message: AxisArray) -> AxisArray:
65
+ # Override __call__ so we can shortcut if weights are None.
66
+ if self.settings.weights is None or (
67
+ isinstance(self.settings.weights, str) and self.settings.weights == "passthrough"
68
+ ):
69
+ return message
70
+ return super().__call__(message)
71
+
72
+ def _hash_message(self, message: AxisArray) -> int:
73
+ return hash(message.key)
74
+
75
+ def _reset_state(self, message: AxisArray) -> None:
76
+ weights = self.settings.weights
77
+ if isinstance(weights, str):
78
+ weights = Path(os.path.abspath(os.path.expanduser(weights)))
79
+ if isinstance(weights, Path):
80
+ weights = np.loadtxt(weights, delimiter=",")
81
+ if not self.settings.right_multiply:
82
+ weights = weights.T
83
+ if weights is not None:
84
+ weights = np.ascontiguousarray(weights)
85
+
86
+ self._state.weights = weights
87
+
88
+ axis = self.settings.axis or message.dims[-1]
89
+ if axis in message.axes and hasattr(message.axes[axis], "data") and weights.shape[0] != weights.shape[1]:
90
+ in_labels = message.axes[axis].data
91
+ new_labels = []
92
+ n_in, n_out = weights.shape
93
+ if len(in_labels) != n_in:
94
+ ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
95
+ else:
96
+ b_filled_outputs = np.any(weights, axis=0)
97
+ b_used_inputs = np.any(weights, axis=1)
98
+ if np.all(b_used_inputs) and np.all(b_filled_outputs):
99
+ new_labels = []
100
+ elif np.all(b_used_inputs):
101
+ in_ix = 0
102
+ new_labels = []
103
+ for out_ix in range(n_out):
104
+ if b_filled_outputs[out_ix]:
105
+ new_labels.append(in_labels[in_ix])
106
+ in_ix += 1
107
+ else:
108
+ new_labels.append("")
109
+ elif np.all(b_filled_outputs):
110
+ new_labels = np.array(in_labels)[b_used_inputs]
111
+
112
+ self._state.new_axis = replace(message.axes[axis], data=np.array(new_labels))
113
+
114
+ def _process(self, message: AxisArray) -> AxisArray:
115
+ axis = self.settings.axis or message.dims[-1]
116
+ axis_idx = message.get_axis_idx(axis)
117
+ data = message.data
118
+
119
+ if data.shape[axis_idx] == (self._state.weights.shape[0] - 1):
110
120
  # The weights are stacked A|B where A is the transform and B is a single row
111
121
  # in the equation y = Ax + B. This supports NeuroKey's weights matrices.
112
122
  sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
113
- data = np.concatenate(
114
- (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
115
- )
123
+ data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
116
124
 
117
- if axis_idx in [-1, len(msg_in.dims) - 1]:
118
- data = np.matmul(data, weights)
125
+ if axis_idx in [-1, len(message.dims) - 1]:
126
+ data = np.matmul(data, self._state.weights)
119
127
  else:
120
128
  data = np.moveaxis(data, axis_idx, -1)
121
- data = np.matmul(data, weights)
129
+ data = np.matmul(data, self._state.weights)
122
130
  data = np.moveaxis(data, -1, axis_idx)
123
131
 
124
132
  replace_kwargs = {"data": data}
125
- if new_axis is not None:
126
- replace_kwargs["axes"] = {**msg_in.axes, axis: new_axis}
127
- msg_out = replace(msg_in, **replace_kwargs)
133
+ if self._state.new_axis is not None:
134
+ replace_kwargs["axes"] = {**message.axes, axis: self._state.new_axis}
128
135
 
136
+ return replace(message, **replace_kwargs)
129
137
 
130
- class AffineTransformSettings(ez.Settings):
131
- """
132
- Settings for :obj:`AffineTransform`.
133
- See :obj:`affine_transform` for argument details.
134
- """
135
138
 
136
- weights: np.ndarray | str | Path
137
- axis: str | None = None
138
- right_multiply: bool = True
139
+ class AffineTransform(BaseTransformerUnit[AffineTransformSettings, AxisArray, AxisArray, AffineTransformTransformer]):
140
+ SETTINGS = AffineTransformSettings
139
141
 
140
142
 
141
- class AffineTransform(GenAxisArray):
142
- """:obj:`Unit` for :obj:`affine_transform`"""
143
+ def affine_transform(
144
+ weights: np.ndarray | str | Path,
145
+ axis: str | None = None,
146
+ right_multiply: bool = True,
147
+ ) -> AffineTransformTransformer:
148
+ """
149
+ Perform affine transformations on streaming data.
143
150
 
144
- SETTINGS = AffineTransformSettings
151
+ Args:
152
+ weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
153
+ axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
154
+ right_multiply: Set False to transpose the weights before applying.
145
155
 
146
- def construct_generator(self):
147
- self.STATE.gen = affine_transform(
148
- weights=self.SETTINGS.weights,
149
- axis=self.SETTINGS.axis,
150
- right_multiply=self.SETTINGS.right_multiply,
151
- )
156
+ Returns:
157
+ :obj:`AffineTransformTransformer`.
158
+ """
159
+ return AffineTransformTransformer(
160
+ AffineTransformSettings(weights=weights, axis=axis, right_multiply=right_multiply)
161
+ )
152
162
 
153
163
 
154
164
  def zeros_for_noop(data: npt.NDArray, **ignore_kwargs) -> npt.NDArray:
155
165
  return np.zeros_like(data)
156
166
 
157
167
 
158
- @consumer
159
- def common_rereference(
160
- mode: str = "mean", axis: str | None = None, include_current: bool = True
161
- ) -> typing.Generator[AxisArray, AxisArray, None]:
168
+ class CommonRereferenceSettings(ez.Settings):
169
+ """
170
+ Settings for :obj:`CommonRereference`
162
171
  """
163
- Perform common average referencing (CAR) on streaming data.
164
172
 
165
- Args:
166
- mode: The statistical mode to apply -- either "mean" or "median"
167
- axis: The name of hte axis to apply the transformation to.
168
- include_current: Set False to exclude each channel from participating in the calculation of its reference.
173
+ mode: str = "mean"
174
+ """The statistical mode to apply -- either "mean" or "median"."""
169
175
 
170
- Returns:
171
- A primed generator object that yields an :obj:`AxisArray` object
172
- for every :obj:`AxisArray` it receives via `send`.
173
- """
174
- msg_out = AxisArray(np.array([]), dims=[""])
176
+ axis: str | None = None
177
+ """The name of the axis to apply the transformation to."""
175
178
 
176
- if mode == "passthrough":
177
- include_current = True
179
+ include_current: bool = True
180
+ """Set False to exclude each channel from participating in the calculation of its reference."""
178
181
 
179
- func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[mode]
180
182
 
181
- while True:
182
- msg_in: AxisArray = yield msg_out
183
+ class CommonRereferenceTransformer(BaseTransformer[CommonRereferenceSettings, AxisArray, AxisArray]):
184
+ def _process(self, message: AxisArray) -> AxisArray:
185
+ if self.settings.mode == "passthrough":
186
+ return message
183
187
 
184
- if axis is None:
185
- axis = msg_in.dims[-1]
186
- axis_idx = -1
187
- else:
188
- axis_idx = msg_in.get_axis_idx(axis)
188
+ axis = self.settings.axis or message.dims[-1]
189
+ axis_idx = message.get_axis_idx(axis)
190
+
191
+ func = {"mean": np.mean, "median": np.median, "passthrough": zeros_for_noop}[self.settings.mode]
189
192
 
190
- ref_data = func(msg_in.data, axis=axis_idx, keepdims=True)
193
+ ref_data = func(message.data, axis=axis_idx, keepdims=True)
191
194
 
192
- if not include_current:
195
+ if not self.settings.include_current:
193
196
  # Typical `CAR = x[0]/N + x[1]/N + ... x[i-1]/N + x[i]/N + x[i+1]/N + ... + x[N-1]/N`
194
197
  # and is the same for all i, so it is calculated only once in `ref_data`.
195
198
  # However, if we had excluded the current channel,
@@ -200,34 +203,33 @@ def common_rereference(
200
203
  # from the current channel (i.e., `x[i] / (N-1)`)
201
204
  # i.e., `CAR[i] = (N / (N-1)) * common_CAR - x[i]/(N-1)`
202
205
  # We can use broadcasting subtraction instead of looping over channels.
203
- N = msg_in.data.shape[axis_idx]
204
- ref_data = (N / (N - 1)) * ref_data - msg_in.data / (N - 1)
205
- # Side note: I profiled using affine_transform and it's about 30x slower than this implementation.
206
+ N = message.data.shape[axis_idx]
207
+ ref_data = (N / (N - 1)) * ref_data - message.data / (N - 1)
208
+ # Note: I profiled using AffineTransformTransformer; it's ~30x slower than this implementation.
206
209
 
207
- msg_out = replace(msg_in, data=msg_in.data - ref_data)
210
+ return replace(message, data=message.data - ref_data)
208
211
 
209
212
 
210
- class CommonRereferenceSettings(ez.Settings):
211
- """
212
- Settings for :obj:`CommonRereference`
213
- See :obj:`common_rereference` for argument details.
214
- """
215
-
216
- mode: str = "mean"
217
- axis: str | None = None
218
- include_current: bool = True
213
+ class CommonRereference(
214
+ BaseTransformerUnit[CommonRereferenceSettings, AxisArray, AxisArray, CommonRereferenceTransformer]
215
+ ):
216
+ SETTINGS = CommonRereferenceSettings
219
217
 
220
218
 
221
- class CommonRereference(GenAxisArray):
222
- """
223
- :obj:`Unit` for :obj:`common_rereference`.
219
+ def common_rereference(
220
+ mode: str = "mean", axis: str | None = None, include_current: bool = True
221
+ ) -> CommonRereferenceTransformer:
224
222
  """
223
+ Perform common average referencing (CAR) on streaming data.
225
224
 
226
- SETTINGS = CommonRereferenceSettings
225
+ Args:
226
+ mode: The statistical mode to apply -- either "mean" or "median"
227
+ axis: The name of hte axis to apply the transformation to.
228
+ include_current: Set False to exclude each channel from participating in the calculation of its reference.
227
229
 
228
- def construct_generator(self):
229
- self.STATE.gen = common_rereference(
230
- mode=self.SETTINGS.mode,
231
- axis=self.SETTINGS.axis,
232
- include_current=self.SETTINGS.include_current,
233
- )
230
+ Returns:
231
+ :obj:`CommonRereferenceTransformer`
232
+ """
233
+ return CommonRereferenceTransformer(
234
+ CommonRereferenceSettings(mode=mode, axis=axis, include_current=include_current)
235
+ )