ezmsg-sigproc 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ezmsg/sigproc/__init__.py CHANGED
@@ -1 +1 @@
1
- from .__version__ import __version__
1
+ from .__version__ import __version__
@@ -1 +1 @@
1
- __version__ = '1.1.1'
1
+ __version__ = "1.2.0"
@@ -1,49 +1,40 @@
1
- from dataclasses import dataclass, field
2
- import logging
3
-
4
1
  import ezmsg.core as ez
5
2
  import scipy.signal
6
3
  import numpy as np
7
4
 
8
- from .filter import Filter, FilterState, FilterSettings
5
+ from .filter import Filter, FilterState, FilterSettingsBase
9
6
 
10
7
  from typing import Optional, Tuple, Union
11
8
 
12
- logger = logging.getLogger('ezmsg')
13
-
14
9
 
15
- @dataclass( frozen = True )
16
- class ButterworthFilterDesign:
10
+ class ButterworthFilterSettings(FilterSettingsBase):
17
11
  order: int = 0
18
12
  cuton: Optional[float] = None # Hz
19
13
  cutoff: Optional[float] = None # Hz
20
14
 
21
- def filter_specs( self ) -> Optional[ Tuple[ str, Union[ float, Tuple[ float, float ] ] ] ]:
15
+ def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
22
16
  if self.cuton is None and self.cutoff is None:
23
17
  return None
24
18
  elif self.cuton is None and self.cutoff is not None:
25
- return 'lowpass', self.cutoff
19
+ return "lowpass", self.cutoff
26
20
  elif self.cuton is not None and self.cutoff is None:
27
- return 'highpass', self.cuton
21
+ return "highpass", self.cuton
28
22
  elif self.cuton is not None and self.cutoff is not None:
29
- if self.cuton <= self.cutoff:
30
- return 'bandpass', ( self.cuton, self.cutoff )
31
- else: return 'bandstop', ( self.cutoff, self.cuton )
32
-
33
-
34
- class ButterworthFilterSettings(ButterworthFilterDesign, FilterSettings):
35
- ...
23
+ if self.cuton <= self.cutoff:
24
+ return "bandpass", (self.cuton, self.cutoff)
25
+ else:
26
+ return "bandstop", (self.cutoff, self.cuton)
36
27
 
37
28
 
38
29
  class ButterworthFilterState(FilterState):
39
- design: ButterworthFilterDesign
30
+ design: ButterworthFilterSettings
40
31
 
41
32
 
42
33
  class ButterworthFilter(Filter):
43
34
  SETTINGS: ButterworthFilterSettings
44
35
  STATE: ButterworthFilterState
45
36
 
46
- INPUT_FILTER = ez.InputStream(ButterworthFilterDesign)
37
+ INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
47
38
 
48
39
  def initialize(self) -> None:
49
40
  self.STATE.design = self.SETTINGS
@@ -55,18 +46,17 @@ class ButterworthFilter(Filter):
55
46
  if self.STATE.design.order > 0 and specs is not None:
56
47
  btype, cut = specs
57
48
  return scipy.signal.butter(
58
- self.STATE.design.order,
59
- Wn=cut,
60
- btype=btype,
61
- fs=self.STATE.fs,
62
- output="ba"
49
+ self.STATE.design.order,
50
+ Wn=cut,
51
+ btype=btype,
52
+ fs=self.STATE.fs,
53
+ output="ba",
63
54
  )
64
55
  else:
65
56
  return None
66
57
 
67
-
68
58
  @ez.subscriber(INPUT_FILTER)
69
- async def redesign(self, message: ButterworthFilterDesign) -> None:
59
+ async def redesign(self, message: ButterworthFilterSettings) -> None:
70
60
  if self.STATE.design.order != message.order:
71
61
  self.STATE.zi = None
72
62
  self.STATE.design = message
ezmsg/sigproc/decimate.py CHANGED
@@ -2,23 +2,22 @@ import ezmsg.core as ez
2
2
 
3
3
  import scipy.signal
4
4
 
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+
5
7
  from .downsample import Downsample, DownsampleSettings
6
8
  from .filter import Filter, FilterCoefficients, FilterSettings
7
- from .messages import TSMessage as TimeSeriesMessage
8
9
 
9
10
 
10
11
  class Decimate(ez.Collection):
11
-
12
12
  SETTINGS: DownsampleSettings
13
13
 
14
- INPUT_SIGNAL = ez.InputStream(TimeSeriesMessage)
15
- OUTPUT_SIGNAL = ez.OutputStream(TimeSeriesMessage)
14
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
15
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
16
16
 
17
17
  FILTER = Filter()
18
18
  DOWNSAMPLE = Downsample()
19
19
 
20
20
  def configure(self) -> None:
21
-
22
21
  self.DOWNSAMPLE.apply_settings(self.SETTINGS)
23
22
 
24
23
  if self.SETTINGS.factor < 1:
@@ -27,11 +26,9 @@ class Decimate(ez.Collection):
27
26
  filt = FilterCoefficients()
28
27
  else:
29
28
  # See scipy.signal.decimate for IIR Filter Condition
30
- system = scipy.signal.dlti(
31
- *scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
32
- )
33
-
34
- filt = FilterCoefficients(b=system.num, a=system.den)
29
+ b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
30
+ system = scipy.signal.dlti(b, a)
31
+ filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
35
32
 
36
33
  self.FILTER.apply_settings(FilterSettings(filt=filt))
37
34
 
@@ -1,69 +1,63 @@
1
- from dataclasses import dataclass, replace
1
+ from dataclasses import replace
2
+
3
+ from ezmsg.util.messages.axisarray import AxisArray
2
4
 
3
5
  import ezmsg.core as ez
4
6
  import numpy as np
5
7
 
6
- from .messages import TSMessage as TimeSeriesMessage
7
-
8
8
  from typing import (
9
9
  AsyncGenerator,
10
10
  Optional,
11
11
  )
12
12
 
13
- @dataclass( frozen = True )
14
- class DownsampleSettingsMessage:
15
- factor: int = 1
16
13
 
17
-
18
- class DownsampleSettings(DownsampleSettingsMessage, ez.Settings):
19
- ...
14
+ class DownsampleSettings(ez.Settings):
15
+ axis: Optional[str] = None
16
+ factor: int = 1
20
17
 
21
18
 
22
19
  class DownsampleState(ez.State):
23
- cur_settings: DownsampleSettingsMessage
20
+ cur_settings: DownsampleSettings
24
21
  s_idx: int = 0
25
22
 
26
23
 
27
24
  class Downsample(ez.Unit):
28
-
29
25
  SETTINGS: DownsampleSettings
30
26
  STATE: DownsampleState
31
27
 
32
- INPUT_SETTINGS = ez.InputStream(DownsampleSettingsMessage)
33
- INPUT_SIGNAL = ez.InputStream(TimeSeriesMessage)
34
- OUTPUT_SIGNAL = ez.OutputStream(TimeSeriesMessage)
28
+ INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
29
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
30
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
35
31
 
36
32
  def initialize(self) -> None:
37
33
  self.STATE.cur_settings = self.SETTINGS
38
34
 
39
35
  @ez.subscriber(INPUT_SETTINGS)
40
- async def on_settings(self, msg: DownsampleSettingsMessage) -> None:
36
+ async def on_settings(self, msg: DownsampleSettings) -> None:
41
37
  self.STATE.cur_settings = msg
42
38
 
43
- @ez.subscriber(INPUT_SIGNAL)
39
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
44
40
  @ez.publisher(OUTPUT_SIGNAL)
45
- async def on_signal(self, msg: TimeSeriesMessage) -> AsyncGenerator:
46
-
41
+ async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
47
42
  if self.STATE.cur_settings.factor < 1:
48
43
  raise ValueError("Downsample factor must be at least 1 (no downsampling)")
49
44
 
50
- samples = np.arange(msg.n_time) + self.STATE.s_idx
45
+ axis_name = self.STATE.cur_settings.axis
46
+ if axis_name is None:
47
+ axis_name = msg.dims[0]
48
+ axis = msg.get_axis(axis_name)
49
+ axis_idx = msg.get_axis_idx(axis_name)
50
+
51
+ samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
51
52
  samples = samples % self.STATE.cur_settings.factor
52
53
  self.STATE.s_idx = samples[-1] + 1
53
54
 
54
55
  pub_samples = np.where(samples == 0)[0]
55
-
56
56
  if len(pub_samples) != 0:
57
-
58
- time_view = np.moveaxis(msg.data, msg.time_dim, 0)
59
- data_down = time_view[pub_samples, ...]
60
- data_down = np.moveaxis(data_down, 0, -(msg.time_dim))
61
-
62
- yield (
63
- self.OUTPUT_SIGNAL,
64
- replace(
65
- msg,
66
- data=data_down,
67
- fs=msg.fs / self.SETTINGS.factor
68
- )
69
- )
57
+ new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
58
+ new_offset = axis.offset + (axis.gain * pub_samples[0].item())
59
+ new_gain = axis.gain * self.STATE.cur_settings.factor
60
+ new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
61
+ down_data = np.take(msg.data, pub_samples, axis_idx)
62
+ out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
63
+ yield self.OUTPUT_SIGNAL, out_msg
@@ -1,29 +1,24 @@
1
1
  import asyncio
2
- from dataclasses import field, replace
2
+ from dataclasses import replace
3
3
 
4
4
  import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+
5
7
  import numpy as np
6
8
 
7
9
  from .window import Window, WindowSettings
8
- from .messages import TSMessage
9
10
 
10
- from typing import (
11
- AsyncGenerator,
12
- Optional
13
- )
11
+ from typing import AsyncGenerator, Optional
14
12
 
15
13
 
16
14
  class EWMSettings(ez.Settings):
15
+ axis: Optional[str] = None
17
16
  zero_offset: bool = True # If true, we assume zero DC offset
18
17
 
19
18
 
20
19
  class EWMState(ez.State):
21
- last_signal: Optional[TSMessage] = None
22
- buffer_queue: "asyncio.Queue[ TSMessage ]" = field(default_factory=asyncio.Queue)
23
-
24
- pows: Optional[np.ndarray] = None
25
- scale_arr: Optional[np.ndarray] = None
26
- pw0: Optional[np.ndarray] = None
20
+ buffer_queue: "asyncio.Queue[AxisArray]"
21
+ signal_queue: "asyncio.Queue[AxisArray]"
27
22
 
28
23
 
29
24
  class EWM(ez.Unit):
@@ -31,84 +26,95 @@ class EWM(ez.Unit):
31
26
  Exponentially Weighted Moving Average Standardization
32
27
 
33
28
  References https://stackoverflow.com/a/42926270
34
- FIXME: Assumes time axis is on dimension 0
35
29
  """
30
+
36
31
  SETTINGS: EWMSettings
37
32
  STATE: EWMState
38
33
 
39
- INPUT_SIGNAL = ez.InputStream(TSMessage)
40
- INPUT_BUFFER = ez.InputStream(TSMessage)
41
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
34
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
35
+ INPUT_BUFFER = ez.InputStream(AxisArray)
36
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
+
38
+ def initialize(self) -> None:
39
+ self.STATE.signal_queue = asyncio.Queue()
40
+ self.STATE.buffer_queue = asyncio.Queue()
42
41
 
43
42
  @ez.subscriber(INPUT_SIGNAL)
44
- async def on_signal(self, message: TSMessage) -> None:
45
- self.STATE.last_signal = message
43
+ async def on_signal(self, message: AxisArray) -> None:
44
+ self.STATE.signal_queue.put_nowait(message)
46
45
 
47
46
  @ez.subscriber(INPUT_BUFFER)
48
- async def on_buffer(self, message: TSMessage) -> None:
49
- buffer_len = message.n_time
50
- block_len = self.STATE.last_signal.n_time
51
- window = buffer_len - block_len
52
-
53
- alpha = 2 / (window + 1.0)
54
- alpha_rev = 1 - alpha
55
-
56
- self.STATE.pows = alpha_rev ** (np.arange(buffer_len + 1))
57
- self.STATE.scale_arr = 1 / self.STATE.pows[:-1]
58
- self.STATE.pw0 = alpha * alpha_rev ** (buffer_len - 1)
59
-
47
+ async def on_buffer(self, message: AxisArray) -> None:
60
48
  self.STATE.buffer_queue.put_nowait(message)
61
49
 
62
50
  @ez.publisher(OUTPUT_SIGNAL)
63
51
  async def sync_output(self) -> AsyncGenerator:
52
+ while True:
53
+ signal = await self.STATE.signal_queue.get()
54
+ buffer = await self.STATE.buffer_queue.get() # includes signal
64
55
 
65
- def ewma(data: np.ndarray) -> np.ndarray:
66
- """ Assumes time axis is dim 0 """
67
- mult: np.ndarray = self.STATE.scale_arr[:, np.newaxis] * data * self.STATE.pw0
68
- out = self.STATE.scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
56
+ axis_name = self.SETTINGS.axis
57
+ if axis_name is None:
58
+ axis_name = signal.dims[0]
69
59
 
70
- if not self.SETTINGS.zero_offset:
71
- out = (data[0, :, np.newaxis] * self.STATE.pows[1:]).T + out
60
+ axis_idx = signal.get_axis_idx(axis_name)
72
61
 
73
- return out
62
+ buffer_len = buffer.shape[axis_idx]
63
+ block_len = signal.shape[axis_idx]
64
+ window = buffer_len - block_len
74
65
 
75
- while True:
76
- buffer = await self.STATE.buffer_queue.get() # includes signal
77
- signal = self.STATE.last_signal # necessarily not "None" once there's a buffer.
66
+ alpha = 2 / (window + 1.0)
67
+ alpha_rev = 1 - alpha
78
68
 
79
- block_len = signal.n_time
69
+ pows = alpha_rev ** (np.arange(buffer_len + 1))
70
+ scale_arr = 1 / pows[:-1]
71
+ pw0 = alpha * alpha_rev ** (buffer_len - 1)
80
72
 
81
- mean = ewma(buffer.data)
82
- std: np.ndarray = ((buffer.data - mean) ** 2.0)
83
- std = ewma(std)
73
+ buffer_data = buffer.data
74
+ buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
84
75
 
85
- standardized: np.ndarray = (buffer.data - mean) / np.sqrt(std).clip(1e-4)
76
+ def ewma(data: np.ndarray) -> np.ndarray:
77
+ mult = scale_arr[:, np.newaxis] * data * pw0
78
+ out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
86
79
 
87
- yield (
88
- self.OUTPUT_SIGNAL,
89
- replace(signal, data=standardized[-block_len:, ...])
90
- )
80
+ if not self.SETTINGS.zero_offset:
81
+ out = (data[0, :, np.newaxis] * pows[1:]).T + out
82
+
83
+ return out
84
+
85
+ mean = ewma(buffer_data)
86
+ std = ewma((buffer_data - mean) ** 2.0)
87
+
88
+ standardized = (buffer_data - mean) / np.sqrt(std).clip(1e-4)
89
+ standardized = standardized[-signal.shape[axis_idx] :, ...]
90
+ standardized = np.moveaxis(standardized, axis_idx, 0)
91
+
92
+ yield self.OUTPUT_SIGNAL, replace(signal, data=standardized)
91
93
 
92
94
 
93
95
  class EWMFilterSettings(ez.Settings):
94
96
  history_dur: float # previous data to accumulate for standardization
97
+ axis: Optional[str] = None
98
+ zero_offset: bool = True # If true, we assume zero DC offset for input data
95
99
 
96
100
 
97
101
  class EWMFilter(ez.Collection):
98
-
99
102
  SETTINGS: EWMFilterSettings
100
103
 
101
- INPUT_SIGNAL = ez.InputStream(TSMessage)
102
- OUTPUT_SIGNAL = ez.OutputStream(TSMessage)
104
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
105
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
103
106
 
104
107
  WINDOW = Window()
105
108
  EWM = EWM()
106
109
 
107
110
  def configure(self) -> None:
111
+ self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
112
+
108
113
  self.WINDOW.apply_settings(
109
114
  WindowSettings(
115
+ axis=self.SETTINGS.axis,
110
116
  window_dur=self.SETTINGS.history_dur,
111
- window_shift=None # 1:1 mode
117
+ window_shift=None, # 1:1 mode
112
118
  )
113
119
  )
114
120
 
@@ -117,5 +123,5 @@ class EWMFilter(ez.Collection):
117
123
  (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL),
118
124
  (self.WINDOW.OUTPUT_SIGNAL, self.EWM.INPUT_BUFFER),
119
125
  (self.INPUT_SIGNAL, self.EWM.INPUT_SIGNAL),
120
- (self.EWM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL)
126
+ (self.EWM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
121
127
  )
ezmsg/sigproc/filter.py CHANGED
@@ -4,26 +4,30 @@ import ezmsg.core as ez
4
4
  import scipy.signal
5
5
  import numpy as np
6
6
  import asyncio
7
- import logging
8
7
 
9
- from .messages import TSMessage as TimeSeriesMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
10
9
 
11
10
  from typing import AsyncGenerator, Optional, Tuple
12
11
 
13
12
 
14
13
  @dataclass
15
14
  class FilterCoefficients:
16
- b: np.ndarray = field(default_factory = lambda: np.array([1.0, 0.0]))
17
- a: np.ndarray = field(default_factory = lambda: np.array([1.0, 0.0]))
15
+ b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
16
+ a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
18
17
 
19
18
 
20
- class FilterSettings(ez.Settings):
19
+ class FilterSettingsBase(ez.Settings):
20
+ axis: Optional[str] = None
21
+ fs: Optional[float] = None
22
+
23
+
24
+ class FilterSettings(FilterSettingsBase):
21
25
  # If you'd like to statically design a filter, define it in settings
22
26
  filt: Optional[FilterCoefficients] = None
23
- fs: Optional[float] = None
24
27
 
25
28
 
26
29
  class FilterState(ez.State):
30
+ axis: Optional[str] = None
27
31
  zi: Optional[np.ndarray] = None
28
32
  filt_designed: bool = False
29
33
  filt: Optional[FilterCoefficients] = None
@@ -33,21 +37,25 @@ class FilterState(ez.State):
33
37
 
34
38
 
35
39
  class Filter(ez.Unit):
36
- SETTINGS: FilterSettings
40
+ SETTINGS: FilterSettingsBase
37
41
  STATE: FilterState
38
42
 
39
43
  INPUT_FILTER = ez.InputStream(FilterCoefficients)
40
- INPUT_SIGNAL = ez.InputStream(TimeSeriesMessage)
41
- OUTPUT_SIGNAL = ez.OutputStream(TimeSeriesMessage)
44
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
45
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
42
46
 
43
47
  def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
44
48
  raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
45
49
 
46
50
  # Set up filter with static initialization if specified
47
51
  def initialize(self) -> None:
48
- if self.SETTINGS.filt is not None:
49
- self.STATE.filt = self.SETTINGS.filt
50
- self.STATE.filt_set.set()
52
+ if self.SETTINGS.axis is not None:
53
+ self.STATE.axis = self.SETTINGS.axis
54
+
55
+ if isinstance(self.SETTINGS, FilterSettings):
56
+ if self.SETTINGS.filt is not None:
57
+ self.STATE.filt = self.SETTINGS.filt
58
+ self.STATE.filt_set.set()
51
59
  else:
52
60
  self.STATE.filt_set.clear()
53
61
 
@@ -64,7 +72,9 @@ class Filter(ez.Unit):
64
72
  def update_filter(self):
65
73
  try:
66
74
  coefs = self.design_filter()
67
- self.STATE.filt = FilterCoefficients() if coefs is None else FilterCoefficients( *coefs )
75
+ self.STATE.filt = (
76
+ FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
77
+ )
68
78
  self.STATE.filt_set.set()
69
79
  self.STATE.filt_designed = True
70
80
  except NotImplementedError as e:
@@ -74,30 +84,36 @@ class Filter(ez.Unit):
74
84
 
75
85
  @ez.subscriber(INPUT_SIGNAL)
76
86
  @ez.publisher(OUTPUT_SIGNAL)
77
- async def apply_filter(self, message: TimeSeriesMessage) -> AsyncGenerator:
78
- if self.STATE.fs != message.fs and self.STATE.filt_designed is True:
79
- self.STATE.fs = message.fs
87
+ async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
88
+ axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
89
+ axis_idx = msg.get_axis_idx(axis_name)
90
+ axis = msg.get_axis(axis_name)
91
+ fs = 1.0 / axis.gain
92
+
93
+ if self.STATE.fs != fs and self.STATE.filt_designed is True:
94
+ self.STATE.fs = fs
80
95
  self.update_filter()
81
96
 
82
97
  # Ensure filter is defined
98
+ # TODO: Maybe have me be a passthrough filter until coefficients are received
83
99
  if self.STATE.filt is None:
84
100
  self.STATE.filt_set.clear()
85
101
  ez.logger.info("Awaiting filter coefficients...")
86
102
  await self.STATE.filt_set.wait()
87
103
  ez.logger.info("Filter coefficients received.")
88
104
 
89
- arr_in: np.ndarray
105
+ assert self.STATE.filt is not None
106
+
107
+ arr_in = msg.data
90
108
 
91
109
  # If the array is one dimensional, add a temporary second dimension so that the math works out
92
110
  one_dimensional = False
93
- if message.data.ndim == 1:
94
- arr_in = np.expand_dims(message.data, axis=1)
111
+ if arr_in.ndim == 1:
112
+ arr_in = np.expand_dims(arr_in, axis=1)
95
113
  one_dimensional = True
96
- else:
97
- arr_in = message.data
98
114
 
99
115
  # We will perform filter with time dimension as last axis
100
- arr_in = np.moveaxis(arr_in, message.time_dim, -1)
116
+ arr_in = np.moveaxis(arr_in, axis_idx, -1)
101
117
  samp_shape = arr_in[..., 0].shape
102
118
 
103
119
  # Re-calculate/reset zi if necessary
@@ -115,10 +131,10 @@ class Filter(ez.Unit):
115
131
  self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
116
132
  )
117
133
 
118
- arr_out = np.moveaxis(arr_out, -1, message.time_dim)
134
+ arr_out = np.moveaxis(arr_out, -1, axis_idx)
119
135
 
120
136
  # Remove temporary first dimension if necessary
121
137
  if one_dimensional:
122
138
  arr_out = np.squeeze(arr_out, axis=1)
123
139
 
124
- yield (self.OUTPUT_SIGNAL, replace(message, data=arr_out))
140
+ yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
ezmsg/sigproc/messages.py CHANGED
@@ -1,51 +1,31 @@
1
1
  import warnings
2
2
  import time
3
3
 
4
- from dataclasses import dataclass, field
5
- from ezmsg.util.messages import AxisArray, TimeAxis
4
+ import numpy.typing as npt
6
5
 
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+
8
+ from typing import Optional
7
9
 
8
10
  # UPCOMING: TSMessage Deprecation
9
11
  # TSMessage is deprecated because it doesn't handle multiple time axes well.
10
- # AxisArray has an incompatible API but supports a superset of functionality
11
- # including messages with multiple time axes and dimensional or categorical axes.
12
- # warnings.warn(
13
- # "TimeSeriesMessage/TSMessage is deprecated. Please use ezmsg.utils.AxisArray",
14
- # DeprecationWarning,
15
- # stacklevel=2
16
- # )
17
-
18
- @dataclass
19
- class TSMessage(AxisArray):
20
- """
21
- NOTE: UPCOMING DEPRECATION. Please use ezmsg.utils.AxisArray
22
- This class remains as a backwards-compatible API for AxisArray
23
-
24
- TS(TimeSeries)Message:
25
-
26
- Base message type for timeseries data within ezmsg.sigproc
27
- TSMessages have one time dimension, and a sampling rate along that one time axis.
28
- Any higher dimensions are treated as "channels"
29
- """
30
-
31
- fs: float = 1.0
32
- time_dim: int = 0
33
- timestamp: float = field(default_factory = time.time)
34
-
35
- def __post_init__(self):
36
- super().__post_init__()
37
- self.axes[self.dims[self.time_dim]] = TimeAxis(fs=self.fs)
38
-
39
- @property
40
- def n_time(self) -> int:
41
- """ Number of time values in the message """
42
- return self.data.shape[self.time_dim]
43
-
44
- @property
45
- def n_ch(self) -> int:
46
- """ Number of channels in the message """
47
- return self.shape2d(self.dims[self.time_dim])[1]
48
-
49
- @property
50
- def _timestamp(self) -> float:
51
- return self.timestamp
12
+ # AxisArray has an incompatible API but supports a superset of functionality.
13
+ warnings.warn(
14
+ "TimeSeriesMessage/TSMessage is deprecated. Please use ezmsg.utils.AxisArray",
15
+ DeprecationWarning,
16
+ stacklevel=2,
17
+ )
18
+
19
+
20
+ def TSMessage(
21
+ data: npt.NDArray,
22
+ fs: float = 1.0,
23
+ time_dim: int = 0,
24
+ timestamp: Optional[float] = None,
25
+ ) -> AxisArray:
26
+ dims = [f"dim_{i}" for i in range(data.ndim)]
27
+ dims[time_dim] = "time"
28
+ offset = time.time() if timestamp is None else timestamp
29
+ offset_adj = data.shape[time_dim] / fs # offset corresponds to idx[0] on time_dim
30
+ axis = AxisArray.Axis.TimeAxis(fs, offset=offset - offset_adj)
31
+ return AxisArray(data, dims=dims, axes=dict(time=axis))