ezmsg-sigproc 1.3.2__tar.gz → 1.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/PKG-INFO +1 -1
  2. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/__version__.py +2 -2
  3. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/activation.py +12 -0
  4. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/affinetransform.py +4 -5
  5. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/aggregate.py +5 -1
  6. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/bandpower.py +2 -1
  7. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/base.py +2 -1
  8. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/butterworthfilter.py +11 -7
  9. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/downsample.py +3 -3
  10. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/filter.py +7 -3
  11. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/filterbank.py +6 -4
  12. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/abs.py +7 -1
  13. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/clip.py +12 -2
  14. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/difference.py +9 -0
  15. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/invert.py +8 -2
  16. ezmsg_sigproc-1.4.1/src/ezmsg/sigproc/math/log.py +51 -0
  17. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/scale.py +11 -2
  18. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/sampler.py +1 -1
  19. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/scaler.py +5 -6
  20. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/signalinjector.py +5 -0
  21. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/slicer.py +14 -1
  22. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/spectrogram.py +7 -3
  23. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/spectrum.py +2 -2
  24. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/wavelets.py +5 -3
  25. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/window.py +13 -6
  26. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_affine_transform.py +33 -4
  27. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_aggregate.py +7 -1
  28. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_butter.py +17 -0
  29. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_filterbank.py +2 -0
  30. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_math.py +7 -3
  31. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_slicer.py +46 -0
  32. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_wavelets.py +2 -0
  33. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/uv.lock +1 -1
  34. ezmsg_sigproc-1.3.2/src/ezmsg/sigproc/math/log.py +0 -32
  35. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/.github/workflows/python-publish-ezmsg-sigproc.yml +0 -0
  36. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/.github/workflows/python-tests.yml +0 -0
  37. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/.gitignore +0 -0
  38. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/.pre-commit-config.yaml +0 -0
  39. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/LICENSE.txt +0 -0
  40. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/README.md +0 -0
  41. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/pyproject.toml +0 -0
  42. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/__init__.py +0 -0
  43. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/decimate.py +0 -0
  44. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/ewmfilter.py +0 -0
  45. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/math/__init__.py +0 -0
  46. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/messages.py +0 -0
  47. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/spectral.py +0 -0
  48. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/src/ezmsg/sigproc/synth.py +0 -0
  49. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/conftest.py +0 -0
  50. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/helpers/__init__.py +0 -0
  51. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/helpers/util.py +0 -0
  52. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/resources/xform.csv +0 -0
  53. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_activation.py +0 -0
  54. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_bandpower.py +0 -0
  55. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_butterworth.py +0 -0
  56. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_downsample.py +0 -0
  57. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_sampler.py +0 -0
  58. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_scaler.py +0 -0
  59. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_spectrogram.py +0 -0
  60. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_spectrum.py +0 -0
  61. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_synth.py +0 -0
  62. {ezmsg_sigproc-1.3.2 → ezmsg_sigproc-1.4.1}/tests/test_window.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ezmsg-sigproc
3
- Version: 1.3.2
3
+ Version: 1.4.1
4
4
  Summary: Timeseries signal processing implementations in ezmsg
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>
6
6
  License-Expression: MIT
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.3.2'
16
- __version_tuple__ = version_tuple = (1, 3, 2)
15
+ __version__ = version = '1.4.1'
16
+ __version_tuple__ = version_tuple = (1, 4, 1)
@@ -43,6 +43,18 @@ ACTIVATIONS = {
43
43
  def activation(
44
44
  function: typing.Union[str, ActivationFunction],
45
45
  ) -> typing.Generator[AxisArray, AxisArray, None]:
46
+ """
47
+ Transform the data with a simple activation function.
48
+
49
+ Args:
50
+ function: An enum value from ActivationFunction or a string representing the activation function.
51
+ Possible values are: SIGMOID, EXPIT, LOGIT, LOGEXPIT, "sigmoid", "expit", "logit", "log_expit".
52
+ SIGMOID and EXPIT are equivalent. See :obj:`scipy.special.expit` for more details.
53
+
54
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an AxisArray
55
+ with the data payload containing a transformed version of the input data.
56
+
57
+ """
46
58
  if type(function) is ActivationFunction:
47
59
  func = ACTIVATIONS[function]
48
60
  else:
@@ -24,7 +24,7 @@ def affine_transform(
24
24
  Args:
25
25
  weights: An array of weights or a path to a file with weights compatible with np.loadtxt.
26
26
  axis: The name of the axis to apply the transformation to. Defaults to the leading (0th) axis in the array.
27
- right_multiply: Set False to tranpose the weights before applying.
27
+ right_multiply: Set False to transpose the weights before applying.
28
28
 
29
29
  Returns:
30
30
  A primed generator object that yields an :obj:`AxisArray` object for every
@@ -76,16 +76,15 @@ def affine_transform(
76
76
  ):
77
77
  in_labels = msg_in.axes[axis].labels
78
78
  new_labels = []
79
- n_in = weights.shape[1 if right_multiply else 0]
80
- n_out = weights.shape[0 if right_multiply else 1]
79
+ n_in, n_out = weights.shape
81
80
  if len(in_labels) != n_in:
82
81
  # Something upstream did something it wasn't supposed to. We will drop the labels.
83
82
  ez.logger.warning(
84
83
  f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
85
84
  )
86
85
  else:
87
- b_used_inputs = np.any(weights, axis=0 if right_multiply else 1)
88
- b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0)
86
+ b_filled_outputs = np.any(weights, axis=0)
87
+ b_used_inputs = np.any(weights, axis=1)
89
88
  if np.all(b_used_inputs) and np.all(b_filled_outputs):
90
89
  # All inputs are used and all outputs are used, but n_in != n_out.
91
90
  # Mapping cannot be determined.
@@ -20,11 +20,13 @@ class AggregationFunction(OptionsEnum):
20
20
  MEAN = "mean"
21
21
  MEDIAN = "median"
22
22
  STD = "std"
23
+ SUM = "sum"
23
24
  NANMAX = "nanmax"
24
25
  NANMIN = "nanmin"
25
26
  NANMEAN = "nanmean"
26
27
  NANMEDIAN = "nanmedian"
27
28
  NANSTD = "nanstd"
29
+ NANSUM = "nansum"
28
30
  ARGMIN = "argmin"
29
31
  ARGMAX = "argmax"
30
32
 
@@ -36,11 +38,13 @@ AGGREGATORS = {
36
38
  AggregationFunction.MEAN: np.mean,
37
39
  AggregationFunction.MEDIAN: np.median,
38
40
  AggregationFunction.STD: np.std,
41
+ AggregationFunction.SUM: np.sum,
39
42
  AggregationFunction.NANMAX: np.nanmax,
40
43
  AggregationFunction.NANMIN: np.nanmin,
41
44
  AggregationFunction.NANMEAN: np.nanmean,
42
45
  AggregationFunction.NANMEDIAN: np.nanmedian,
43
46
  AggregationFunction.NANSTD: np.nanstd,
47
+ AggregationFunction.NANSUM: np.nansum,
44
48
  AggregationFunction.ARGMIN: np.argmin,
45
49
  AggregationFunction.ARGMAX: np.argmax,
46
50
  }
@@ -62,7 +66,7 @@ def ranged_aggregate(
62
66
  operation: :obj:`AggregationFunction` to apply to each band.
63
67
 
64
68
  Returns:
65
- A primed generator object ready to yield an AxisArray for each .send(axis_array)
69
+ A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
66
70
  """
67
71
  msg_out = AxisArray(np.array([]), dims=[""])
68
72
 
@@ -27,7 +27,8 @@ def bandpower(
27
27
  bands: (min, max) tuples of band limits in Hz.
28
28
 
29
29
  Returns:
30
- A primed generator object ready to yield an AxisArray for each .send(axis_array)
30
+ A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
31
+ with the data payload being the average spectral power in each band of the input data.
31
32
  """
32
33
  msg_out = AxisArray(np.array([]), dims=[""])
33
34
 
@@ -1,3 +1,4 @@
1
+ import math
1
2
  import traceback
2
3
  import typing
3
4
 
@@ -30,7 +31,7 @@ class GenAxisArray(ez.Unit):
30
31
  async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
31
32
  try:
32
33
  ret = self.STATE.gen.send(message)
33
- if ret.data.size > 0:
34
+ if math.prod(ret.data.shape) > 0:
34
35
  yield self.OUTPUT_SIGNAL, ret
35
36
  except (StopIteration, GeneratorExit):
36
37
  ez.logger.debug(f"Generator closed in {self.address}")
@@ -13,19 +13,22 @@ class ButterworthFilterSettings(FilterSettingsBase):
13
13
  """Settings for :obj:`ButterworthFilter`."""
14
14
 
15
15
  order: int = 0
16
+ """
17
+ Filter order
18
+ """
16
19
 
17
20
  cuton: typing.Optional[float] = None
18
21
  """
19
- Cuton frequency (Hz). If cutoff is not specified then this is the highpass corner, otherwise
20
- if it is lower than cutoff then this is the beginning of the bandpass
21
- or if it is greater than cuton then it is the end of the bandstop.
22
+ Cuton frequency (Hz). If `cutoff` is not specified then this is the highpass corner. Otherwise,
23
+ if this is lower than `cutoff` then this is the beginning of the bandpass
24
+ or if this is greater than `cutoff` then this is the end of the bandstop.
22
25
  """
23
26
 
24
27
  cutoff: typing.Optional[float] = None
25
28
  """
26
- Cutoff frequency (Hz). If cuton is not specified then this is the lowpass corner, otherwise
27
- if it is greater than cuton then this is the end of the bandpass,
28
- or if it is less than cuton then it is the beginning of the bandstop.
29
+ Cutoff frequency (Hz). If `cuton` is not specified then this is the lowpass corner. Otherwise,
30
+ if this is greater than `cuton` then this is the end of the bandpass,
31
+ or if this is less than `cuton` then this is the beginning of the bandstop.
29
32
  """
30
33
 
31
34
  def filter_specs(
@@ -76,7 +79,8 @@ def butter(
76
79
  coef_type: "ba" or "sos"
77
80
 
78
81
  Returns:
79
- A primed generator object which accepts .send(axis_array) and yields filtered axis array.
82
+ A primed generator object which accepts an :obj:`AxisArray` via .send(axis_array)
83
+ and yields an :obj:`AxisArray` with filtered data.
80
84
 
81
85
  """
82
86
  # IO
@@ -25,10 +25,10 @@ def downsample(
25
25
  factor: Downsampling factor.
26
26
 
27
27
  Returns:
28
- A primed generator object ready to receive a `.send(axis_array)`
29
- and yields the downsampled data.
28
+ A primed generator object ready to receive an :obj:`AxisArray` via `.send(axis_array)`
29
+ and yields an :obj:`AxisArray` with its data downsampled.
30
30
  Note that if a send chunk does not have sufficient samples to reach the
31
- next downsample interval then `None` is yielded.
31
+ next downsample interval then an :obj:`AxisArray` with size-zero data is yielded.
32
32
 
33
33
  """
34
34
  msg_out = AxisArray(np.array([]), dims=[""])
@@ -38,7 +38,7 @@ def filtergen(
38
38
  axis: str, coefs: typing.Optional[typing.Tuple[np.ndarray]], coef_type: str
39
39
  ) -> typing.Generator[AxisArray, AxisArray, None]:
40
40
  """
41
- Construct a generic filter generator function.
41
+ Filter data using the provided coefficients.
42
42
 
43
43
  Args:
44
44
  axis: The name of the axis to operate on.
@@ -46,7 +46,8 @@ def filtergen(
46
46
  coef_type: The type of filter coefficients. One of "ba" or "sos".
47
47
 
48
48
  Returns:
49
- A generator that expects .send(axis_array) and yields the filtered :obj:`AxisArray`.
49
+ A primed generator that, when passed an :obj:`AxisArray` via `.send(axis_array)`,
50
+ yields an :obj:`AxisArray` with the data filtered.
50
51
  """
51
52
  # Massage inputs
52
53
  if coefs is not None and not isinstance(coefs, tuple):
@@ -97,7 +98,10 @@ def filtergen(
97
98
  n_tile = (1,) + n_tile
98
99
  zi = np.tile(zi[zi_expand], n_tile)
99
100
 
100
- dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
101
+ if msg_in.data.size > 0:
102
+ dat_out, zi = filt_func(*coefs, msg_in.data, axis=axis_idx, zi=zi)
103
+ else:
104
+ dat_out = msg_in.data
101
105
  msg_out = replace(msg_in, data=dat_out)
102
106
 
103
107
 
@@ -43,9 +43,9 @@ def filterbank(
43
43
  new_axis: str = "kernel",
44
44
  ) -> typing.Generator[AxisArray, AxisArray, None]:
45
45
  """
46
- Returns a generator that perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
47
- This generator is intended to be used during online processing, therefore both direct and fft convolutions
48
- use the overlap-add method.
46
+ Perform multiple (direct or fft) convolutions on a signal using a bank of kernels.
47
+ This is intended to be used during online processing, therefore both direct and fft convolutions
48
+ use the overlap-add method.
49
49
  Args:
50
50
  kernels:
51
51
  mode: "conv", "fft", or "auto". If "auto", the mode is determined by the size of the input data.
@@ -59,7 +59,8 @@ def filterbank(
59
59
  axis: The name of the axis to operate on. This should usually be "time".
60
60
  new_axis: The name of the new axis corresponding to the kernel index.
61
61
 
62
- Returns:
62
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
63
+ with the data payload containing the absolute value of the input :obj:`AxisArray` data.
63
64
 
64
65
  """
65
66
  msg_out: typing.Optional[AxisArray] = None
@@ -133,6 +134,7 @@ def filterbank(
133
134
  + msg_in.dims[targ_ax_ix + 1 :]
134
135
  + [new_axis, axis],
135
136
  axes=msg_in.axes.copy(), # We do not have info for kernel/filter axis :(.
137
+ key=msg_in.key,
136
138
  )
137
139
 
138
140
  # Determine optimal mode. Assumes 100 msec chunks.
@@ -11,9 +11,15 @@ from ..base import GenAxisArray
11
11
 
12
12
  @consumer
13
13
  def abs() -> typing.Generator[AxisArray, AxisArray, None]:
14
+ """
15
+ Take the absolute value of the data. See :obj:`np.abs` for more details.
16
+
17
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
18
+ with the data payload containing the absolute value of the input :obj:`AxisArray` data.
19
+ """
14
20
  msg_out = AxisArray(np.array([]), dims=[""])
15
21
  while True:
16
- msg_in = yield msg_out
22
+ msg_in: AxisArray = yield msg_out
17
23
  msg_out = replace(msg_in, data=np.abs(msg_in.data))
18
24
 
19
25
 
@@ -11,10 +11,20 @@ from ..base import GenAxisArray
11
11
 
12
12
  @consumer
13
13
  def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, None]:
14
- msg_in = AxisArray(np.array([]), dims=[""])
14
+ """
15
+ Clips the data to be within the specified range. See :obj:`np.clip` for more details.
16
+
17
+ Args:
18
+ a_min: Lower clip bound
19
+ a_max: Upper clip bound
20
+
21
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
22
+ with the data payload containing the clipped version of the input :obj:`AxisArray` data.
23
+
24
+ """
15
25
  msg_out = AxisArray(np.array([]), dims=[""])
16
26
  while True:
17
- msg_in = yield msg_out
27
+ msg_in: AxisArray = yield msg_out
18
28
  msg_out = replace(msg_in, data=np.clip(msg_in.data, a_min, a_max))
19
29
 
20
30
 
@@ -16,6 +16,15 @@ def const_difference(
16
16
  """
17
17
  result = (in_data - value) if subtrahend else (value - in_data)
18
18
  https://en.wikipedia.org/wiki/Template:Arithmetic_operations
19
+
20
+ Args:
21
+ value: number to subtract or be subtracted from the input data
22
+ subtrahend: If True (default) then value is subtracted from the input data.
23
+ If False, the input data is subtracted from value.
24
+
25
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
26
+ with the data payload containing the difference between the input :obj:`AxisArray` data and the value.
27
+
19
28
  """
20
29
  msg_out = AxisArray(np.array([]), dims=[""])
21
30
  while True:
@@ -11,10 +11,16 @@ from ..base import GenAxisArray
11
11
 
12
12
  @consumer
13
13
  def invert() -> typing.Generator[AxisArray, AxisArray, None]:
14
- msg_in = AxisArray(np.array([]), dims=[""])
14
+ """
15
+ Take the inverse of the data.
16
+
17
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
18
+ with the data payload containing the inversion of the input :obj:`AxisArray` data.
19
+
20
+ """
15
21
  msg_out = AxisArray(np.array([]), dims=[""])
16
22
  while True:
17
- msg_in = yield msg_out
23
+ msg_in: AxisArray = yield msg_out
18
24
  msg_out = replace(msg_in, data=1 / msg_in.data)
19
25
 
20
26
 
@@ -0,0 +1,51 @@
1
+ from dataclasses import replace
2
+ import typing
3
+
4
+ import numpy as np
5
+ import ezmsg.core as ez
6
+ from ezmsg.util.generator import consumer
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+
9
+ from ..base import GenAxisArray
10
+
11
+
12
+ @consumer
13
+ def log(
14
+ base: float = 10.0,
15
+ clip_zero: bool = False,
16
+ ) -> typing.Generator[AxisArray, AxisArray, None]:
17
+ """
18
+ Take the logarithm of the data. See :obj:`np.log` for more details.
19
+
20
+ Args:
21
+ base: The base of the logarithm. Default is 10.
22
+ clip_zero: If True, clip the data to the minimum positive value of the data type before taking the log.
23
+
24
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
25
+ with the data payload containing the logarithm of the input :obj:`AxisArray` data.
26
+
27
+ """
28
+ msg_out = AxisArray(np.array([]), dims=[""])
29
+ log_base = np.log(base)
30
+ while True:
31
+ msg_in: AxisArray = yield msg_out
32
+ if (
33
+ clip_zero
34
+ and np.any(msg_in.data <= 0)
35
+ and np.issubdtype(msg_in.data.dtype, np.floating)
36
+ ):
37
+ msg_in.data = np.clip(
38
+ msg_in.data, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None
39
+ )
40
+ msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
41
+
42
+
43
+ class LogSettings(ez.Settings):
44
+ base: float = 10.0
45
+
46
+
47
+ class Log(GenAxisArray):
48
+ SETTINGS = LogSettings
49
+
50
+ def construct_generator(self):
51
+ self.STATE.gen = log(base=self.SETTINGS.base)
@@ -11,10 +11,19 @@ from ..base import GenAxisArray
11
11
 
12
12
  @consumer
13
13
  def scale(scale: float = 1.0) -> typing.Generator[AxisArray, AxisArray, None]:
14
- msg_in = AxisArray(np.array([]), dims=[""])
14
+ """
15
+ Scale the data by a constant factor.
16
+
17
+ Args:
18
+ scale: Factor by which to scale the data magnitude.
19
+
20
+ Returns: A primed generator that, when passed an input message via `.send(msg)`, yields an :obj:`AxisArray`
21
+ with the data payload containing the input :obj:`AxisArray` data scaled by a constant factor.
22
+
23
+ """
15
24
  msg_out = AxisArray(np.array([]), dims=[""])
16
25
  while True:
17
- msg_in = yield msg_out
26
+ msg_in: AxisArray = yield msg_out
18
27
  msg_out = replace(msg_in, data=scale * msg_in.data)
19
28
 
20
29
 
@@ -43,7 +43,7 @@ def sampler(
43
43
  typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None
44
44
  ]:
45
45
  """
46
- A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled
46
+ Sample data into a buffer, accept triggers, and return slices of sampled
47
47
  data around the trigger time.
48
48
 
49
49
  Args:
@@ -33,8 +33,7 @@ def scaler(
33
33
  time_constant: float = 1.0, axis: typing.Optional[str] = None
34
34
  ) -> typing.Generator[AxisArray, AxisArray, None]:
35
35
  """
36
- Create a generator function that applies the
37
- adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
36
+ Apply the adaptive standard scaler from https://riverml.xyz/latest/api/preprocessing/AdaptiveStandardScaler/
38
37
  This is faster than :obj:`scaler_np` for single-channel data.
39
38
 
40
39
  Args:
@@ -42,8 +41,8 @@ def scaler(
42
41
  axis: The name of the axis to accumulate statistics over.
43
42
 
44
43
  Returns:
45
- A primed generator object that expects `.send(axis_array)` and yields a
46
- standardized, or "Z-scored" version of the input.
44
+ A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
45
+ and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
47
46
  """
48
47
  from river import preprocessing
49
48
 
@@ -90,8 +89,8 @@ def scaler_np(
90
89
  axis: The name of the axis to accumulate statistics over.
91
90
 
92
91
  Returns:
93
- A primed generator object that expects `.send(axis_array)` and yields a
94
- standardized, or "Z-scored" version of the input.
92
+ A primed generator object that expects to be sent a :obj:`AxisArray` via `.send(axis_array)`
93
+ and yields an :obj:`AxisArray` with its data being a standardized, or "Z-scored" version of the input data.
95
94
  """
96
95
  msg_out = AxisArray(np.array([]), dims=[""])
97
96
 
@@ -22,6 +22,11 @@ class SignalInjectorState(ez.State):
22
22
 
23
23
 
24
24
  class SignalInjector(ez.Unit):
25
+ """
26
+ Add a sinusoidal signal to the input signal. Each feature gets a different amplitude of the sinusoid.
27
+ All features get the same frequency sinusoid. The frequency and base amplitude can be changed while running.
28
+ """
29
+
25
30
  SETTINGS = SignalInjectorSettings
26
31
  STATE = SignalInjectorState
27
32
 
@@ -48,6 +48,18 @@ def parse_slice(s: str) -> typing.Tuple[typing.Union[slice, int], ...]:
48
48
  def slicer(
49
49
  selection: str = "", axis: typing.Optional[str] = None
50
50
  ) -> typing.Generator[AxisArray, AxisArray, None]:
51
+ """
52
+ Slice along a particular axis.
53
+
54
+ Args:
55
+ selection: See :obj:`ezmsg.sigproc.slicer.parse_slice` for details.
56
+ axis: The name of the axis to slice along. If None, the last axis is used.
57
+
58
+ Returns:
59
+ A primed generator object ready to yield an :obj:`AxisArray` for each .send(axis_array)
60
+ with the data payload containing a sliced view of the input data.
61
+
62
+ """
51
63
  msg_out = AxisArray(np.array([]), dims=[""])
52
64
 
53
65
  # State variables
@@ -98,7 +110,8 @@ def slicer(
98
110
  and hasattr(msg_in.axes[axis], "labels")
99
111
  and len(msg_in.axes[axis].labels) > 0
100
112
  ):
101
- new_labels = msg_in.axes[axis].labels[_slice]
113
+ in_labels = np.array(msg_in.axes[axis].labels)
114
+ new_labels = in_labels[_slice].tolist()
102
115
  new_axis = replace(msg_in.axes[axis], labels=new_labels)
103
116
 
104
117
  replace_kwargs = {}
@@ -33,13 +33,17 @@ def spectrogram(
33
33
  output: See :obj:`ezmsg.sigproc.spectrum.spectrum`
34
34
 
35
35
  Returns:
36
- A primed generator object that expects `.send(axis_array)` of continuous data
37
- and yields an AxisArray of time-frequency power values.
36
+ A primed generator object that expects an :obj:`AxisArray` via `.send(axis_array)`
37
+ with continuous data in its .data payload, and yields an :obj:`AxisArray` of time-frequency power values.
38
38
  """
39
39
 
40
40
  pipeline = compose(
41
41
  windowing(
42
- axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift
42
+ axis="time",
43
+ newaxis="win",
44
+ window_dur=window_dur,
45
+ window_shift=window_shift,
46
+ zero_pad_until="shift" if window_shift is not None else "input",
43
47
  ),
44
48
  spectrum(axis="time", window=window, transform=transform, output=output),
45
49
  modify_axis(name_map={"win": "time"}),
@@ -92,8 +92,8 @@ def spectrum(
92
92
  nfft: The number of points to use for the FFT. If None, the length of the input data is used.
93
93
 
94
94
  Returns:
95
- A primed generator object that expects `.send(axis_array)` of continuous data
96
- and yields an AxisArray of spectral magnitudes or powers.
95
+ A primed generator object that expects an :obj:`AxisArray` via `.send(axis_array)` containing continuous data
96
+ and yields an :obj:`AxisArray` with data of spectral magnitudes or powers.
97
97
  """
98
98
  msg_out = AxisArray(np.array([]), dims=[""])
99
99
 
@@ -20,8 +20,8 @@ def cwt(
20
20
  axis: str = "time",
21
21
  ) -> typing.Generator[AxisArray, AxisArray, None]:
22
22
  """
23
- Build a generator to perform a continuous wavelet transform on sent AxisArray messages.
24
- The function is equivalent to the `pywt.cwt` function, but is designed to work with streaming data.
23
+ Perform a continuous wavelet transform.
24
+ The function is equivalent to the :obj:`pywt.cwt` function, but is designed to work with streaming data.
25
25
 
26
26
  Args:
27
27
  scales: The wavelet scales to use.
@@ -31,7 +31,8 @@ def cwt(
31
31
  because fft and matrix multiplication is much faster on the last axis.
32
32
 
33
33
  Returns:
34
- A Generator object that expects `.send(axis_array)` of continuous data
34
+ A primed Generator object that expects an :obj:`AxisArray` via `.send(axis_array)` of continuous data
35
+ and yields an :obj:`AxisArray` with a continuous wavelet transform in its data.
35
36
  """
36
37
  msg_out: typing.Optional[AxisArray] = None
37
38
 
@@ -114,6 +115,7 @@ def cwt(
114
115
  **msg_in.axes,
115
116
  "freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep),
116
117
  },
118
+ key=msg_in.key,
117
119
  )
118
120
  last_conv_samp = np.zeros(
119
121
  dummy_shape[:-1] + (1,), dtype=template.data.dtype
@@ -24,7 +24,9 @@ def windowing(
24
24
  zero_pad_until: str = "input",
25
25
  ) -> typing.Generator[AxisArray, AxisArray, None]:
26
26
  """
27
- Construct a generator that yields windows of data from an input :obj:`AxisArray`.
27
+ Apply a sliding window along the specified axis to input streaming data.
28
+ The `windowing` method is perhaps the most useful and versatile method in ezmsg.sigproc, but its parameterization
29
+ can be difficult. Please read the argument descriptions carefully.
28
30
 
29
31
  Args:
30
32
  axis: The axis along which to segment windows.
@@ -48,8 +50,8 @@ def windowing(
48
50
  - "none" does not pad the buffer. No outputs will be yielded until at least `window_dur` data has been seen.
49
51
 
50
52
  Returns:
51
- A (primed) generator that accepts .send(an AxisArray object) and yields a list of windowed
52
- AxisArray objects. The list will always be length-1 if `newaxis` is not None or `window_shift` is None.
53
+ A primed generator that accepts an :obj:`AxisArray` via `.send(axis_array)`
54
+ and yields an :obj:`AxisArray` with the data payload containing a windowed version of the input data.
53
55
  """
54
56
  # Check arguments
55
57
  if newaxis is None:
@@ -77,7 +79,7 @@ def windowing(
77
79
  b_1to1 = window_shift is None
78
80
  newaxis_warned: bool = b_1to1
79
81
  out_newaxis: typing.Optional[AxisArray.Axis] = None
80
- out_dims: typing.typing.Optional[typing.List[str]] = None
82
+ out_dims: typing.Optional[typing.List[str]] = None
81
83
 
82
84
  check_inputs = {"samp_shape": None, "fs": None, "key": None}
83
85
 
@@ -123,11 +125,12 @@ def windowing(
123
125
  else: # i.e. zero_pad_until == "input"
124
126
  req_samples = msg_in.data.shape[axis_idx]
125
127
  n_zero = max(0, window_samples - req_samples)
126
- buffer = np.zeros(
128
+ init_buffer_shape = (
127
129
  msg_in.data.shape[:axis_idx]
128
130
  + (n_zero,)
129
131
  + msg_in.data.shape[axis_idx + 1 :]
130
132
  )
133
+ buffer = np.zeros(init_buffer_shape, dtype=msg_in.data.dtype)
131
134
 
132
135
  # Add new data to buffer.
133
136
  # Currently, we concatenate the new time samples and clip the output.
@@ -174,10 +177,14 @@ def windowing(
174
177
  if b_1to1:
175
178
  # one-to-one mode -- Each send yields exactly one window containing only the most recent samples.
176
179
  buffer = slice_along_axis(buffer, slice(-window_samples, None), axis_idx)
177
- out_dat = np.expand_dims(buffer, axis=axis_idx)
180
+ out_dat = buffer.reshape(
181
+ buffer.shape[:axis_idx] + (1,) + buffer.shape[axis_idx:]
182
+ )
178
183
  out_newaxis = replace(out_newaxis, offset=buffer_offset[-window_samples])
179
184
  elif buffer.shape[axis_idx] >= window_samples:
180
185
  # Deterministic window shifts.
186
+ # Note: After https://github.com/ezmsg-org/ezmsg/pull/152, add `window_shift_samples` as the last arg
187
+ # to `sliding_win_oneaxis` and remove the call to `slice_along_axis`.
181
188
  out_dat = sliding_win_oneaxis(buffer, window_samples, axis_idx)
182
189
  out_dat = slice_along_axis(
183
190
  out_dat, slice(None, None, window_shift_samples), axis_idx
@@ -1,5 +1,7 @@
1
1
  import copy
2
+ from dataclasses import dataclass, field
2
3
  from pathlib import Path
4
+ import typing
3
5
 
4
6
  import numpy as np
5
7
  from ezmsg.util.messages.axisarray import AxisArray
@@ -9,11 +11,33 @@ from ezmsg.sigproc.affinetransform import affine_transform, common_rereference
9
11
  from util import assert_messages_equal
10
12
 
11
13
 
14
+ # Define a custom Axis class that has a `labels` attribute
15
+ @dataclass
16
+ class CustomAxis(AxisArray.Axis):
17
+ labels: typing.List[str] = field(default_factory=lambda: [])
18
+
19
+ @classmethod
20
+ def SpaceAxis(
21
+ cls, labels: typing.List[str]
22
+ ): # , locs: typing.Optional[npt.NDArray] = None):
23
+ return cls(unit="mm", labels=labels)
24
+
25
+
26
+ # Monkey-patch AxisArray with our customized Axis
27
+ AxisArray.Axis = CustomAxis
28
+
29
+
12
30
  def test_affine_generator():
13
31
  n_times = 13
14
32
  n_chans = 64
15
33
  in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans)
16
- msg_in = AxisArray(in_dat, dims=["time", "ch"])
34
+ msg_in = AxisArray(
35
+ data=in_dat,
36
+ dims=["time", "ch"],
37
+ axes={
38
+ "ch": AxisArray.Axis.SpaceAxis(labels=[f"ch_{i}" for i in range(n_chans)])
39
+ },
40
+ )
17
41
 
18
42
  backup = [copy.deepcopy(msg_in)]
19
43
 
@@ -25,33 +49,38 @@ def test_affine_generator():
25
49
 
26
50
  assert_messages_equal([msg_in], backup)
27
51
 
52
+ # Send again just to make sure the generator doesn't crash
53
+ _ = gen.send(msg_in)
54
+
28
55
  # Test with weights from a CSV file.
29
56
  csv_path = Path(__file__).parent / "resources" / "xform.csv"
30
57
  weights = np.loadtxt(csv_path, delimiter=",")
31
58
  expected_out = in_dat @ weights.T
32
59
  # Same result: expected_out = np.vstack([(step[None, :] * weights).sum(axis=1) for step in in_dat])
33
60
 
34
- # Send again just to make sure the generator doesn't crash
35
- _ = gen.send(msg_in)
36
-
37
61
  gen = affine_transform(weights=csv_path, axis="ch", right_multiply=False)
38
62
  msg_out = gen.send(msg_in)
39
63
  assert np.allclose(msg_out.data, expected_out)
64
+ assert len(msg_out.axes["ch"].labels) == weights.shape[0]
65
+ assert msg_out.axes["ch"].labels[:-1] == msg_in.axes["ch"].labels
40
66
 
41
67
  # Try again as str, not Path
42
68
  gen = affine_transform(weights=str(csv_path), axis="ch", right_multiply=False)
43
69
  msg_out = gen.send(msg_in)
44
70
  assert np.allclose(msg_out.data, expected_out)
71
+ assert len(msg_out.axes["ch"].labels) == weights.shape[0]
45
72
 
46
73
  # Try again as direct ndarray
47
74
  gen = affine_transform(weights=weights, axis="ch", right_multiply=False)
48
75
  msg_out = gen.send(msg_in)
49
76
  assert np.allclose(msg_out.data, expected_out)
77
+ assert len(msg_out.axes["ch"].labels) == weights.shape[0]
50
78
 
51
79
  # One more time, but we pre-transpose the weights and do not override right_multiply
52
80
  gen = affine_transform(weights=weights.T, axis="ch", right_multiply=True)
53
81
  msg_out = gen.send(msg_in)
54
82
  assert np.allclose(msg_out.data, expected_out)
83
+ assert len(msg_out.axes["ch"].labels) == weights.shape[0]
55
84
 
56
85
 
57
86
  def test_affine_passthrough():
@@ -37,7 +37,12 @@ def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""):
37
37
 
38
38
  @pytest.mark.parametrize(
39
39
  "agg_func",
40
- [AggregationFunction.MEAN, AggregationFunction.MEDIAN, AggregationFunction.STD],
40
+ [
41
+ AggregationFunction.MEAN,
42
+ AggregationFunction.MEDIAN,
43
+ AggregationFunction.STD,
44
+ AggregationFunction.SUM,
45
+ ],
41
46
  )
42
47
  def test_aggregate(agg_func: AggregationFunction):
43
48
  bands = [(5.0, 20.0), (30.0, 50.0)]
@@ -74,6 +79,7 @@ def test_aggregate(agg_func: AggregationFunction):
74
79
  AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True),
75
80
  AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True),
76
81
  AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True),
82
+ AggregationFunction.SUM: partial(np.sum, axis=-1, keepdims=True),
77
83
  }[agg_func]
78
84
  expected_data = np.concatenate(
79
85
  [
@@ -138,3 +138,20 @@ def test_butterworth(
138
138
 
139
139
  result = np.concatenate([gen.send(_).data for _ in messages], axis=time_ax)
140
140
  assert np.allclose(result, out_dat)
141
+
142
+
143
+ def test_butterworth_empty_msg():
144
+ proc = butter(
145
+ axis="time",
146
+ order=2,
147
+ cuton=0.1,
148
+ cutoff=1.0,
149
+ coef_type="sos",
150
+ )
151
+ msg_in = AxisArray(
152
+ data=np.zeros((0, 2)),
153
+ dims=["time", "ch"],
154
+ axes={"time": AxisArray.Axis.TimeAxis(fs=19.0, offset=0)},
155
+ )
156
+ res = proc.send(msg_in)
157
+ assert res.data.size == 0
@@ -63,6 +63,7 @@ def test_filterbank(mode: str, kernel_type: str):
63
63
  data=chirp[:, idx : idx + step_size],
64
64
  dims=["ch", "time"],
65
65
  axes={"time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs)},
66
+ key="test_filterbank",
66
67
  )
67
68
  )
68
69
 
@@ -91,6 +92,7 @@ def test_filterbank(mode: str, kernel_type: str):
91
92
  # Pass the messages
92
93
  out_messages = [gen.send(msg_in) for msg_in in in_messages]
93
94
  result = AxisArray.concatenate(*out_messages, dim="time")
95
+ assert result.key == "test_filterbank"
94
96
 
95
97
  # Compare to sps.oaconvolve(chirp), with the following differences:
96
98
  # - conv has transients at the beginning that we need to skip over
@@ -62,13 +62,17 @@ def test_invert():
62
62
 
63
63
 
64
64
  @pytest.mark.parametrize("base", [np.e, 2, 10])
65
- def test_log(base: float):
65
+ @pytest.mark.parametrize("dtype", [int, float])
66
+ @pytest.mark.parametrize("clip", [False, True])
67
+ def test_log(base: float, dtype, clip: bool):
66
68
  n_times = 130
67
69
  n_chans = 255
68
- in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans)
70
+ in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans).astype(dtype)
69
71
  msg_in = AxisArray(in_dat, dims=["time", "ch"])
70
- proc = log(base)
72
+ proc = log(base, clip_zero=clip)
71
73
  msg_out = proc.send(msg_in)
74
+ if clip and dtype is float:
75
+ in_dat = np.clip(in_dat, a_min=np.finfo(msg_in.data.dtype).tiny, a_max=None)
72
76
  assert np.array_equal(msg_out.data, np.log(in_dat) / np.log(base))
73
77
 
74
78
 
@@ -1,13 +1,31 @@
1
1
  import copy
2
+ from dataclasses import dataclass, field
3
+ import typing
2
4
 
3
5
  import numpy as np
4
6
  from ezmsg.util.messages.axisarray import AxisArray
7
+ import pytest
5
8
 
6
9
  from ezmsg.sigproc.slicer import slicer, parse_slice
7
10
 
8
11
  from util import assert_messages_equal
9
12
 
10
13
 
14
+ @dataclass
15
+ class CustomAxis(AxisArray.Axis):
16
+ labels: typing.List[str] = field(default_factory=lambda: [])
17
+
18
+ @classmethod
19
+ def SpaceAxis(
20
+ cls, labels: typing.List[str]
21
+ ): # , locs: typing.Optional[npt.NDArray] = None):
22
+ return cls(unit="mm", labels=labels)
23
+
24
+
25
+ # Monkey-patch AxisArray with our customized Axis
26
+ AxisArray.Axis = CustomAxis
27
+
28
+
11
29
  def test_parse_slice():
12
30
  assert parse_slice("") == (slice(None),)
13
31
  assert parse_slice(":") == (slice(None),)
@@ -76,3 +94,31 @@ def test_slicer_gen_drop_dim():
76
94
  assert_messages_equal([msg_in], backup)
77
95
  assert msg_out.data.shape == (n_times,)
78
96
  assert np.array_equal(msg_out.data, msg_in.data[:, 5])
97
+
98
+
99
+ @pytest.mark.parametrize("selection", [":3", "0, 1, 2"])
100
+ def test_slicer_label(selection: str):
101
+ """
102
+ We use the monkey-patched AxisArray `labels` field that exists in several other ezmsg
103
+ modules that generate data.
104
+ """
105
+ n_times = 50
106
+ n_chans = 10
107
+ in_dat = np.arange(n_times * n_chans).reshape(n_times, n_chans)
108
+ msg_in = AxisArray(
109
+ in_dat,
110
+ dims=["time", "ch"],
111
+ axes={
112
+ "time": AxisArray.Axis.TimeAxis(fs=100.0, offset=0.1),
113
+ "ch": CustomAxis.SpaceAxis(labels=[str(_) for _ in range(n_chans)]),
114
+ },
115
+ )
116
+ backup = [copy.deepcopy(msg_in)]
117
+
118
+ gen = slicer(selection=selection, axis="ch")
119
+ # gen = slicer(selection=":3", axis="ch")
120
+ msg_out = gen.send(msg_in)
121
+ assert_messages_equal([msg_in], backup)
122
+ assert msg_out.data.shape == (n_times, 3)
123
+ assert np.array_equal(msg_out.data, msg_in.data[:, :3])
124
+ assert msg_out.axes["ch"].labels == msg_in.axes["ch"].labels[:3]
@@ -89,6 +89,7 @@ def test_cwt():
89
89
  data=chirp[:, idx : idx + step_size],
90
90
  dims=["ch", "time"],
91
91
  axes={"time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs)},
92
+ key="test_cwt",
92
93
  )
93
94
  )
94
95
 
@@ -106,6 +107,7 @@ def test_cwt():
106
107
  out_messages = [gen.send(in_messages[0])]
107
108
  out_messages += [gen.send(msg_in) for msg_in in in_messages[1:]]
108
109
  result = AxisArray.concatenate(*out_messages, dim="time")
110
+ assert result.key == "test_cwt"
109
111
 
110
112
  # TODO: Compare result to expected
111
113
 
@@ -125,7 +125,7 @@ wheels = [
125
125
 
126
126
  [[package]]
127
127
  name = "ezmsg-sigproc"
128
- version = "1.3.2.dev0+g021511b.d20240926"
128
+ version = "1.3.3.dev10+g5aaad7f.d20241101"
129
129
  source = { editable = "." }
130
130
  dependencies = [
131
131
  { name = "ezmsg" },
@@ -1,32 +0,0 @@
1
- from dataclasses import replace
2
- import typing
3
-
4
- import numpy as np
5
- import ezmsg.core as ez
6
- from ezmsg.util.generator import consumer
7
- from ezmsg.util.messages.axisarray import AxisArray
8
-
9
- from ..base import GenAxisArray
10
-
11
-
12
- @consumer
13
- def log(
14
- base: float = 10.0,
15
- ) -> typing.Generator[AxisArray, AxisArray, None]:
16
- msg_in = AxisArray(np.array([]), dims=[""])
17
- msg_out = AxisArray(np.array([]), dims=[""])
18
- log_base = np.log(base)
19
- while True:
20
- msg_in = yield msg_out
21
- msg_out = replace(msg_in, data=np.log(msg_in.data) / log_base)
22
-
23
-
24
- class LogSettings(ez.Settings):
25
- base: float = 10.0
26
-
27
-
28
- class Log(GenAxisArray):
29
- SETTINGS = LogSettings
30
-
31
- def construct_generator(self):
32
- self.STATE.gen = log(base=self.SETTINGS.base)
File without changes
File without changes
File without changes