ezmsg-sigproc 1.8.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +36 -39
- ezmsg/sigproc/adaptive_lattice_notch.py +231 -0
- ezmsg/sigproc/affinetransform.py +169 -163
- ezmsg/sigproc/aggregate.py +119 -104
- ezmsg/sigproc/bandpower.py +58 -52
- ezmsg/sigproc/base.py +1242 -0
- ezmsg/sigproc/butterworthfilter.py +37 -33
- ezmsg/sigproc/cheby.py +29 -17
- ezmsg/sigproc/combfilter.py +163 -0
- ezmsg/sigproc/decimate.py +19 -10
- ezmsg/sigproc/detrend.py +29 -0
- ezmsg/sigproc/diff.py +81 -0
- ezmsg/sigproc/downsample.py +78 -78
- ezmsg/sigproc/ewma.py +197 -0
- ezmsg/sigproc/extract_axis.py +41 -0
- ezmsg/sigproc/filter.py +257 -141
- ezmsg/sigproc/filterbank.py +247 -199
- ezmsg/sigproc/math/abs.py +17 -22
- ezmsg/sigproc/math/clip.py +24 -24
- ezmsg/sigproc/math/difference.py +34 -30
- ezmsg/sigproc/math/invert.py +13 -25
- ezmsg/sigproc/math/log.py +28 -33
- ezmsg/sigproc/math/scale.py +18 -26
- ezmsg/sigproc/quantize.py +71 -0
- ezmsg/sigproc/resample.py +298 -0
- ezmsg/sigproc/sampler.py +241 -259
- ezmsg/sigproc/scaler.py +55 -218
- ezmsg/sigproc/signalinjector.py +52 -43
- ezmsg/sigproc/slicer.py +81 -89
- ezmsg/sigproc/spectrogram.py +77 -75
- ezmsg/sigproc/spectrum.py +203 -168
- ezmsg/sigproc/synth.py +546 -393
- ezmsg/sigproc/transpose.py +131 -0
- ezmsg/sigproc/util/asio.py +156 -0
- ezmsg/sigproc/util/message.py +31 -0
- ezmsg/sigproc/util/profile.py +55 -12
- ezmsg/sigproc/util/typeresolution.py +83 -0
- ezmsg/sigproc/wavelets.py +154 -153
- ezmsg/sigproc/window.py +269 -211
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/METADATA +2 -1
- ezmsg_sigproc-2.0.0.dist-info/RECORD +51 -0
- ezmsg_sigproc-1.8.1.dist-info/RECORD +0 -39
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/WHEEL +0 -0
- {ezmsg_sigproc-1.8.1.dist-info → ezmsg_sigproc-2.0.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from types import EllipsisType
|
|
2
|
+
import numpy as np
|
|
3
|
+
import ezmsg.core as ez
|
|
4
|
+
from ezmsg.util.messages.axisarray import (
|
|
5
|
+
AxisArray,
|
|
6
|
+
replace,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from .base import (
|
|
10
|
+
BaseStatefulTransformer,
|
|
11
|
+
BaseTransformerUnit,
|
|
12
|
+
processor_state,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TransposeSettings(ez.Settings):
|
|
17
|
+
"""
|
|
18
|
+
Settings for :obj:`Transpose` node.
|
|
19
|
+
|
|
20
|
+
Fields:
|
|
21
|
+
axes:
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None
|
|
25
|
+
order: str | None = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@processor_state
|
|
29
|
+
class TransposeState:
|
|
30
|
+
axes_ints: tuple[int, ...] | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TransposeTransformer(
|
|
34
|
+
BaseStatefulTransformer[TransposeSettings, AxisArray, AxisArray, TransposeState]
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Downsampled data simply comprise every `factor`th sample.
|
|
38
|
+
This should only be used following appropriate lowpass filtering.
|
|
39
|
+
If your pipeline does not already have lowpass filtering then consider
|
|
40
|
+
using the :obj:`Decimate` collection instead.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def _hash_message(self, message: AxisArray) -> int:
|
|
44
|
+
return hash(tuple(message.dims))
|
|
45
|
+
|
|
46
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
47
|
+
if self.settings.axes is None:
|
|
48
|
+
self._state.axes_ints = None
|
|
49
|
+
else:
|
|
50
|
+
ell_ix = [ix for ix, ax in enumerate(self.settings.axes) if ax is Ellipsis]
|
|
51
|
+
if len(ell_ix) > 1:
|
|
52
|
+
raise ValueError("Only one Ellipsis is allowed in axes.")
|
|
53
|
+
ell_ix = ell_ix[0] if len(ell_ix) == 1 else len(message.dims)
|
|
54
|
+
prefix = []
|
|
55
|
+
for ax in self.settings.axes[:ell_ix]:
|
|
56
|
+
if isinstance(ax, int):
|
|
57
|
+
prefix.append(ax)
|
|
58
|
+
else:
|
|
59
|
+
if ax not in message.dims:
|
|
60
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
61
|
+
prefix.append(message.dims.index(ax))
|
|
62
|
+
suffix = []
|
|
63
|
+
for ax in self.settings.axes[ell_ix + 1 :]:
|
|
64
|
+
if isinstance(ax, int):
|
|
65
|
+
suffix.append(ax)
|
|
66
|
+
else:
|
|
67
|
+
if ax not in message.dims:
|
|
68
|
+
raise ValueError(f"Axis {ax} not found in message dims.")
|
|
69
|
+
suffix.append(message.dims.index(ax))
|
|
70
|
+
ells = [
|
|
71
|
+
_
|
|
72
|
+
for _ in range(message.data.ndim)
|
|
73
|
+
if _ not in prefix and _ not in suffix
|
|
74
|
+
]
|
|
75
|
+
re_ix = tuple(prefix + ells + suffix)
|
|
76
|
+
if re_ix == tuple(range(message.data.ndim)):
|
|
77
|
+
self._state.axes_ints = None
|
|
78
|
+
else:
|
|
79
|
+
self._state.axes_ints = re_ix
|
|
80
|
+
if self.settings.order is not None and self.settings.order.upper()[0] not in [
|
|
81
|
+
"C",
|
|
82
|
+
"F",
|
|
83
|
+
]:
|
|
84
|
+
raise ValueError("order must be 'C' or 'F'.")
|
|
85
|
+
|
|
86
|
+
def __call__(self, message: AxisArray) -> AxisArray:
|
|
87
|
+
if self.settings.axes is None and self.settings.order is None:
|
|
88
|
+
# Passthrough
|
|
89
|
+
return message
|
|
90
|
+
return super().__call__(message)
|
|
91
|
+
|
|
92
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
93
|
+
if self.state.axes_ints is None:
|
|
94
|
+
# No transpose required
|
|
95
|
+
if self.settings.order is None:
|
|
96
|
+
# No memory relayout required
|
|
97
|
+
# Note: We should not be able to reach here because it should be shortcutted at passthrough.
|
|
98
|
+
msg_out = message
|
|
99
|
+
else:
|
|
100
|
+
# If the memory is already contiguous in the correct order, np.require won't do anything.
|
|
101
|
+
msg_out = replace(
|
|
102
|
+
message,
|
|
103
|
+
data=np.require(
|
|
104
|
+
message.data, requirements=self.settings.order.upper()[0]
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
dims_out = [message.dims[ix] for ix in self.state.axes_ints]
|
|
109
|
+
data_out = np.transpose(message.data, axes=self.state.axes_ints)
|
|
110
|
+
if self.settings.order is not None:
|
|
111
|
+
data_out = np.require(
|
|
112
|
+
data_out, requirements=self.settings.order.upper()[0]
|
|
113
|
+
)
|
|
114
|
+
msg_out = replace(
|
|
115
|
+
message,
|
|
116
|
+
data=data_out,
|
|
117
|
+
dims=dims_out,
|
|
118
|
+
)
|
|
119
|
+
return msg_out
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class Transpose(
|
|
123
|
+
BaseTransformerUnit[TransposeSettings, AxisArray, AxisArray, TransposeTransformer]
|
|
124
|
+
):
|
|
125
|
+
SETTINGS = TransposeSettings
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def transpose(
|
|
129
|
+
axes: tuple[int | str | EllipsisType, ...] | None = None, order: str | None = None
|
|
130
|
+
) -> TransposeTransformer:
|
|
131
|
+
return TransposeTransformer(TransposeSettings(axes=axes, order=order))
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
import contextlib
|
|
4
|
+
import inspect
|
|
5
|
+
import threading
|
|
6
|
+
from typing import Any, Coroutine, TypeVar
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CoroutineExecutionError(Exception):
|
|
12
|
+
"""Custom exception for coroutine execution failures"""
|
|
13
|
+
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T:
|
|
18
|
+
"""
|
|
19
|
+
Executes an asyncio coroutine synchronously, with enhanced error handling.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
coroutine: The asyncio coroutine to execute
|
|
23
|
+
timeout: Maximum time in seconds to wait for coroutine completion (default: 30)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The result of the coroutine execution
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
CoroutineExecutionError: If execution fails due to threading or event loop issues
|
|
30
|
+
TimeoutError: If execution exceeds the timeout period
|
|
31
|
+
Exception: Any exception raised by the coroutine
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def run_in_new_loop() -> T:
|
|
35
|
+
"""
|
|
36
|
+
Creates and runs a new event loop in the current thread.
|
|
37
|
+
Ensures proper cleanup of the loop.
|
|
38
|
+
"""
|
|
39
|
+
new_loop = asyncio.new_event_loop()
|
|
40
|
+
asyncio.set_event_loop(new_loop)
|
|
41
|
+
try:
|
|
42
|
+
return new_loop.run_until_complete(
|
|
43
|
+
asyncio.wait_for(coroutine, timeout=timeout)
|
|
44
|
+
)
|
|
45
|
+
finally:
|
|
46
|
+
with contextlib.suppress(Exception):
|
|
47
|
+
# Clean up any pending tasks
|
|
48
|
+
pending = asyncio.all_tasks(new_loop)
|
|
49
|
+
for task in pending:
|
|
50
|
+
task.cancel()
|
|
51
|
+
new_loop.run_until_complete(
|
|
52
|
+
asyncio.gather(*pending, return_exceptions=True)
|
|
53
|
+
)
|
|
54
|
+
new_loop.close()
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
loop = asyncio.get_running_loop()
|
|
58
|
+
except RuntimeError:
|
|
59
|
+
try:
|
|
60
|
+
return asyncio.run(asyncio.wait_for(coroutine, timeout=timeout))
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise CoroutineExecutionError(
|
|
63
|
+
f"Failed to execute coroutine: {str(e)}"
|
|
64
|
+
) from e
|
|
65
|
+
|
|
66
|
+
if threading.current_thread() is threading.main_thread():
|
|
67
|
+
if not loop.is_running():
|
|
68
|
+
try:
|
|
69
|
+
return loop.run_until_complete(
|
|
70
|
+
asyncio.wait_for(coroutine, timeout=timeout)
|
|
71
|
+
)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise CoroutineExecutionError(
|
|
74
|
+
f"Failed to execute coroutine in main loop: {str(e)}"
|
|
75
|
+
) from e
|
|
76
|
+
else:
|
|
77
|
+
with ThreadPoolExecutor() as pool:
|
|
78
|
+
try:
|
|
79
|
+
future = pool.submit(run_in_new_loop)
|
|
80
|
+
return future.result(timeout=timeout)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise CoroutineExecutionError(
|
|
83
|
+
f"Failed to execute coroutine in thread: {str(e)}"
|
|
84
|
+
) from e
|
|
85
|
+
else:
|
|
86
|
+
try:
|
|
87
|
+
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
|
88
|
+
return future.result(timeout=timeout)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
raise CoroutineExecutionError(
|
|
91
|
+
f"Failed to execute coroutine threadsafe: {str(e)}"
|
|
92
|
+
) from e
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SyncToAsyncGeneratorWrapper:
|
|
96
|
+
"""
|
|
97
|
+
A wrapper for synchronous generators to be used in an async context.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(self, gen):
|
|
101
|
+
self._gen = gen
|
|
102
|
+
self._closed = False
|
|
103
|
+
# Prime the generator to ready for first send/next call
|
|
104
|
+
try:
|
|
105
|
+
is_not_primed = inspect.getgeneratorstate(self._gen) is inspect.GEN_CREATED
|
|
106
|
+
except AttributeError as e:
|
|
107
|
+
raise TypeError(
|
|
108
|
+
"The provided generator is not a valid generator object"
|
|
109
|
+
) from e
|
|
110
|
+
if is_not_primed:
|
|
111
|
+
try:
|
|
112
|
+
next(self._gen)
|
|
113
|
+
except StopIteration:
|
|
114
|
+
self._closed = True
|
|
115
|
+
except Exception as e:
|
|
116
|
+
raise RuntimeError(f"Failed to prime generator: {e}") from e
|
|
117
|
+
|
|
118
|
+
async def asend(self, value):
|
|
119
|
+
if self._closed:
|
|
120
|
+
raise StopAsyncIteration("Generator is closed")
|
|
121
|
+
try:
|
|
122
|
+
return await asyncio.to_thread(self._gen.send, value)
|
|
123
|
+
except StopIteration as e:
|
|
124
|
+
self._closed = True
|
|
125
|
+
raise StopAsyncIteration("Generator is closed") from e
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise RuntimeError(f"Error while sending value to generator: {e}") from e
|
|
128
|
+
|
|
129
|
+
async def __anext__(self):
|
|
130
|
+
if self._closed:
|
|
131
|
+
raise StopAsyncIteration("Generator is closed")
|
|
132
|
+
try:
|
|
133
|
+
return await asyncio.to_thread(self._gen.__next__)
|
|
134
|
+
except StopIteration as e:
|
|
135
|
+
self._closed = True
|
|
136
|
+
raise StopAsyncIteration("Generator is closed") from e
|
|
137
|
+
except Exception as e:
|
|
138
|
+
raise RuntimeError(
|
|
139
|
+
f"Error while getting next value from generator: {e}"
|
|
140
|
+
) from e
|
|
141
|
+
|
|
142
|
+
async def aclose(self):
|
|
143
|
+
if self._closed:
|
|
144
|
+
return
|
|
145
|
+
try:
|
|
146
|
+
await asyncio.to_thread(self._gen.close)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
raise RuntimeError(f"Error while closing generator: {e}") from e
|
|
149
|
+
finally:
|
|
150
|
+
self._closed = True
|
|
151
|
+
|
|
152
|
+
def __aiter__(self):
|
|
153
|
+
return self
|
|
154
|
+
|
|
155
|
+
def __getattr__(self, name):
|
|
156
|
+
return getattr(self._gen, name)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import typing
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(unsafe_hash=True)
|
|
9
|
+
class SampleTriggerMessage:
|
|
10
|
+
timestamp: float = field(default_factory=time.time)
|
|
11
|
+
"""Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
|
|
12
|
+
|
|
13
|
+
period: tuple[float, float] | None = None
|
|
14
|
+
"""The period around the timestamp, in seconds"""
|
|
15
|
+
|
|
16
|
+
value: typing.Any = None
|
|
17
|
+
"""A value or 'label' associated with the trigger."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class SampleMessage:
|
|
22
|
+
trigger: SampleTriggerMessage
|
|
23
|
+
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
|
+
|
|
25
|
+
sample: AxisArray
|
|
26
|
+
"""The data sampled around the trigger."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
|
|
30
|
+
"""Check if the message is a SampleMessage."""
|
|
31
|
+
return hasattr(message, "trigger")
|
ezmsg/sigproc/util/profile.py
CHANGED
|
@@ -8,6 +8,9 @@ import typing
|
|
|
8
8
|
import ezmsg.core as ez
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
HEADER = "Time,Source,Topic,SampleTime,PerfCounter,Elapsed"
|
|
12
|
+
|
|
13
|
+
|
|
11
14
|
def get_logger_path() -> Path:
|
|
12
15
|
# Retrieve the logfile name from the environment variable
|
|
13
16
|
logfile = os.environ.get("EZMSG_PROFILE", None)
|
|
@@ -26,9 +29,23 @@ def _setup_logger(append: bool = False) -> logging.Logger:
|
|
|
26
29
|
logpath = get_logger_path()
|
|
27
30
|
logpath.parent.mkdir(parents=True, exist_ok=True)
|
|
28
31
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
+
write_header = True
|
|
33
|
+
if logpath.exists() and logpath.is_file():
|
|
34
|
+
if append:
|
|
35
|
+
with open(logpath) as f:
|
|
36
|
+
first_line = f.readline().rstrip()
|
|
37
|
+
if first_line == HEADER:
|
|
38
|
+
write_header = False
|
|
39
|
+
else:
|
|
40
|
+
# Remove the file if appending, but headers do not match
|
|
41
|
+
ezmsg_logger = logging.getLogger("ezmsg")
|
|
42
|
+
ezmsg_logger.warning(
|
|
43
|
+
"Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
|
|
44
|
+
)
|
|
45
|
+
logpath.unlink()
|
|
46
|
+
else:
|
|
47
|
+
# Remove the file if not appending
|
|
48
|
+
logpath.unlink()
|
|
32
49
|
|
|
33
50
|
# Create a logger with the name "ezprofile"
|
|
34
51
|
_logger = logging.getLogger("ezprofile")
|
|
@@ -43,13 +60,13 @@ def _setup_logger(append: bool = False) -> logging.Logger:
|
|
|
43
60
|
# Add the file handler to the logger
|
|
44
61
|
_logger.addHandler(fh)
|
|
45
62
|
|
|
46
|
-
# Add the
|
|
47
|
-
|
|
63
|
+
# Add the header if writing to new file or if header matched header in file.
|
|
64
|
+
if write_header:
|
|
65
|
+
_logger.debug(HEADER)
|
|
48
66
|
|
|
49
67
|
# Set the log message format
|
|
50
68
|
formatter = logging.Formatter(
|
|
51
|
-
"%(asctime)s,%(message)s",
|
|
52
|
-
datefmt="%Y-%m-%dT%H:%M:%S%z"
|
|
69
|
+
"%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z"
|
|
53
70
|
)
|
|
54
71
|
fh.setFormatter(formatter)
|
|
55
72
|
|
|
@@ -89,18 +106,31 @@ def profile_method(trace_oldest: bool = True):
|
|
|
89
106
|
Returns:
|
|
90
107
|
Callable: The decorated function with profiling.
|
|
91
108
|
"""
|
|
109
|
+
|
|
92
110
|
def profiling_decorator(func: typing.Callable):
|
|
93
111
|
@functools.wraps(func)
|
|
94
112
|
def wrapped_func(caller, *args, **kwargs):
|
|
95
113
|
start = time.perf_counter()
|
|
96
114
|
res = func(caller, *args, **kwargs)
|
|
97
115
|
stop = time.perf_counter()
|
|
98
|
-
source =
|
|
116
|
+
source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
|
|
99
117
|
topic = f"{caller.address}"
|
|
100
118
|
samp_time = _process_obj(res, trace_oldest=trace_oldest)
|
|
101
|
-
logger.debug(
|
|
119
|
+
logger.debug(
|
|
120
|
+
",".join(
|
|
121
|
+
[
|
|
122
|
+
source,
|
|
123
|
+
topic,
|
|
124
|
+
f"{samp_time}",
|
|
125
|
+
f"{stop}",
|
|
126
|
+
f"{(stop - start) * 1e3:0.4f}",
|
|
127
|
+
]
|
|
128
|
+
)
|
|
129
|
+
)
|
|
102
130
|
return res
|
|
131
|
+
|
|
103
132
|
return wrapped_func if logger.level == logging.DEBUG else func
|
|
133
|
+
|
|
104
134
|
return profiling_decorator
|
|
105
135
|
|
|
106
136
|
|
|
@@ -115,17 +145,30 @@ def profile_subpub(trace_oldest: bool = True):
|
|
|
115
145
|
Returns:
|
|
116
146
|
Callable: The decorated async task with profiling.
|
|
117
147
|
"""
|
|
148
|
+
|
|
118
149
|
def profiling_decorator(func: typing.Callable):
|
|
119
150
|
@functools.wraps(func)
|
|
120
|
-
async def wrapped_task(unit: ez.Unit, msg: typing.Any = None)
|
|
121
|
-
source =
|
|
151
|
+
async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
|
|
152
|
+
source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
|
|
122
153
|
topic = f"{unit.address}"
|
|
123
154
|
start = time.perf_counter()
|
|
124
155
|
async for stream, obj in func(unit, msg):
|
|
125
156
|
stop = time.perf_counter()
|
|
126
157
|
samp_time = _process_obj(obj, trace_oldest=trace_oldest)
|
|
127
|
-
logger.debug(
|
|
158
|
+
logger.debug(
|
|
159
|
+
",".join(
|
|
160
|
+
[
|
|
161
|
+
source,
|
|
162
|
+
topic,
|
|
163
|
+
f"{samp_time}",
|
|
164
|
+
f"{stop}",
|
|
165
|
+
f"{(stop - start) * 1e3:0.4f}",
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
)
|
|
128
169
|
start = stop
|
|
129
170
|
yield stream, obj
|
|
171
|
+
|
|
130
172
|
return wrapped_task if logger.level == logging.DEBUG else func
|
|
173
|
+
|
|
131
174
|
return profiling_decorator
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from types import UnionType
|
|
2
|
+
import typing
|
|
3
|
+
from typing_extensions import get_original_bases
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def resolve_typevar(cls: type, target_typevar: typing.TypeVar) -> type:
|
|
7
|
+
"""
|
|
8
|
+
Resolve the concrete type bound to a TypeVar in a class hierarchy.
|
|
9
|
+
This function traverses the method resolution order (MRO) of the class
|
|
10
|
+
and checks the original bases of each class in the MRO for the TypeVar.
|
|
11
|
+
If the TypeVar is found, it returns the concrete type bound to it.
|
|
12
|
+
If the TypeVar is not found, it raises a TypeError.
|
|
13
|
+
Args:
|
|
14
|
+
cls (type): The class to inspect.
|
|
15
|
+
target_typevar (typing.TypeVar): The TypeVar to resolve.
|
|
16
|
+
Returns:
|
|
17
|
+
type: The concrete type bound to the TypeVar.
|
|
18
|
+
"""
|
|
19
|
+
for base in cls.__mro__:
|
|
20
|
+
orig_bases = get_original_bases(base)
|
|
21
|
+
for orig_base in orig_bases:
|
|
22
|
+
origin = typing.get_origin(orig_base)
|
|
23
|
+
if origin is None:
|
|
24
|
+
continue
|
|
25
|
+
params = getattr(origin, "__parameters__", ())
|
|
26
|
+
if not params:
|
|
27
|
+
continue
|
|
28
|
+
if target_typevar in params:
|
|
29
|
+
index = params.index(target_typevar)
|
|
30
|
+
args = typing.get_args(orig_base)
|
|
31
|
+
try:
|
|
32
|
+
return args[index]
|
|
33
|
+
except IndexError:
|
|
34
|
+
pass
|
|
35
|
+
raise TypeError(f"Could not resolve {target_typevar} in {cls}")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
TypeLike = typing.Union[type[typing.Any], typing.Any, type(None), None]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def check_message_type_compatibility(type1: TypeLike, type2: TypeLike) -> bool:
|
|
42
|
+
"""
|
|
43
|
+
Check if two types are compatible for message passing.
|
|
44
|
+
Returns True if:
|
|
45
|
+
- Both are None/NoneType
|
|
46
|
+
- Either is typing.Any
|
|
47
|
+
- type1 is a subclass of type2, which includes
|
|
48
|
+
- type1 and type2 are concrete types and type1 is a subclass of type2
|
|
49
|
+
- type1 is None/NoneType and type2 is typing.Optional, or
|
|
50
|
+
- type1 is subtype of the non-None inner type of type2 if type2 is Optional
|
|
51
|
+
- type1 is a Union/Optional type and all inner types are compatible with type2
|
|
52
|
+
Args:
|
|
53
|
+
type1: First type to compare
|
|
54
|
+
type2: Second type to compare
|
|
55
|
+
Returns:
|
|
56
|
+
bool: True if the types are compatible, False otherwise
|
|
57
|
+
"""
|
|
58
|
+
# If either is Any, they are compatible
|
|
59
|
+
if type1 is typing.Any or type2 is typing.Any:
|
|
60
|
+
return True
|
|
61
|
+
|
|
62
|
+
# Handle None as NoneType
|
|
63
|
+
if type1 is None:
|
|
64
|
+
type1 = type(None)
|
|
65
|
+
if type2 is None:
|
|
66
|
+
type2 = type(None)
|
|
67
|
+
|
|
68
|
+
# Handle if type1 is Optional/Union type
|
|
69
|
+
if typing.get_origin(type1) in {typing.Union, UnionType}:
|
|
70
|
+
return all(
|
|
71
|
+
check_message_type_compatibility(inner_type, type2)
|
|
72
|
+
for inner_type in typing.get_args(type1)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Regular issubclass check. Handles cases like:
|
|
76
|
+
# - type1 is a subclass of concrete type2
|
|
77
|
+
# - type1 is a subclass of the inner type of type2 if type2 is Optional
|
|
78
|
+
# - type1 is a subclass of one of the inner types of type2 if type2 is Union
|
|
79
|
+
# - type1 is NoneType and type2 is Optional or Union[None, ...] or Union[NoneType, ...]
|
|
80
|
+
try:
|
|
81
|
+
return issubclass(type1, type2)
|
|
82
|
+
except TypeError:
|
|
83
|
+
return False
|