ezmsg-sigproc 1.8.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +36 -39
  3. ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
  4. ezmsg/sigproc/affinetransform.py +169 -163
  5. ezmsg/sigproc/aggregate.py +133 -101
  6. ezmsg/sigproc/bandpower.py +64 -52
  7. ezmsg/sigproc/base.py +1242 -0
  8. ezmsg/sigproc/butterworthfilter.py +37 -33
  9. ezmsg/sigproc/cheby.py +29 -17
  10. ezmsg/sigproc/combfilter.py +163 -0
  11. ezmsg/sigproc/decimate.py +19 -10
  12. ezmsg/sigproc/detrend.py +29 -0
  13. ezmsg/sigproc/diff.py +81 -0
  14. ezmsg/sigproc/downsample.py +78 -84
  15. ezmsg/sigproc/ewma.py +197 -0
  16. ezmsg/sigproc/extract_axis.py +41 -0
  17. ezmsg/sigproc/filter.py +257 -141
  18. ezmsg/sigproc/filterbank.py +247 -199
  19. ezmsg/sigproc/math/abs.py +17 -22
  20. ezmsg/sigproc/math/clip.py +24 -24
  21. ezmsg/sigproc/math/difference.py +34 -30
  22. ezmsg/sigproc/math/invert.py +13 -25
  23. ezmsg/sigproc/math/log.py +28 -33
  24. ezmsg/sigproc/math/scale.py +18 -26
  25. ezmsg/sigproc/quantize.py +71 -0
  26. ezmsg/sigproc/resample.py +298 -0
  27. ezmsg/sigproc/sampler.py +241 -259
  28. ezmsg/sigproc/scaler.py +55 -218
  29. ezmsg/sigproc/signalinjector.py +52 -43
  30. ezmsg/sigproc/slicer.py +81 -89
  31. ezmsg/sigproc/spectrogram.py +77 -75
  32. ezmsg/sigproc/spectrum.py +203 -168
  33. ezmsg/sigproc/synth.py +546 -393
  34. ezmsg/sigproc/transpose.py +131 -0
  35. ezmsg/sigproc/util/asio.py +156 -0
  36. ezmsg/sigproc/util/message.py +31 -0
  37. ezmsg/sigproc/util/profile.py +55 -12
  38. ezmsg/sigproc/util/typeresolution.py +83 -0
  39. ezmsg/sigproc/wavelets.py +154 -153
  40. ezmsg/sigproc/window.py +269 -211
  41. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/METADATA +2 -1
  42. ezmsg_sigproc-2.1.0.dist-info/RECORD +51 -0
  43. ezmsg_sigproc-1.8.2.dist-info/RECORD +0 -39
  44. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/WHEEL +0 -0
  45. {ezmsg_sigproc-1.8.2.dist-info → ezmsg_sigproc-2.1.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,131 @@
1
+ from types import EllipsisType
2
+ import numpy as np
3
+ import ezmsg.core as ez
4
+ from ezmsg.util.messages.axisarray import (
5
+ AxisArray,
6
+ replace,
7
+ )
8
+
9
+ from .base import (
10
+ BaseStatefulTransformer,
11
+ BaseTransformerUnit,
12
+ processor_state,
13
+ )
14
+
15
+
16
+ class TransposeSettings(ez.Settings):
17
+ """
18
+ Settings for :obj:`Transpose` node.
19
+
20
+ Fields:
21
+ axes:
22
+ """
23
+
24
+ axes: tuple[int | str | EllipsisType, ...] | None = None
25
+ order: str | None = None
26
+
27
+
28
+ @processor_state
29
+ class TransposeState:
30
+ axes_ints: tuple[int, ...] | None = None
31
+
32
+
33
+ class TransposeTransformer(
34
+ BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]
35
+ ):
36
+ """
37
+ Downsampled data simply comprise every `factor`th sample.
38
+ This should only be used following appropriate lowpass filtering.
39
+ If your pipeline does not already have lowpass filtering then consider
40
+ using the :obj:`Decimate` collection instead.
41
+ """
42
+
43
+ def _hash_message(self, message: AxisArray) -> int:
44
+ return hash(tuple(message.dims))
45
+
46
+ def _reset_state(self, message: AxisArray) -> None:
47
+ if self.settings.axes is None:
48
+ self._state.axes_ints = None
49
+ else:
50
+ ell_ix = [ix for ix, ax in enumerate(self.settings.axes) if ax is Ellipsis]
51
+ if len(ell_ix) > 1:
52
+ raise ValueError("Only one Ellipsis is allowed in axes.")
53
+ ell_ix = ell_ix[0] if len(ell_ix) == 1 else len(message.dims)
54
+ prefix = []
55
+ for ax in self.settings.axes[:ell_ix]:
56
+ if isinstance(ax, int):
57
+ prefix.append(ax)
58
+ else:
59
+ if ax not in message.dims:
60
+ raise ValueError(f"Axis {ax} not found in message dims.")
61
+ prefix.append(message.dims.index(ax))
62
+ suffix = []
63
+ for ax in self.settings.axes[ell_ix + 1 :]:
64
+ if isinstance(ax, int):
65
+ suffix.append(ax)
66
+ else:
67
+ if ax not in message.dims:
68
+ raise ValueError(f"Axis {ax} not found in message dims.")
69
+ suffix.append(message.dims.index(ax))
70
+ ells = [
71
+ _
72
+ for _ in range(message.data.ndim)
73
+ if _ not in prefix and _ not in suffix
74
+ ]
75
+ re_ix = tuple(prefix + ells + suffix)
76
+ if re_ix == tuple(range(message.data.ndim)):
77
+ self._state.axes_ints = None
78
+ else:
79
+ self._state.axes_ints = re_ix
80
+ if self.settings.order is not None and self.settings.order.upper()[0] not in [
81
+ "C",
82
+ "F",
83
+ ]:
84
+ raise ValueError("order must be 'C' or 'F'.")
85
+
86
+ def __call__(self, message: AxisArray) -> AxisArray:
87
+ if self.settings.axes is None and self.settings.order is None:
88
+ # Passthrough
89
+ return message
90
+ return super().__call__(message)
91
+
92
+ def _process(self, message: AxisArray) -> AxisArray:
93
+ if self.state.axes_ints is None:
94
+ # No transpose required
95
+ if self.settings.order is None:
96
+ # No memory relayout required
97
+ # Note: We should not be able to reach here because it should be shortcutted at passthrough.
98
+ msg_out = message
99
+ else:
100
+ # If the memory is already contiguous in the correct order, np.require won't do anything.
101
+ msg_out = replace(
102
+ message,
103
+ data=np.require(
104
+ message.data, requirements=self.settings.order.upper()[0]
105
+ ),
106
+ )
107
+ else:
108
+ dims_out = [message.dims[ix] for ix in self.state.axes_ints]
109
+ data_out = np.transpose(message.data, axes=self.state.axes_ints)
110
+ if self.settings.order is not None:
111
+ data_out = np.require(
112
+ data_out, requirements=self.settings.order.upper()[0]
113
+ )
114
+ msg_out = replace(
115
+ message,
116
+ data=data_out,
117
+ dims=dims_out,
118
+ )
119
+ return msg_out
120
+
121
+
122
+ class Transpose(
123
+ BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]
124
+ ):
125
+ SETTINGS = TransposeSettings
126
+
127
+
128
+ def transpose(
129
+ axes: tuple[int | str | EllipsisType, ...] | None = None, order: str | None = None
130
+ ) -> TransposeTransformer:
131
+ return TransposeTransformer(TransposeSettings(axes=axes, order=order))
@@ -0,0 +1,156 @@
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import contextlib
4
+ import inspect
5
+ import threading
6
+ from typing import Any, Coroutine, TypeVar
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ class CoroutineExecutionError(Exception):
12
+ """Custom exception for coroutine execution failures"""
13
+
14
+ pass
15
+
16
+
17
+ def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T:
18
+ """
19
+ Executes an asyncio coroutine synchronously, with enhanced error handling.
20
+
21
+ Args:
22
+ coroutine: The asyncio coroutine to execute
23
+ timeout: Maximum time in seconds to wait for coroutine completion (default: 30)
24
+
25
+ Returns:
26
+ The result of the coroutine execution
27
+
28
+ Raises:
29
+ CoroutineExecutionError: If execution fails due to threading or event loop issues
30
+ TimeoutError: If execution exceeds the timeout period
31
+ Exception: Any exception raised by the coroutine
32
+ """
33
+
34
+ def run_in_new_loop() -> T:
35
+ """
36
+ Creates and runs a new event loop in the current thread.
37
+ Ensures proper cleanup of the loop.
38
+ """
39
+ new_loop = asyncio.new_event_loop()
40
+ asyncio.set_event_loop(new_loop)
41
+ try:
42
+ return new_loop.run_until_complete(
43
+ asyncio.wait_for(coroutine, timeout=timeout)
44
+ )
45
+ finally:
46
+ with contextlib.suppress(Exception):
47
+ # Clean up any pending tasks
48
+ pending = asyncio.all_tasks(new_loop)
49
+ for task in pending:
50
+ task.cancel()
51
+ new_loop.run_until_complete(
52
+ asyncio.gather(*pending, return_exceptions=True)
53
+ )
54
+ new_loop.close()
55
+
56
+ try:
57
+ loop = asyncio.get_running_loop()
58
+ except RuntimeError:
59
+ try:
60
+ return asyncio.run(asyncio.wait_for(coroutine, timeout=timeout))
61
+ except Exception as e:
62
+ raise CoroutineExecutionError(
63
+ f"Failed to execute coroutine: {str(e)}"
64
+ ) from e
65
+
66
+ if threading.current_thread() is threading.main_thread():
67
+ if not loop.is_running():
68
+ try:
69
+ return loop.run_until_complete(
70
+ asyncio.wait_for(coroutine, timeout=timeout)
71
+ )
72
+ except Exception as e:
73
+ raise CoroutineExecutionError(
74
+ f"Failed to execute coroutine in main loop: {str(e)}"
75
+ ) from e
76
+ else:
77
+ with ThreadPoolExecutor() as pool:
78
+ try:
79
+ future = pool.submit(run_in_new_loop)
80
+ return future.result(timeout=timeout)
81
+ except Exception as e:
82
+ raise CoroutineExecutionError(
83
+ f"Failed to execute coroutine in thread: {str(e)}"
84
+ ) from e
85
+ else:
86
+ try:
87
+ future = asyncio.run_coroutine_threadsafe(coroutine, loop)
88
+ return future.result(timeout=timeout)
89
+ except Exception as e:
90
+ raise CoroutineExecutionError(
91
+ f"Failed to execute coroutine threadsafe: {str(e)}"
92
+ ) from e
93
+
94
+
95
+ class SyncToAsyncGeneratorWrapper:
96
+ """
97
+ A wrapper for synchronous generators to be used in an async context.
98
+ """
99
+
100
+ def __init__(self, gen):
101
+ self._gen = gen
102
+ self._closed = False
103
+ # Prime the generator to ready for first send/next call
104
+ try:
105
+ is_not_primed = inspect.getgeneratorstate(self._gen) is inspect.GEN_CREATED
106
+ except AttributeError as e:
107
+ raise TypeError(
108
+ "The provided generator is not a valid generator object"
109
+ ) from e
110
+ if is_not_primed:
111
+ try:
112
+ next(self._gen)
113
+ except StopIteration:
114
+ self._closed = True
115
+ except Exception as e:
116
+ raise RuntimeError(f"Failed to prime generator: {e}") from e
117
+
118
+ async def asend(self, value):
119
+ if self._closed:
120
+ raise StopAsyncIteration("Generator is closed")
121
+ try:
122
+ return await asyncio.to_thread(self._gen.send, value)
123
+ except StopIteration as e:
124
+ self._closed = True
125
+ raise StopAsyncIteration("Generator is closed") from e
126
+ except Exception as e:
127
+ raise RuntimeError(f"Error while sending value to generator: {e}") from e
128
+
129
+ async def __anext__(self):
130
+ if self._closed:
131
+ raise StopAsyncIteration("Generator is closed")
132
+ try:
133
+ return await asyncio.to_thread(self._gen.__next__)
134
+ except StopIteration as e:
135
+ self._closed = True
136
+ raise StopAsyncIteration("Generator is closed") from e
137
+ except Exception as e:
138
+ raise RuntimeError(
139
+ f"Error while getting next value from generator: {e}"
140
+ ) from e
141
+
142
+ async def aclose(self):
143
+ if self._closed:
144
+ return
145
+ try:
146
+ await asyncio.to_thread(self._gen.close)
147
+ except Exception as e:
148
+ raise RuntimeError(f"Error while closing generator: {e}") from e
149
+ finally:
150
+ self._closed = True
151
+
152
+ def __aiter__(self):
153
+ return self
154
+
155
+ def __getattr__(self, name):
156
+ return getattr(self._gen, name)
@@ -0,0 +1,31 @@
1
+ import time
2
+ import typing
3
+ from dataclasses import dataclass, field
4
+
5
+ from ezmsg.util.messages.axisarray import AxisArray
6
+
7
+
8
+ @dataclass(unsafe_hash=True)
9
+ class SampleTriggerMessage:
10
+ timestamp: float = field(default_factory=time.time)
11
+ """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
12
+
13
+ period: tuple[float, float] | None = None
14
+ """The period around the timestamp, in seconds"""
15
+
16
+ value: typing.Any = None
17
+ """A value or 'label' associated with the trigger."""
18
+
19
+
20
+ @dataclass
21
+ class SampleMessage:
22
+ trigger: SampleTriggerMessage
23
+ """The time, window, and value (if any) associated with the trigger."""
24
+
25
+ sample: AxisArray
26
+ """The data sampled around the trigger."""
27
+
28
+
29
+ def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
+ """Check if the message is a SampleMessage."""
31
+ return hasattr(message, "trigger")
@@ -8,6 +8,9 @@ import typing
8
8
  import ezmsg.core as ez
9
9
 
10
10
 
11
+ HEADER = "Time,Source,Topic,SampleTime,PerfCounter,Elapsed"
12
+
13
+
11
14
  def get_logger_path() -> Path:
12
15
  # Retrieve the logfile name from the environment variable
13
16
  logfile = os.environ.get("EZMSG_PROFILE", None)
@@ -26,9 +29,23 @@ def _setup_logger(append: bool = False) -> logging.Logger:
26
29
  logpath = get_logger_path()
27
30
  logpath.parent.mkdir(parents=True, exist_ok=True)
28
31
 
29
- if not append:
30
- # Remove the file if it exists
31
- logpath.unlink(missing_ok=True)
32
+ write_header = True
33
+ if logpath.exists() and logpath.is_file():
34
+ if append:
35
+ with open(logpath) as f:
36
+ first_line = f.readline().rstrip()
37
+ if first_line == HEADER:
38
+ write_header = False
39
+ else:
40
+ # Remove the file if appending, but headers do not match
41
+ ezmsg_logger = logging.getLogger("ezmsg")
42
+ ezmsg_logger.warning(
43
+ "Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
44
+ )
45
+ logpath.unlink()
46
+ else:
47
+ # Remove the file if not appending
48
+ logpath.unlink()
32
49
 
33
50
  # Create a logger with the name "ezprofile"
34
51
  _logger = logging.getLogger("ezprofile")
@@ -43,13 +60,13 @@ def _setup_logger(append: bool = False) -> logging.Logger:
43
60
  # Add the file handler to the logger
44
61
  _logger.addHandler(fh)
45
62
 
46
- # Add the first row without formatting.
47
- _logger.debug(",".join(["Time", "Source", "Topic", "SampleTime", "PerfCounter", "Elapsed"]))
63
+ # Add the header if writing to new file or if header matched header in file.
64
+ if write_header:
65
+ _logger.debug(HEADER)
48
66
 
49
67
  # Set the log message format
50
68
  formatter = logging.Formatter(
51
- "%(asctime)s,%(message)s",
52
- datefmt="%Y-%m-%dT%H:%M:%S%z"
69
+ "%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
53
70
  )
54
71
  fh.setFormatter(formatter)
55
72
 
@@ -89,18 +106,31 @@ def profile_method(trace_oldest: bool = True):
89
106
  Returns:
90
107
  Callable: The decorated function with profiling.
91
108
  """
109
+
92
110
  def profiling_decorator(func: typing.Callable):
93
111
  @functools.wraps(func)
94
112
  def wrapped_func(caller, *args, **kwargs):
95
113
  start = time.perf_counter()
96
114
  res = func(caller, *args, **kwargs)
97
115
  stop = time.perf_counter()
98
- source = '.'.join((caller.__class__.__module__, caller.__class__.__name__))
116
+ source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
99
117
  topic = f"{caller.address}"
100
118
  samp_time = _process_obj(res, trace_oldest=trace_oldest)
101
- logger.debug(",".join([source, topic, f"{samp_time}", f"{stop}", f"{(stop - start) * 1e3:0.4f}"]))
119
+ logger.debug(
120
+ ",".join(
121
+ [
122
+ source,
123
+ topic,
124
+ f"{samp_time}",
125
+ f"{stop}",
126
+ f"{(stop - start) * 1e3:0.4f}",
127
+ ]
128
+ )
129
+ )
102
130
  return res
131
+
103
132
  return wrapped_func if logger.level == logging.DEBUG else func
133
+
104
134
  return profiling_decorator
105
135
 
106
136
 
@@ -115,17 +145,30 @@ def profile_subpub(trace_oldest: bool = True):
115
145
  Returns:
116
146
  Callable: The decorated async task with profiling.
117
147
  """
148
+
118
149
  def profiling_decorator(func: typing.Callable):
119
150
  @functools.wraps(func)
120
- async def wrapped_task(unit: ez.Unit, msg: typing.Any = None) -> None:
121
- source = '.'.join((unit.__class__.__module__, unit.__class__.__name__))
151
+ async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
152
+ source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
122
153
  topic = f"{unit.address}"
123
154
  start = time.perf_counter()
124
155
  async for stream, obj in func(unit, msg):
125
156
  stop = time.perf_counter()
126
157
  samp_time = _process_obj(obj, trace_oldest=trace_oldest)
127
- logger.debug(",".join([source, topic, f"{samp_time}", f"{stop}", f"{(stop - start) * 1e3:0.4f}"]))
158
+ logger.debug(
159
+ ",".join(
160
+ [
161
+ source,
162
+ topic,
163
+ f"{samp_time}",
164
+ f"{stop}",
165
+ f"{(stop - start) * 1e3:0.4f}",
166
+ ]
167
+ )
168
+ )
128
169
  start = stop
129
170
  yield stream, obj
171
+
130
172
  return wrapped_task if logger.level == logging.DEBUG else func
173
+
131
174
  return profiling_decorator
@@ -0,0 +1,83 @@
1
+ from types import UnionType
2
+ import typing
3
+ from typing_extensions import get_original_bases
4
+
5
+
6
+ def resolve_typevar(cls: type, target_typevar: typing.TypeVar) -> type:
7
+ """
8
+ Resolve the concrete type bound to a TypeVar in a class hierarchy.
9
+ This function traverses the method resolution order (MRO) of the class
10
+ and checks the original bases of each class in the MRO for the TypeVar.
11
+ If the TypeVar is found, it returns the concrete type bound to it.
12
+ If the TypeVar is not found, it raises a TypeError.
13
+ Args:
14
+ cls (type): The class to inspect.
15
+ target_typevar (typing.TypeVar): The TypeVar to resolve.
16
+ Returns:
17
+ type: The concrete type bound to the TypeVar.
18
+ """
19
+ for base in cls.__mro__:
20
+ orig_bases = get_original_bases(base)
21
+ for orig_base in orig_bases:
22
+ origin = typing.get_origin(orig_base)
23
+ if origin is None:
24
+ continue
25
+ params = getattr(origin, "__parameters__", ())
26
+ if not params:
27
+ continue
28
+ if target_typevar in params:
29
+ index = params.index(target_typevar)
30
+ args = typing.get_args(orig_base)
31
+ try:
32
+ return args[index]
33
+ except IndexError:
34
+ pass
35
+ raise TypeError(f"Could not resolve {target_typevar} in {cls}")
36
+
37
+
38
+ TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
39
+
40
+
41
+ def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
42
+ """
43
+ Check if two types are compatible for message passing.
44
+ Returns True if:
45
+ - Both are None/NoneType
46
+ - Either is typing.Any
47
+ - type1 is a subclass of type2, which includes
48
+ - type1 and type2 are concrete types and type1 is a subclass of type2
49
+ - type1 is None/NoneType and type2 is typing.Optional, or
50
+ - type1 is subtype of the non-None inner type of type2 if type2 is Optional
51
+ - type1 is a Union/Optional type and all inner types are compatible with type2
52
+ Args:
53
+ type1: First type to compare
54
+ type2: Second type to compare
55
+ Returns:
56
+ bool: True if the types are compatible, False otherwise
57
+ """
58
+ # If either is Any, they are compatible
59
+ if type1 is typing.Any or type2 is typing.Any:
60
+ return True
61
+
62
+ # Handle None as NoneType
63
+ if type1 is None:
64
+ type1 = type(None)
65
+ if type2 is None:
66
+ type2 = type(None)
67
+
68
+ # Handle if type1 is Optional/Union type
69
+ if typing.get_origin(type1) in {typing.Union, UnionType}:
70
+ return all(
71
+ check_message_type_compatibility(inner_type, type2)
72
+ for inner_type in typing.get_args(type1)
73
+ )
74
+
75
+ # Regular issubclass check. Handles cases like:
76
+ # - type1 is a subclass of concrete type2
77
+ # - type1 is a subclass of the inner type of type2 if type2 is Optional
78
+ # - type1 is a subclass of one of the inner types of type2 if type2 is Union
79
+ # - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
80
+ try:
81
+ return issubclass(type1, type2)
82
+ except TypeError:
83
+ return False