ezmsg-sigproc 2.5.0__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. ezmsg/sigproc/__version__.py +2 -2
  2. ezmsg/sigproc/activation.py +5 -11
  3. ezmsg/sigproc/adaptive_lattice_notch.py +11 -30
  4. ezmsg/sigproc/affinetransform.py +16 -42
  5. ezmsg/sigproc/aggregate.py +17 -34
  6. ezmsg/sigproc/bandpower.py +12 -20
  7. ezmsg/sigproc/base.py +141 -1276
  8. ezmsg/sigproc/butterworthfilter.py +8 -16
  9. ezmsg/sigproc/butterworthzerophase.py +7 -16
  10. ezmsg/sigproc/cheby.py +4 -10
  11. ezmsg/sigproc/combfilter.py +5 -8
  12. ezmsg/sigproc/coordinatespaces.py +142 -0
  13. ezmsg/sigproc/decimate.py +3 -7
  14. ezmsg/sigproc/denormalize.py +6 -11
  15. ezmsg/sigproc/detrend.py +3 -4
  16. ezmsg/sigproc/diff.py +8 -17
  17. ezmsg/sigproc/downsample.py +11 -20
  18. ezmsg/sigproc/ewma.py +11 -28
  19. ezmsg/sigproc/ewmfilter.py +1 -1
  20. ezmsg/sigproc/extract_axis.py +3 -4
  21. ezmsg/sigproc/fbcca.py +34 -59
  22. ezmsg/sigproc/filter.py +19 -45
  23. ezmsg/sigproc/filterbank.py +37 -74
  24. ezmsg/sigproc/filterbankdesign.py +7 -14
  25. ezmsg/sigproc/fir_hilbert.py +13 -30
  26. ezmsg/sigproc/fir_pmc.py +5 -10
  27. ezmsg/sigproc/firfilter.py +12 -14
  28. ezmsg/sigproc/gaussiansmoothing.py +5 -9
  29. ezmsg/sigproc/kaiser.py +11 -15
  30. ezmsg/sigproc/math/abs.py +4 -3
  31. ezmsg/sigproc/math/add.py +121 -0
  32. ezmsg/sigproc/math/clip.py +4 -1
  33. ezmsg/sigproc/math/difference.py +100 -36
  34. ezmsg/sigproc/math/invert.py +3 -3
  35. ezmsg/sigproc/math/log.py +5 -6
  36. ezmsg/sigproc/math/scale.py +2 -0
  37. ezmsg/sigproc/messages.py +1 -2
  38. ezmsg/sigproc/quantize.py +3 -6
  39. ezmsg/sigproc/resample.py +17 -38
  40. ezmsg/sigproc/rollingscaler.py +12 -37
  41. ezmsg/sigproc/sampler.py +19 -37
  42. ezmsg/sigproc/scaler.py +11 -22
  43. ezmsg/sigproc/signalinjector.py +7 -18
  44. ezmsg/sigproc/slicer.py +14 -34
  45. ezmsg/sigproc/spectral.py +3 -3
  46. ezmsg/sigproc/spectrogram.py +12 -19
  47. ezmsg/sigproc/spectrum.py +17 -38
  48. ezmsg/sigproc/transpose.py +12 -24
  49. ezmsg/sigproc/util/asio.py +25 -156
  50. ezmsg/sigproc/util/axisarray_buffer.py +12 -26
  51. ezmsg/sigproc/util/buffer.py +22 -43
  52. ezmsg/sigproc/util/message.py +17 -31
  53. ezmsg/sigproc/util/profile.py +23 -174
  54. ezmsg/sigproc/util/sparse.py +7 -15
  55. ezmsg/sigproc/util/typeresolution.py +17 -83
  56. ezmsg/sigproc/wavelets.py +10 -19
  57. ezmsg/sigproc/window.py +29 -83
  58. ezmsg_sigproc-2.7.0.dist-info/METADATA +60 -0
  59. ezmsg_sigproc-2.7.0.dist-info/RECORD +64 -0
  60. ezmsg/sigproc/synth.py +0 -774
  61. ezmsg_sigproc-2.5.0.dist-info/METADATA +0 -72
  62. ezmsg_sigproc-2.5.0.dist-info/RECORD +0 -63
  63. {ezmsg_sigproc-2.5.0.dist-info → ezmsg_sigproc-2.7.0.dist-info}/WHEEL +0 -0
  64. /ezmsg_sigproc-2.5.0.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.7.0.dist-info/licenses/LICENSE +0 -0
@@ -1,3 +1,7 @@
1
+ """A stateful, FIFO buffer that combines a deque for fast appends with a
2
+ contiguous circular buffer for efficient, advancing reads.
3
+ """
4
+
1
5
  import collections
2
6
  import math
3
7
  import typing
@@ -63,9 +67,7 @@ class HybridBuffer:
63
67
  self._buff_unread = 0 # Number of unread samples in the circular buffer
64
68
  self._buff_read = 0 # Tracks samples read and still in buffer
65
69
  self._deque_len = 0 # Number of unread samples in the deque
66
- self._last_overflow = (
67
- 0 # Tracks the last overflow count, overwritten or skipped
68
- )
70
+ self._last_overflow = 0 # Tracks the last overflow count, overwritten or skipped
69
71
  self._warned = False # Tracks if we've warned already (for warn_once)
70
72
 
71
73
  @property
@@ -96,9 +98,7 @@ class HybridBuffer:
96
98
  block = block[:, self.xp.newaxis]
97
99
 
98
100
  if block.shape[1:] != other_shape:
99
- raise ValueError(
100
- f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}"
101
- )
101
+ raise ValueError(f"Block shape {block.shape[1:]} does not match buffer's other_shape {other_shape}")
102
102
 
103
103
  # Most overflow strategies are handled during flush, but there are a couple
104
104
  # scenarios that can be evaluated on write to give immediate feedback.
@@ -117,8 +117,7 @@ class HybridBuffer:
117
117
  self._deque_len += block.shape[0]
118
118
 
119
119
  if self._update_strategy == "immediate" or (
120
- self._update_strategy == "threshold"
121
- and (0 < self._threshold <= self._deque_len)
120
+ self._update_strategy == "threshold" and (0 < self._threshold <= self._deque_len)
122
121
  ):
123
122
  self.flush()
124
123
 
@@ -128,9 +127,7 @@ class HybridBuffer:
128
127
  from the buffer.
129
128
  """
130
129
  if n_samples > self.available():
131
- raise ValueError(
132
- f"Requested {n_samples} samples, but only {self.available()} are available."
133
- )
130
+ raise ValueError(f"Requested {n_samples} samples, but only {self.available()} are available.")
134
131
  n_overflow = 0
135
132
  if self._deque and (n_samples > self._buff_unread):
136
133
  # We would cause a flush, but would that cause an overflow?
@@ -161,14 +158,10 @@ class HybridBuffer:
161
158
  n_overflow = self._estimate_overflow(n_samples)
162
159
  if n_overflow > 0:
163
160
  first_read = self._buff_unread
164
- if (n_overflow - first_read) < self.capacity or (
165
- self._overflow_strategy == "drop"
166
- ):
161
+ if (n_overflow - first_read) < self.capacity or (self._overflow_strategy == "drop"):
167
162
  # We can prevent the overflow (or at least *some* if using "drop"
168
163
  # strategy) by reading the samples in the buffer first to make room.
169
- data = self.xp.empty(
170
- (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
171
- )
164
+ data = self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
172
165
  self.peek(first_read, out=data[:first_read])
173
166
  offset += first_read
174
167
  self.seek(first_read)
@@ -204,13 +197,9 @@ class HybridBuffer:
204
197
  if n_samples is None:
205
198
  n_samples = self.available()
206
199
  elif n_samples > self.available():
207
- raise ValueError(
208
- f"Requested to peek {n_samples} samples, but only {self.available()} are available."
209
- )
200
+ raise ValueError(f"Requested to peek {n_samples} samples, but only {self.available()} are available.")
210
201
  if out is not None and out.shape[0] < n_samples:
211
- raise ValueError(
212
- f"Output array shape {out.shape} is smaller than requested {n_samples} samples."
213
- )
202
+ raise ValueError(f"Output array shape {out.shape} is smaller than requested {n_samples} samples.")
214
203
 
215
204
  if n_samples == 0:
216
205
  return self._buffer[:0]
@@ -224,9 +213,7 @@ class HybridBuffer:
224
213
  out = (
225
214
  out
226
215
  if out is not None
227
- else self.xp.empty(
228
- (n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype
229
- )
216
+ else self.xp.empty((n_samples, *self._buffer.shape[1:]), dtype=self._buffer.dtype)
230
217
  )
231
218
  out[:part1_len] = self._buffer[self._tail :]
232
219
  out[part1_len:] = self._buffer[:part2_len]
@@ -258,9 +245,7 @@ class HybridBuffer:
258
245
  if not allow_flush and idx >= self._buff_unread:
259
246
  # The requested sample is in the deque.
260
247
  idx -= self._buff_unread
261
- deq_splits = self.xp.cumsum(
262
- [0] + [_.shape[0] for _ in self._deque], dtype=int
263
- )
248
+ deq_splits = self.xp.cumsum([0] + [_.shape[0] for _ in self._deque], dtype=int)
264
249
  arr_idx = self.xp.searchsorted(deq_splits, idx, side="right") - 1
265
250
  idx -= deq_splits[arr_idx]
266
251
  return self._deque[arr_idx][idx : idx + 1]
@@ -334,7 +319,8 @@ class HybridBuffer:
334
319
  if n_overflow > 0 and (not self._warn_once or not self._warned):
335
320
  self._warned = True
336
321
  warnings.warn(
337
- f"Buffer overflow: {n_new} samples received, but only {self._capacity - self._buff_unread} available. "
322
+ f"Buffer overflow: {n_new} samples received, "
323
+ f"but only {self._capacity - self._buff_unread} available. "
338
324
  f"Overwriting {n_overflow} previous samples.",
339
325
  RuntimeWarning,
340
326
  )
@@ -347,10 +333,9 @@ class HybridBuffer:
347
333
  break
348
334
  n_to_copy = min(block.shape[0], samples_to_copy - copied_samples)
349
335
  start_idx = block.shape[0] - n_to_copy
350
- self._buffer[
351
- samples_to_copy - copied_samples - n_to_copy : samples_to_copy
352
- - copied_samples
353
- ] = block[start_idx:]
336
+ self._buffer[samples_to_copy - copied_samples - n_to_copy : samples_to_copy - copied_samples] = block[
337
+ start_idx:
338
+ ]
354
339
  copied_samples += n_to_copy
355
340
 
356
341
  self._head = 0
@@ -362,9 +347,7 @@ class HybridBuffer:
362
347
  else:
363
348
  if n_overflow > 0:
364
349
  if self._overflow_strategy == "raise":
365
- raise OverflowError(
366
- f"Buffer overflow: {n_new} samples received, but only {n_free} available."
367
- )
350
+ raise OverflowError(f"Buffer overflow: {n_new} samples received, but only {n_free} available.")
368
351
  elif self._overflow_strategy == "warn-overwrite":
369
352
  if not self._warn_once or not self._warned:
370
353
  self._warned = True
@@ -430,9 +413,7 @@ class HybridBuffer:
430
413
  return
431
414
 
432
415
  other_shape = self._buffer.shape[1:]
433
- max_capacity = self._max_size / (
434
- self._buffer.dtype.itemsize * math.prod(other_shape)
435
- )
416
+ max_capacity = self._max_size / (self._buffer.dtype.itemsize * math.prod(other_shape))
436
417
  if min_capacity > max_capacity:
437
418
  raise OverflowError(
438
419
  f"Cannot grow buffer to {min_capacity} samples, "
@@ -440,9 +421,7 @@ class HybridBuffer:
440
421
  )
441
422
 
442
423
  new_capacity = min(max_capacity, max(self._capacity * 2, min_capacity))
443
- new_buffer = self.xp.empty(
444
- (new_capacity, *other_shape), dtype=self._buffer.dtype
445
- )
424
+ new_buffer = self.xp.empty((new_capacity, *other_shape), dtype=self._buffer.dtype)
446
425
 
447
426
  # Copy existing data to new buffer
448
427
  total_samples = self._buff_read + self._buff_unread
@@ -1,31 +1,17 @@
1
- import time
2
- import typing
3
- from dataclasses import dataclass, field
4
-
5
- from ezmsg.util.messages.axisarray import AxisArray
6
-
7
-
8
- @dataclass(unsafe_hash=True)
9
- class SampleTriggerMessage:
10
- timestamp: float = field(default_factory=time.time)
11
- """Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
12
-
13
- period: tuple[float, float] | None = None
14
- """The period around the timestamp, in seconds"""
15
-
16
- value: typing.Any = None
17
- """A value or 'label' associated with the trigger."""
18
-
19
-
20
- @dataclass
21
- class SampleMessage:
22
- trigger: SampleTriggerMessage
23
- """The time, window, and value (if any) associated with the trigger."""
24
-
25
- sample: AxisArray
26
- """The data sampled around the trigger."""
27
-
28
-
29
- def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
- """Check if the message is a SampleMessage."""
31
- return hasattr(message, "trigger")
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.message.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.message import (
8
+ SampleMessage,
9
+ SampleTriggerMessage,
10
+ is_sample_message,
11
+ )
12
+
13
+ __all__ = [
14
+ "SampleMessage",
15
+ "SampleTriggerMessage",
16
+ "is_sample_message",
17
+ ]
@@ -1,174 +1,23 @@
1
- import functools
2
- import logging
3
- import os
4
- from pathlib import Path
5
- import time
6
- import typing
7
-
8
- import ezmsg.core as ez
9
-
10
-
11
- HEADER = "Time,Source,Topic,SampleTime,PerfCounter,Elapsed"
12
-
13
-
14
- def get_logger_path() -> Path:
15
- # Retrieve the logfile name from the environment variable
16
- logfile = os.environ.get("EZMSG_PROFILE", None)
17
-
18
- # Determine the log file path, defaulting to "ezprofiler.log" if not set
19
- logpath = Path(logfile or "ezprofiler.log")
20
-
21
- # If the log path is not absolute, prepend it with the user's home directory and ".ezmsg/profile"
22
- if not logpath.is_absolute():
23
- logpath = Path.home() / ".ezmsg" / "profile" / logpath
24
-
25
- return logpath
26
-
27
-
28
- def _setup_logger(append: bool = False) -> logging.Logger:
29
- logpath = get_logger_path()
30
- logpath.parent.mkdir(parents=True, exist_ok=True)
31
-
32
- write_header = True
33
- if logpath.exists() and logpath.is_file():
34
- if append:
35
- with open(logpath) as f:
36
- first_line = f.readline().rstrip()
37
- if first_line == HEADER:
38
- write_header = False
39
- else:
40
- # Remove the file if appending, but headers do not match
41
- ezmsg_logger = logging.getLogger("ezmsg")
42
- ezmsg_logger.warning(
43
- "Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
44
- )
45
- logpath.unlink()
46
- else:
47
- # Remove the file if not appending
48
- logpath.unlink()
49
-
50
- # Create a logger with the name "ezprofile"
51
- _logger = logging.getLogger("ezprofile")
52
-
53
- # Set the logger's level to EZMSG_LOGLEVEL env var value if it exists, otherwise INFO
54
- _logger.setLevel(os.environ.get("EZMSG_LOGLEVEL", "INFO").upper())
55
-
56
- # Create a file handler to write log messages to the log file
57
- fh = logging.FileHandler(logpath)
58
- fh.setLevel(logging.DEBUG) # Set the file handler log level to DEBUG
59
-
60
- # Add the file handler to the logger
61
- _logger.addHandler(fh)
62
-
63
- # Add the header if writing to new file or if header matched header in file.
64
- if write_header:
65
- _logger.debug(HEADER)
66
-
67
- # Set the log message format
68
- formatter = logging.Formatter(
69
- "%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
70
- )
71
- fh.setFormatter(formatter)
72
-
73
- return _logger
74
-
75
-
76
- logger = _setup_logger(append=True)
77
-
78
-
79
- def _process_obj(obj, trace_oldest: bool = True):
80
- samp_time = None
81
- if hasattr(obj, "axes") and ("time" in obj.axes or "win" in obj.axes):
82
- axis = "win" if "win" in obj.axes else "time"
83
- ax = obj.get_axis(axis)
84
- len = obj.data.shape[obj.get_axis_idx(axis)]
85
- if len > 0:
86
- idx = 0 if trace_oldest else (len - 1)
87
- if hasattr(ax, "data"):
88
- samp_time = ax.data[idx]
89
- else:
90
- samp_time = ax.value(idx)
91
- if ax == "win" and "time" in obj.axes:
92
- if hasattr(obj.axes["time"], "data"):
93
- samp_time += obj.axes["time"].data[idx]
94
- else:
95
- samp_time += obj.axes["time"].value(idx)
96
- return samp_time
97
-
98
-
99
- def profile_method(trace_oldest: bool = True):
100
- """
101
- Decorator to profile a method by logging its execution time and other details.
102
-
103
- Args:
104
- trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
105
-
106
- Returns:
107
- Callable: The decorated function with profiling.
108
- """
109
-
110
- def profiling_decorator(func: typing.Callable):
111
- @functools.wraps(func)
112
- def wrapped_func(caller, *args, **kwargs):
113
- start = time.perf_counter()
114
- res = func(caller, *args, **kwargs)
115
- stop = time.perf_counter()
116
- source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
117
- topic = f"{caller.address}"
118
- samp_time = _process_obj(res, trace_oldest=trace_oldest)
119
- logger.debug(
120
- ",".join(
121
- [
122
- source,
123
- topic,
124
- f"{samp_time}",
125
- f"{stop}",
126
- f"{(stop - start) * 1e3:0.4f}",
127
- ]
128
- )
129
- )
130
- return res
131
-
132
- return wrapped_func if logger.level == logging.DEBUG else func
133
-
134
- return profiling_decorator
135
-
136
-
137
- def profile_subpub(trace_oldest: bool = True):
138
- """
139
- Decorator to profile a subscriber-publisher method in an ezmsg Unit
140
- by logging its execution time and other details.
141
-
142
- Args:
143
- trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
144
-
145
- Returns:
146
- Callable: The decorated async task with profiling.
147
- """
148
-
149
- def profiling_decorator(func: typing.Callable):
150
- @functools.wraps(func)
151
- async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
152
- source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
153
- topic = f"{unit.address}"
154
- start = time.perf_counter()
155
- async for stream, obj in func(unit, msg):
156
- stop = time.perf_counter()
157
- samp_time = _process_obj(obj, trace_oldest=trace_oldest)
158
- logger.debug(
159
- ",".join(
160
- [
161
- source,
162
- topic,
163
- f"{samp_time}",
164
- f"{stop}",
165
- f"{(stop - start) * 1e3:0.4f}",
166
- ]
167
- )
168
- )
169
- start = stop
170
- yield stream, obj
171
-
172
- return wrapped_task if logger.level == logging.DEBUG else func
173
-
174
- return profiling_decorator
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.profile.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.profile import (
8
+ HEADER,
9
+ _setup_logger,
10
+ get_logger_path,
11
+ logger,
12
+ profile_method,
13
+ profile_subpub,
14
+ )
15
+
16
+ __all__ = [
17
+ "HEADER",
18
+ "get_logger_path",
19
+ "logger",
20
+ "profile_method",
21
+ "profile_subpub",
22
+ "_setup_logger",
23
+ ]
@@ -1,10 +1,10 @@
1
+ """Methods for sparse array signal processing operations."""
2
+
1
3
  import numpy as np
2
4
  import sparse
3
5
 
4
6
 
5
- def sliding_win_oneaxis_old(
6
- s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
7
- ) -> sparse.SparseArray:
7
+ def sliding_win_oneaxis_old(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
8
8
  """
9
9
  Like `ezmsg.util.messages.axisarray.sliding_win_oneaxis` but for sparse arrays.
10
10
  This approach is about 4x slower than the version that uses coordinate arithmetic below.
@@ -23,16 +23,12 @@ def sliding_win_oneaxis_old(
23
23
  targ_slices = [slice(_, _ + nwin) for _ in range(0, s.shape[axis] - nwin + 1, step)]
24
24
  s = s.reshape(s.shape[:axis] + (1,) + s.shape[axis:])
25
25
  full_slices = (slice(None),) * s.ndim
26
- full_slices = [
27
- full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices
28
- ]
26
+ full_slices = [full_slices[: axis + 1] + (sl,) + full_slices[axis + 2 :] for sl in targ_slices]
29
27
  result = sparse.concatenate([s[_] for _ in full_slices], axis=axis)
30
28
  return result
31
29
 
32
30
 
33
- def sliding_win_oneaxis(
34
- s: sparse.SparseArray, nwin: int, axis: int, step: int = 1
35
- ) -> sparse.SparseArray:
31
+ def sliding_win_oneaxis(s: sparse.SparseArray, nwin: int, axis: int, step: int = 1) -> sparse.SparseArray:
36
32
  """
37
33
  Generates a view-like sparse array using a sliding window of specified length along a specified axis.
38
34
  Sparse analog of an optimized dense as_strided-based implementation with these properties:
@@ -72,9 +68,7 @@ def sliding_win_oneaxis(
72
68
  n_win_out = len(win_starts)
73
69
  if n_win_out <= 0:
74
70
  # Return array with proper shape except empty along windows axis
75
- return sparse.zeros(
76
- s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
77
- )
71
+ return sparse.zeros(s.shape[:axis] + (0,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
78
72
 
79
73
  coo = s.asformat("coo")
80
74
  coords = coo.coords # shape: (ndim, nnz)
@@ -112,9 +106,7 @@ def sliding_win_oneaxis(
112
106
  out_data_blocks.append(data[sel])
113
107
 
114
108
  if not out_coords_blocks:
115
- return sparse.zeros(
116
- s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype
117
- )
109
+ return sparse.zeros(s.shape[:axis] + (n_win_out,) + (nwin,) + s.shape[axis + 1 :], dtype=s.dtype)
118
110
 
119
111
  out_coords = np.hstack(out_coords_blocks)
120
112
  out_data = np.hstack(out_data_blocks)
@@ -1,83 +1,17 @@
1
- from types import UnionType
2
- import typing
3
- from typing_extensions import get_original_bases
4
-
5
-
6
- def resolve_typevar(cls: type, target_typevar: typing.TypeVar) -> type:
7
- """
8
- Resolve the concrete type bound to a TypeVar in a class hierarchy.
9
- This function traverses the method resolution order (MRO) of the class
10
- and checks the original bases of each class in the MRO for the TypeVar.
11
- If the TypeVar is found, it returns the concrete type bound to it.
12
- If the TypeVar is not found, it raises a TypeError.
13
- Args:
14
- cls (type): The class to inspect.
15
- target_typevar (typing.TypeVar): The TypeVar to resolve.
16
- Returns:
17
- type: The concrete type bound to the TypeVar.
18
- """
19
- for base in cls.__mro__:
20
- orig_bases = get_original_bases(base)
21
- for orig_base in orig_bases:
22
- origin = typing.get_origin(orig_base)
23
- if origin is None:
24
- continue
25
- params = getattr(origin, "__parameters__", ())
26
- if not params:
27
- continue
28
- if target_typevar in params:
29
- index = params.index(target_typevar)
30
- args = typing.get_args(orig_base)
31
- try:
32
- return args[index]
33
- except IndexError:
34
- pass
35
- raise TypeError(f"Could not resolve {target_typevar} in {cls}")
36
-
37
-
38
- TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
39
-
40
-
41
- def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
42
- """
43
- Check if two types are compatible for message passing.
44
- Returns True if:
45
- - Both are None/NoneType
46
- - Either is typing.Any
47
- - type1 is a subclass of type2, which includes
48
- - type1 and type2 are concrete types and type1 is a subclass of type2
49
- - type1 is None/NoneType and type2 is typing.Optional, or
50
- - type1 is subtype of the non-None inner type of type2 if type2 is Optional
51
- - type1 is a Union/Optional type and all inner types are compatible with type2
52
- Args:
53
- type1: First type to compare
54
- type2: Second type to compare
55
- Returns:
56
- bool: True if the types are compatible, False otherwise
57
- """
58
- # If either is Any, they are compatible
59
- if type1 is typing.Any or type2 is typing.Any:
60
- return True
61
-
62
- # Handle None as NoneType
63
- if type1 is None:
64
- type1 = type(None)
65
- if type2 is None:
66
- type2 = type(None)
67
-
68
- # Handle if type1 is Optional/Union type
69
- if typing.get_origin(type1) in {typing.Union, UnionType}:
70
- return all(
71
- check_message_type_compatibility(inner_type, type2)
72
- for inner_type in typing.get_args(type1)
73
- )
74
-
75
- # Regular issubclass check. Handles cases like:
76
- # - type1 is a subclass of concrete type2
77
- # - type1 is a subclass of the inner type of type2 if type2 is Optional
78
- # - type1 is a subclass of one of the inner types of type2 if type2 is Union
79
- # - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
80
- try:
81
- return issubclass(type1, type2)
82
- except TypeError:
83
- return False
1
+ """
2
+ Backwards-compatible re-exports from ezmsg.baseproc.util.typeresolution.
3
+
4
+ New code should import directly from ezmsg.baseproc instead.
5
+ """
6
+
7
+ from ezmsg.baseproc.util.typeresolution import (
8
+ TypeLike,
9
+ check_message_type_compatibility,
10
+ resolve_typevar,
11
+ )
12
+
13
+ __all__ = [
14
+ "TypeLike",
15
+ "check_message_type_compatibility",
16
+ "resolve_typevar",
17
+ ]
ezmsg/sigproc/wavelets.py CHANGED
@@ -1,18 +1,18 @@
1
1
  import typing
2
2
 
3
+ import ezmsg.core as ez
3
4
  import numpy as np
4
5
  import numpy.typing as npt
5
6
  import pywt
6
- import ezmsg.core as ez
7
- from ezmsg.util.messages.axisarray import AxisArray
8
- from ezmsg.util.messages.util import replace
9
-
10
- from .base import (
7
+ from ezmsg.baseproc import (
11
8
  BaseStatefulTransformer,
12
9
  BaseTransformerUnit,
13
10
  processor_state,
14
11
  )
15
- from .filterbank import filterbank, FilterbankMode, MinPhaseMode
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.messages.util import replace
14
+
15
+ from .filterbank import FilterbankMode, MinPhaseMode, filterbank
16
16
 
17
17
 
18
18
  class CWTSettings(ez.Settings):
@@ -37,9 +37,7 @@ class CWTState:
37
37
  last_conv_samp: npt.NDArray | None = None
38
38
 
39
39
 
40
- class CWTTransformer(
41
- BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]
42
- ):
40
+ class CWTTransformer(BaseStatefulTransformer[CWTSettings, AxisArray, AxisArray, CWTState]):
43
41
  def _hash_message(self, message: AxisArray) -> int:
44
42
  ax_idx = message.get_axis_idx(self.settings.axis)
45
43
  in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
@@ -107,25 +105,18 @@ class CWTTransformer(
107
105
  # Create output template
108
106
  ax_idx = message.get_axis_idx(self.settings.axis)
109
107
  in_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :]
110
- freqs = (
111
- pywt.scale2frequency(wavelet, scales, precision)
112
- / message.axes[self.settings.axis].gain
113
- )
108
+ freqs = pywt.scale2frequency(wavelet, scales, precision) / message.axes[self.settings.axis].gain
114
109
  dummy_shape = in_shape + (len(scales), 0)
115
110
  self._state.template = AxisArray(
116
111
  np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data),
117
- dims=message.dims[:ax_idx]
118
- + message.dims[ax_idx + 1 :]
119
- + ["freq", self.settings.axis],
112
+ dims=message.dims[:ax_idx] + message.dims[ax_idx + 1 :] + ["freq", self.settings.axis],
120
113
  axes={
121
114
  **message.axes,
122
115
  "freq": AxisArray.CoordinateAxis(unit="Hz", data=freqs, dims=["freq"]),
123
116
  },
124
117
  key=message.key,
125
118
  )
126
- self._state.last_conv_samp = np.zeros(
127
- dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype
128
- )
119
+ self._state.last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=self._state.template.data.dtype)
129
120
 
130
121
  def _process(self, message: AxisArray) -> AxisArray:
131
122
  conv_msg = self._state.fbgen.send(message)