ezmsg-sigproc 1.1.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {ezmsg-sigproc-1.1.0/ezmsg_sigproc.egg-info → ezmsg-sigproc-1.2.0}/PKG-INFO +1 -1
  2. ezmsg-sigproc-1.2.0/ezmsg/sigproc/__init__.py +1 -0
  3. ezmsg-sigproc-1.2.0/ezmsg/sigproc/__version__.py +1 -0
  4. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/butterworthfilter.py +17 -27
  5. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/decimate.py +7 -10
  6. ezmsg-sigproc-1.2.0/ezmsg/sigproc/downsample.py +63 -0
  7. ezmsg-sigproc-1.2.0/ezmsg/sigproc/ewmfilter.py +127 -0
  8. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg/sigproc/filter.py +43 -30
  9. ezmsg-sigproc-1.2.0/ezmsg/sigproc/messages.py +31 -0
  10. ezmsg-sigproc-1.2.0/ezmsg/sigproc/sampler.py +287 -0
  11. ezmsg-sigproc-1.2.0/ezmsg/sigproc/spectral.py +132 -0
  12. ezmsg-sigproc-1.2.0/ezmsg/sigproc/synth.py +411 -0
  13. ezmsg-sigproc-1.2.0/ezmsg/sigproc/window.py +144 -0
  14. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0/ezmsg_sigproc.egg-info}/PKG-INFO +1 -1
  15. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/SOURCES.txt +6 -2
  16. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/requires.txt +1 -1
  17. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/setup.cfg +2 -2
  18. ezmsg-sigproc-1.2.0/setup.py +7 -0
  19. ezmsg-sigproc-1.2.0/tests/test_butterworth.py +143 -0
  20. ezmsg-sigproc-1.2.0/tests/test_downsample.py +133 -0
  21. ezmsg-sigproc-1.2.0/tests/test_window.py +140 -0
  22. ezmsg-sigproc-1.1.0/ezmsg/sigproc/__init__.py +0 -0
  23. ezmsg-sigproc-1.1.0/ezmsg/sigproc/downsample.py +0 -69
  24. ezmsg-sigproc-1.1.0/ezmsg/sigproc/ewmfilter.py +0 -121
  25. ezmsg-sigproc-1.1.0/ezmsg/sigproc/messages.py +0 -51
  26. ezmsg-sigproc-1.1.0/ezmsg/sigproc/sampler.py +0 -253
  27. ezmsg-sigproc-1.1.0/ezmsg/sigproc/synth.py +0 -236
  28. ezmsg-sigproc-1.1.0/ezmsg/sigproc/timeseriesmessage.py +0 -1
  29. ezmsg-sigproc-1.1.0/ezmsg/sigproc/window.py +0 -112
  30. ezmsg-sigproc-1.1.0/setup.py +0 -7
  31. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/LICENSE.txt +0 -0
  32. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/README.md +0 -0
  33. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/dependency_links.txt +0 -0
  34. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/not-zip-safe +0 -0
  35. {ezmsg-sigproc-1.1.0 → ezmsg-sigproc-1.2.0}/ezmsg_sigproc.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ezmsg-sigproc
3
- Version: 1.1.0
3
+ Version: 1.2.0
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Home-page: https://github.com/iscoe/ezmsg
6
6
  Author: Griffin Milsap
@@ -0,0 +1 @@
1
+ from .__version__ import __version__
@@ -0,0 +1 @@
1
+ __version__ = "1.2.0"
@@ -1,49 +1,40 @@
1
- from dataclasses import dataclass, field
2
- import logging
3
-
4
1
  import ezmsg.core as ez
5
2
  import scipy.signal
6
3
  import numpy as np
7
4
 
8
- from .filter import Filter, FilterState, FilterSettings
5
+ from .filter import Filter, FilterState, FilterSettingsBase
9
6
 
10
7
  from typing import Optional, Tuple, Union
11
8
 
12
- logger = logging.getLogger('ezmsg')
13
-
14
9
 
15
- @dataclass( frozen = True )
16
- class ButterworthFilterDesign:
10
+ class ButterworthFilterSettings(FilterSettingsBase):
17
11
  order: int = 0
18
12
  cuton: Optional[float] = None # Hz
19
13
  cutoff: Optional[float] = None # Hz
20
14
 
21
- def filter_specs( self ) -> Optional[ Tuple[ str, Union[ float, Tuple[ float, float ] ] ] ]:
15
+ def filter_specs(self) -> Optional[Tuple[str, Union[float, Tuple[float, float]]]]:
22
16
  if self.cuton is None and self.cutoff is None:
23
17
  return None
24
18
  elif self.cuton is None and self.cutoff is not None:
25
- return 'lowpass', self.cutoff
19
+ return "lowpass", self.cutoff
26
20
  elif self.cuton is not None and self.cutoff is None:
27
- return 'highpass', self.cuton
21
+ return "highpass", self.cuton
28
22
  elif self.cuton is not None and self.cutoff is not None:
29
- if self.cuton <= self.cutoff:
30
- return 'bandpass', ( self.cuton, self.cutoff )
31
- else: return 'bandstop', ( self.cutoff, self.cuton )
32
-
33
-
34
- class ButterworthFilterSettings(ButterworthFilterDesign, FilterSettings):
35
- ...
23
+ if self.cuton <= self.cutoff:
24
+ return "bandpass", (self.cuton, self.cutoff)
25
+ else:
26
+ return "bandstop", (self.cutoff, self.cuton)
36
27
 
37
28
 
38
29
  class ButterworthFilterState(FilterState):
39
- design: ButterworthFilterDesign
30
+ design: ButterworthFilterSettings
40
31
 
41
32
 
42
33
  class ButterworthFilter(Filter):
43
34
  SETTINGS: ButterworthFilterSettings
44
35
  STATE: ButterworthFilterState
45
36
 
46
- INPUT_FILTER = ez.InputStream(ButterworthFilterDesign)
37
+ INPUT_FILTER = ez.InputStream(ButterworthFilterSettings)
47
38
 
48
39
  def initialize(self) -> None:
49
40
  self.STATE.design = self.SETTINGS
@@ -55,18 +46,17 @@ class ButterworthFilter(Filter):
55
46
  if self.STATE.design.order > 0 and specs is not None:
56
47
  btype, cut = specs
57
48
  return scipy.signal.butter(
58
- self.STATE.design.order,
59
- Wn=cut,
60
- btype=btype,
61
- fs=self.STATE.fs,
62
- output="ba"
49
+ self.STATE.design.order,
50
+ Wn=cut,
51
+ btype=btype,
52
+ fs=self.STATE.fs,
53
+ output="ba",
63
54
  )
64
55
  else:
65
56
  return None
66
57
 
67
-
68
58
  @ez.subscriber(INPUT_FILTER)
69
- async def redesign(self, message: ButterworthFilterDesign) -> None:
59
+ async def redesign(self, message: ButterworthFilterSettings) -> None:
70
60
  if self.STATE.design.order != message.order:
71
61
  self.STATE.zi = None
72
62
  self.STATE.design = message
@@ -2,23 +2,22 @@ import ezmsg.core as ez
2
2
 
3
3
  import scipy.signal
4
4
 
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+
5
7
  from .downsample import Downsample, DownsampleSettings
6
8
  from .filter import Filter, FilterCoefficients, FilterSettings
7
- from .messages import TSMessage as TimeSeriesMessage
8
9
 
9
10
 
10
11
  class Decimate(ez.Collection):
11
-
12
12
  SETTINGS: DownsampleSettings
13
13
 
14
- INPUT_SIGNAL = ez.InputStream(TimeSeriesMessage)
15
- OUTPUT_SIGNAL = ez.OutputStream(TimeSeriesMessage)
14
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
15
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
16
16
 
17
17
  FILTER = Filter()
18
18
  DOWNSAMPLE = Downsample()
19
19
 
20
20
  def configure(self) -> None:
21
-
22
21
  self.DOWNSAMPLE.apply_settings(self.SETTINGS)
23
22
 
24
23
  if self.SETTINGS.factor < 1:
@@ -27,11 +26,9 @@ class Decimate(ez.Collection):
27
26
  filt = FilterCoefficients()
28
27
  else:
29
28
  # See scipy.signal.decimate for IIR Filter Condition
30
- system = scipy.signal.dlti(
31
- *scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
32
- )
33
-
34
- filt = FilterCoefficients(b=system.num, a=system.den)
29
+ b, a = scipy.signal.cheby1(8, 0.05, 0.8 / self.SETTINGS.factor)
30
+ system = scipy.signal.dlti(b, a)
31
+ filt = FilterCoefficients(b=system.num, a=system.den) # type: ignore
35
32
 
36
33
  self.FILTER.apply_settings(FilterSettings(filt=filt))
37
34
 
@@ -0,0 +1,63 @@
1
+ from dataclasses import replace
2
+
3
+ from ezmsg.util.messages.axisarray import AxisArray
4
+
5
+ import ezmsg.core as ez
6
+ import numpy as np
7
+
8
+ from typing import (
9
+ AsyncGenerator,
10
+ Optional,
11
+ )
12
+
13
+
14
+ class DownsampleSettings(ez.Settings):
15
+ axis: Optional[str] = None
16
+ factor: int = 1
17
+
18
+
19
+ class DownsampleState(ez.State):
20
+ cur_settings: DownsampleSettings
21
+ s_idx: int = 0
22
+
23
+
24
+ class Downsample(ez.Unit):
25
+ SETTINGS: DownsampleSettings
26
+ STATE: DownsampleState
27
+
28
+ INPUT_SETTINGS = ez.InputStream(DownsampleSettings)
29
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
30
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
31
+
32
+ def initialize(self) -> None:
33
+ self.STATE.cur_settings = self.SETTINGS
34
+
35
+ @ez.subscriber(INPUT_SETTINGS)
36
+ async def on_settings(self, msg: DownsampleSettings) -> None:
37
+ self.STATE.cur_settings = msg
38
+
39
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
40
+ @ez.publisher(OUTPUT_SIGNAL)
41
+ async def on_signal(self, msg: AxisArray) -> AsyncGenerator:
42
+ if self.STATE.cur_settings.factor < 1:
43
+ raise ValueError("Downsample factor must be at least 1 (no downsampling)")
44
+
45
+ axis_name = self.STATE.cur_settings.axis
46
+ if axis_name is None:
47
+ axis_name = msg.dims[0]
48
+ axis = msg.get_axis(axis_name)
49
+ axis_idx = msg.get_axis_idx(axis_name)
50
+
51
+ samples = np.arange(msg.data.shape[axis_idx]) + self.STATE.s_idx
52
+ samples = samples % self.STATE.cur_settings.factor
53
+ self.STATE.s_idx = samples[-1] + 1
54
+
55
+ pub_samples = np.where(samples == 0)[0]
56
+ if len(pub_samples) != 0:
57
+ new_axes = {ax_name: msg.get_axis(ax_name) for ax_name in msg.dims}
58
+ new_offset = axis.offset + (axis.gain * pub_samples[0].item())
59
+ new_gain = axis.gain * self.STATE.cur_settings.factor
60
+ new_axes[axis_name] = replace(axis, gain=new_gain, offset=new_offset)
61
+ down_data = np.take(msg.data, pub_samples, axis_idx)
62
+ out_msg = replace(msg, data=down_data, dims=msg.dims, axes=new_axes)
63
+ yield self.OUTPUT_SIGNAL, out_msg
@@ -0,0 +1,127 @@
1
+ import asyncio
2
+ from dataclasses import replace
3
+
4
+ import ezmsg.core as ez
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+
7
+ import numpy as np
8
+
9
+ from .window import Window, WindowSettings
10
+
11
+ from typing import AsyncGenerator, Optional
12
+
13
+
14
+ class EWMSettings(ez.Settings):
15
+ axis: Optional[str] = None
16
+ zero_offset: bool = True # If true, we assume zero DC offset
17
+
18
+
19
+ class EWMState(ez.State):
20
+ buffer_queue: "asyncio.Queue[AxisArray]"
21
+ signal_queue: "asyncio.Queue[AxisArray]"
22
+
23
+
24
+ class EWM(ez.Unit):
25
+ """
26
+ Exponentially Weighted Moving Average Standardization
27
+
28
+ References https://stackoverflow.com/a/42926270
29
+ """
30
+
31
+ SETTINGS: EWMSettings
32
+ STATE: EWMState
33
+
34
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
35
+ INPUT_BUFFER = ez.InputStream(AxisArray)
36
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
37
+
38
+ def initialize(self) -> None:
39
+ self.STATE.signal_queue = asyncio.Queue()
40
+ self.STATE.buffer_queue = asyncio.Queue()
41
+
42
+ @ez.subscriber(INPUT_SIGNAL)
43
+ async def on_signal(self, message: AxisArray) -> None:
44
+ self.STATE.signal_queue.put_nowait(message)
45
+
46
+ @ez.subscriber(INPUT_BUFFER)
47
+ async def on_buffer(self, message: AxisArray) -> None:
48
+ self.STATE.buffer_queue.put_nowait(message)
49
+
50
+ @ez.publisher(OUTPUT_SIGNAL)
51
+ async def sync_output(self) -> AsyncGenerator:
52
+ while True:
53
+ signal = await self.STATE.signal_queue.get()
54
+ buffer = await self.STATE.buffer_queue.get() # includes signal
55
+
56
+ axis_name = self.SETTINGS.axis
57
+ if axis_name is None:
58
+ axis_name = signal.dims[0]
59
+
60
+ axis_idx = signal.get_axis_idx(axis_name)
61
+
62
+ buffer_len = buffer.shape[axis_idx]
63
+ block_len = signal.shape[axis_idx]
64
+ window = buffer_len - block_len
65
+
66
+ alpha = 2 / (window + 1.0)
67
+ alpha_rev = 1 - alpha
68
+
69
+ pows = alpha_rev ** (np.arange(buffer_len + 1))
70
+ scale_arr = 1 / pows[:-1]
71
+ pw0 = alpha * alpha_rev ** (buffer_len - 1)
72
+
73
+ buffer_data = buffer.data
74
+ buffer_data = np.moveaxis(buffer_data, axis_idx, 0)
75
+
76
+ def ewma(data: np.ndarray) -> np.ndarray:
77
+ mult = scale_arr[:, np.newaxis] * data * pw0
78
+ out = scale_arr[::-1, np.newaxis] * mult.cumsum(axis=0)
79
+
80
+ if not self.SETTINGS.zero_offset:
81
+ out = (data[0, :, np.newaxis] * pows[1:]).T + out
82
+
83
+ return out
84
+
85
+ mean = ewma(buffer_data)
86
+ std = ewma((buffer_data - mean) ** 2.0)
87
+
88
+ standardized = (buffer_data - mean) / np.sqrt(std).clip(1e-4)
89
+ standardized = standardized[-signal.shape[axis_idx] :, ...]
90
+ standardized = np.moveaxis(standardized, axis_idx, 0)
91
+
92
+ yield self.OUTPUT_SIGNAL, replace(signal, data=standardized)
93
+
94
+
95
+ class EWMFilterSettings(ez.Settings):
96
+ history_dur: float # previous data to accumulate for standardization
97
+ axis: Optional[str] = None
98
+ zero_offset: bool = True # If true, we assume zero DC offset for input data
99
+
100
+
101
+ class EWMFilter(ez.Collection):
102
+ SETTINGS: EWMFilterSettings
103
+
104
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
105
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
106
+
107
+ WINDOW = Window()
108
+ EWM = EWM()
109
+
110
+ def configure(self) -> None:
111
+ self.EWM.apply_settings(EWMSettings(axis=self.SETTINGS.axis, zero_offset=True))
112
+
113
+ self.WINDOW.apply_settings(
114
+ WindowSettings(
115
+ axis=self.SETTINGS.axis,
116
+ window_dur=self.SETTINGS.history_dur,
117
+ window_shift=None, # 1:1 mode
118
+ )
119
+ )
120
+
121
+ def network(self) -> ez.NetworkDefinition:
122
+ return (
123
+ (self.INPUT_SIGNAL, self.WINDOW.INPUT_SIGNAL),
124
+ (self.WINDOW.OUTPUT_SIGNAL, self.EWM.INPUT_BUFFER),
125
+ (self.INPUT_SIGNAL, self.EWM.INPUT_SIGNAL),
126
+ (self.EWM.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
127
+ )
@@ -4,29 +4,30 @@ import ezmsg.core as ez
4
4
  import scipy.signal
5
5
  import numpy as np
6
6
  import asyncio
7
- import logging
8
7
 
9
- from .messages import TSMessage as TimeSeriesMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
10
9
 
11
10
  from typing import AsyncGenerator, Optional, Tuple
12
11
 
13
- logger = logging.getLogger('ezmsg')
14
-
15
12
 
16
13
  @dataclass
17
14
  class FilterCoefficients:
18
- b: np.ndarray = field(default_factory = lambda: np.array([1.0, 0.0]))
19
- a: np.ndarray = field(default_factory = lambda: np.array([1.0, 0.0]))
15
+ b: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
16
+ a: np.ndarray = field(default_factory=lambda: np.array([1.0, 0.0]))
17
+
20
18
 
19
+ class FilterSettingsBase(ez.Settings):
20
+ axis: Optional[str] = None
21
+ fs: Optional[float] = None
21
22
 
22
23
 
23
- class FilterSettings(ez.Settings):
24
+ class FilterSettings(FilterSettingsBase):
24
25
  # If you'd like to statically design a filter, define it in settings
25
26
  filt: Optional[FilterCoefficients] = None
26
- fs: Optional[float] = None
27
27
 
28
28
 
29
29
  class FilterState(ez.State):
30
+ axis: Optional[str] = None
30
31
  zi: Optional[np.ndarray] = None
31
32
  filt_designed: bool = False
32
33
  filt: Optional[FilterCoefficients] = None
@@ -36,21 +37,25 @@ class FilterState(ez.State):
36
37
 
37
38
 
38
39
  class Filter(ez.Unit):
39
- SETTINGS: FilterSettings
40
+ SETTINGS: FilterSettingsBase
40
41
  STATE: FilterState
41
42
 
42
43
  INPUT_FILTER = ez.InputStream(FilterCoefficients)
43
- INPUT_SIGNAL = ez.InputStream(TimeSeriesMessage)
44
- OUTPUT_SIGNAL = ez.OutputStream(TimeSeriesMessage)
44
+ INPUT_SIGNAL = ez.InputStream(AxisArray)
45
+ OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
45
46
 
46
47
  def design_filter(self) -> Optional[Tuple[np.ndarray, np.ndarray]]:
47
48
  raise NotImplementedError("Must implement 'design_filter' in Unit subclass!")
48
49
 
49
50
  # Set up filter with static initialization if specified
50
51
  def initialize(self) -> None:
51
- if self.SETTINGS.filt is not None:
52
- self.STATE.filt = self.SETTINGS.filt
53
- self.STATE.filt_set.set()
52
+ if self.SETTINGS.axis is not None:
53
+ self.STATE.axis = self.SETTINGS.axis
54
+
55
+ if isinstance(self.SETTINGS, FilterSettings):
56
+ if self.SETTINGS.filt is not None:
57
+ self.STATE.filt = self.SETTINGS.filt
58
+ self.STATE.filt_set.set()
54
59
  else:
55
60
  self.STATE.filt_set.clear()
56
61
 
@@ -58,7 +63,7 @@ class Filter(ez.Unit):
58
63
  try:
59
64
  self.update_filter()
60
65
  except NotImplementedError:
61
- logger.debug("Using filter coefficients.")
66
+ ez.logger.debug("Using filter coefficients.")
62
67
 
63
68
  @ez.subscriber(INPUT_FILTER)
64
69
  async def redesign(self, message: FilterCoefficients):
@@ -67,40 +72,48 @@ class Filter(ez.Unit):
67
72
  def update_filter(self):
68
73
  try:
69
74
  coefs = self.design_filter()
70
- self.STATE.filt = FilterCoefficients() if coefs is None else FilterCoefficients( *coefs )
75
+ self.STATE.filt = (
76
+ FilterCoefficients() if coefs is None else FilterCoefficients(*coefs)
77
+ )
71
78
  self.STATE.filt_set.set()
72
79
  self.STATE.filt_designed = True
73
80
  except NotImplementedError as e:
74
81
  raise e
75
82
  except Exception as e:
76
- logger.warning(f"Error when designing filter: {e}")
83
+ ez.logger.warning(f"Error when designing filter: {e}")
77
84
 
78
85
  @ez.subscriber(INPUT_SIGNAL)
79
86
  @ez.publisher(OUTPUT_SIGNAL)
80
- async def apply_filter(self, message: TimeSeriesMessage) -> AsyncGenerator:
81
- if self.STATE.fs != message.fs and self.STATE.filt_designed is True:
82
- self.STATE.fs = message.fs
87
+ async def apply_filter(self, msg: AxisArray) -> AsyncGenerator:
88
+ axis_name = msg.dims[0] if self.STATE.axis is None else self.STATE.axis
89
+ axis_idx = msg.get_axis_idx(axis_name)
90
+ axis = msg.get_axis(axis_name)
91
+ fs = 1.0 / axis.gain
92
+
93
+ if self.STATE.fs != fs and self.STATE.filt_designed is True:
94
+ self.STATE.fs = fs
83
95
  self.update_filter()
84
96
 
85
97
  # Ensure filter is defined
98
+ # TODO: Maybe have me be a passthrough filter until coefficients are received
86
99
  if self.STATE.filt is None:
87
100
  self.STATE.filt_set.clear()
88
- logger.info("Awaiting filter coefficients...")
101
+ ez.logger.info("Awaiting filter coefficients...")
89
102
  await self.STATE.filt_set.wait()
90
- logger.info("Filter coefficients received.")
103
+ ez.logger.info("Filter coefficients received.")
91
104
 
92
- arr_in: np.ndarray
105
+ assert self.STATE.filt is not None
106
+
107
+ arr_in = msg.data
93
108
 
94
109
  # If the array is one dimensional, add a temporary second dimension so that the math works out
95
110
  one_dimensional = False
96
- if message.data.ndim == 1:
97
- arr_in = np.expand_dims(message.data, axis=1)
111
+ if arr_in.ndim == 1:
112
+ arr_in = np.expand_dims(arr_in, axis=1)
98
113
  one_dimensional = True
99
- else:
100
- arr_in = message.data
101
114
 
102
115
  # We will perform filter with time dimension as last axis
103
- arr_in = np.moveaxis(arr_in, message.time_dim, -1)
116
+ arr_in = np.moveaxis(arr_in, axis_idx, -1)
104
117
  samp_shape = arr_in[..., 0].shape
105
118
 
106
119
  # Re-calculate/reset zi if necessary
@@ -118,10 +131,10 @@ class Filter(ez.Unit):
118
131
  self.STATE.filt.b, self.STATE.filt.a, arr_in, zi=self.STATE.zi
119
132
  )
120
133
 
121
- arr_out = np.moveaxis(arr_out, -1, message.time_dim)
134
+ arr_out = np.moveaxis(arr_out, -1, axis_idx)
122
135
 
123
136
  # Remove temporary first dimension if necessary
124
137
  if one_dimensional:
125
138
  arr_out = np.squeeze(arr_out, axis=1)
126
139
 
127
- yield (self.OUTPUT_SIGNAL, replace(message, data=arr_out))
140
+ yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out),
@@ -0,0 +1,31 @@
1
+ import warnings
2
+ import time
3
+
4
+ import numpy.typing as npt
5
+
6
+ from ezmsg.util.messages.axisarray import AxisArray
7
+
8
+ from typing import Optional
9
+
10
+ # UPCOMING: TSMessage Deprecation
11
+ # TSMessage is deprecated because it doesn't handle multiple time axes well.
12
+ # AxisArray has an incompatible API but supports a superset of functionality.
13
+ warnings.warn(
14
+ "TimeSeriesMessage/TSMessage is deprecated. Please use ezmsg.utils.AxisArray",
15
+ DeprecationWarning,
16
+ stacklevel=2,
17
+ )
18
+
19
+
20
+ def TSMessage(
21
+ data: npt.NDArray,
22
+ fs: float = 1.0,
23
+ time_dim: int = 0,
24
+ timestamp: Optional[float] = None,
25
+ ) -> AxisArray:
26
+ dims = [f"dim_{i}" for i in range(data.ndim)]
27
+ dims[time_dim] = "time"
28
+ offset = time.time() if timestamp is None else timestamp
29
+ offset_adj = data.shape[time_dim] / fs # offset corresponds to idx[0] on time_dim
30
+ axis = AxisArray.Axis.TimeAxis(fs, offset=offset - offset_adj)
31
+ return AxisArray(data, dims=dims, axes=dict(time=axis))