ezmsg-baseproc 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/baseproc/__init__.py +155 -0
- ezmsg/baseproc/__version__.py +34 -0
- ezmsg/baseproc/composite.py +323 -0
- ezmsg/baseproc/processor.py +209 -0
- ezmsg/baseproc/protocols.py +147 -0
- ezmsg/baseproc/stateful.py +323 -0
- ezmsg/baseproc/units.py +282 -0
- ezmsg/baseproc/util/__init__.py +1 -0
- ezmsg/baseproc/util/asio.py +138 -0
- ezmsg/baseproc/util/message.py +31 -0
- ezmsg/baseproc/util/profile.py +171 -0
- ezmsg/baseproc/util/typeresolution.py +81 -0
- ezmsg_baseproc-1.0.dist-info/METADATA +106 -0
- ezmsg_baseproc-1.0.dist-info/RECORD +16 -0
- ezmsg_baseproc-1.0.dist-info/WHEEL +4 -0
- ezmsg_baseproc-1.0.dist-info/licenses/LICENSE +21 -0
ezmsg/baseproc/units.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""Base Unit classes for ezmsg integration."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import traceback
|
|
5
|
+
import typing
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
import ezmsg.core as ez
|
|
9
|
+
from ezmsg.util.generator import GenState
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
|
+
|
|
12
|
+
from .composite import CompositeProcessor
|
|
13
|
+
from .processor import BaseConsumer, BaseProducer, BaseTransformer
|
|
14
|
+
from .protocols import MessageInType, MessageOutType, SettingsType
|
|
15
|
+
from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
|
|
16
|
+
from .util.message import SampleMessage
|
|
17
|
+
from .util.profile import profile_subpub
|
|
18
|
+
from .util.typeresolution import resolve_typevar
|
|
19
|
+
|
|
20
|
+
# --- Type variables for Unit classes ---
|
|
21
|
+
ProducerType = typing.TypeVar("ProducerType", bound=BaseProducer)
|
|
22
|
+
ConsumerType = typing.TypeVar("ConsumerType", bound=BaseConsumer | BaseStatefulConsumer)
|
|
23
|
+
TransformerType = typing.TypeVar(
|
|
24
|
+
"TransformerType",
|
|
25
|
+
bound=BaseTransformer | BaseStatefulTransformer | CompositeProcessor,
|
|
26
|
+
)
|
|
27
|
+
AdaptiveTransformerType = typing.TypeVar("AdaptiveTransformerType", bound=BaseAdaptiveTransformer)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_base_producer_type(cls: type) -> type:
|
|
31
|
+
return resolve_typevar(cls, ProducerType)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_base_consumer_type(cls: type) -> type:
|
|
35
|
+
return resolve_typevar(cls, ConsumerType)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_base_transformer_type(cls: type) -> type:
|
|
39
|
+
return resolve_typevar(cls, TransformerType)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_base_adaptive_transformer_type(cls: type) -> type:
|
|
43
|
+
return resolve_typevar(cls, AdaptiveTransformerType)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# --- Base classes for ezmsg Unit with specific processing capabilities ---
|
|
47
|
+
class BaseProducerUnit(ez.Unit, ABC, typing.Generic[SettingsType, MessageOutType, ProducerType]):
|
|
48
|
+
"""
|
|
49
|
+
Base class for producer units -- i.e. units that generate messages without consuming inputs.
|
|
50
|
+
Implement a new Unit as follows:
|
|
51
|
+
|
|
52
|
+
class CustomUnit(BaseProducerUnit[
|
|
53
|
+
CustomProducerSettings, # SettingsType
|
|
54
|
+
AxisArray, # MessageOutType
|
|
55
|
+
CustomProducer, # ProducerType
|
|
56
|
+
]):
|
|
57
|
+
SETTINGS = CustomProducerSettings
|
|
58
|
+
|
|
59
|
+
... that's all!
|
|
60
|
+
|
|
61
|
+
Where CustomProducerSettings, and CustomProducer are custom implementations of ez.Settings,
|
|
62
|
+
and BaseProducer or BaseStatefulProducer, respectively.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
INPUT_SETTINGS = ez.InputStream(SettingsType)
|
|
66
|
+
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
67
|
+
|
|
68
|
+
async def initialize(self) -> None:
|
|
69
|
+
self.create_producer()
|
|
70
|
+
|
|
71
|
+
def create_producer(self) -> None:
|
|
72
|
+
# self.producer: ProducerType
|
|
73
|
+
"""Create the producer instance from settings."""
|
|
74
|
+
producer_type = get_base_producer_type(self.__class__)
|
|
75
|
+
self.producer = producer_type(settings=self.SETTINGS)
|
|
76
|
+
|
|
77
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
78
|
+
async def on_settings(self, msg: SettingsType) -> None:
|
|
79
|
+
"""
|
|
80
|
+
Receive a settings message, override self.SETTINGS, and re-create the producer.
|
|
81
|
+
Child classes that wish to have fine-grained control over whether the
|
|
82
|
+
core producer resets on settings changes should override this method.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
msg: a settings message.
|
|
86
|
+
"""
|
|
87
|
+
self.apply_settings(msg) # type: ignore
|
|
88
|
+
self.create_producer()
|
|
89
|
+
|
|
90
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
91
|
+
async def produce(self) -> typing.AsyncGenerator:
|
|
92
|
+
while True:
|
|
93
|
+
out = await self.producer.__acall__()
|
|
94
|
+
if out is not None: # and math.prod(out.data.shape) > 0:
|
|
95
|
+
yield self.OUTPUT_SIGNAL, out
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class BaseProcessorUnit(ez.Unit, ABC, typing.Generic[SettingsType]):
|
|
99
|
+
"""
|
|
100
|
+
Base class for processor units -- i.e. units that process messages.
|
|
101
|
+
This is an abstract base class that provides common functionality for consumer and transformer
|
|
102
|
+
units. You probably do not want to inherit from this class directly as you would need to define
|
|
103
|
+
a custom implementation of `create_processor`.
|
|
104
|
+
Refer instead to BaseConsumerUnit or BaseTransformerUnit.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
INPUT_SETTINGS = ez.InputStream(SettingsType)
|
|
108
|
+
|
|
109
|
+
async def initialize(self) -> None:
|
|
110
|
+
self.create_processor()
|
|
111
|
+
|
|
112
|
+
@abstractmethod
|
|
113
|
+
def create_processor(self) -> None: ...
|
|
114
|
+
|
|
115
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
116
|
+
async def on_settings(self, msg: SettingsType) -> None:
|
|
117
|
+
"""
|
|
118
|
+
Receive a settings message, override self.SETTINGS, and re-create the processor.
|
|
119
|
+
Child classes that wish to have fine-grained control over whether the
|
|
120
|
+
core processor resets on settings changes should override this method.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
msg: a settings message.
|
|
124
|
+
"""
|
|
125
|
+
self.apply_settings(msg) # type: ignore
|
|
126
|
+
self.create_processor()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class BaseConsumerUnit(
|
|
130
|
+
BaseProcessorUnit[SettingsType],
|
|
131
|
+
ABC,
|
|
132
|
+
typing.Generic[SettingsType, MessageInType, ConsumerType],
|
|
133
|
+
):
|
|
134
|
+
"""
|
|
135
|
+
Base class for consumer units -- i.e. units that receive messages but do not return results.
|
|
136
|
+
Implement a new Unit as follows:
|
|
137
|
+
|
|
138
|
+
class CustomUnit(BaseConsumerUnit[
|
|
139
|
+
CustomConsumerSettings, # SettingsType
|
|
140
|
+
AxisArray, # MessageInType
|
|
141
|
+
CustomConsumer, # ConsumerType
|
|
142
|
+
]):
|
|
143
|
+
SETTINGS = CustomConsumerSettings
|
|
144
|
+
|
|
145
|
+
... that's all!
|
|
146
|
+
|
|
147
|
+
Where CustomConsumerSettings and CustomConsumer are custom implementations of:
|
|
148
|
+
- ez.Settings for settings
|
|
149
|
+
- BaseConsumer or BaseStatefulConsumer for the consumer implementation
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
153
|
+
|
|
154
|
+
def create_processor(self):
|
|
155
|
+
# self.processor: ConsumerType[SettingsType, MessageInType, StateType]
|
|
156
|
+
"""Create the consumer instance from settings."""
|
|
157
|
+
consumer_type = get_base_consumer_type(self.__class__)
|
|
158
|
+
self.processor = consumer_type(settings=self.SETTINGS)
|
|
159
|
+
|
|
160
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
161
|
+
async def on_signal(self, message: MessageInType):
|
|
162
|
+
"""
|
|
163
|
+
Consume the message.
|
|
164
|
+
Args:
|
|
165
|
+
message: Input message to be consumed
|
|
166
|
+
"""
|
|
167
|
+
await self.processor.__acall__(message)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class BaseTransformerUnit(
|
|
171
|
+
BaseProcessorUnit[SettingsType],
|
|
172
|
+
ABC,
|
|
173
|
+
typing.Generic[SettingsType, MessageInType, MessageOutType, TransformerType],
|
|
174
|
+
):
|
|
175
|
+
"""
|
|
176
|
+
Base class for transformer units -- i.e. units that transform input messages into output messages.
|
|
177
|
+
Implement a new Unit as follows:
|
|
178
|
+
|
|
179
|
+
class CustomUnit(BaseTransformerUnit[
|
|
180
|
+
CustomTransformerSettings, # SettingsType
|
|
181
|
+
AxisArray, # MessageInType
|
|
182
|
+
AxisArray, # MessageOutType
|
|
183
|
+
CustomTransformer, # TransformerType
|
|
184
|
+
]):
|
|
185
|
+
SETTINGS = CustomTransformerSettings
|
|
186
|
+
|
|
187
|
+
... that's all!
|
|
188
|
+
|
|
189
|
+
Where CustomTransformerSettings and CustomTransformer are custom implementations of:
|
|
190
|
+
- ez.Settings for settings
|
|
191
|
+
- One of these transformer types:
|
|
192
|
+
* BaseTransformer
|
|
193
|
+
* BaseStatefulTransformer
|
|
194
|
+
* CompositeProcessor
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
198
|
+
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
199
|
+
|
|
200
|
+
def create_processor(self):
|
|
201
|
+
# self.processor: TransformerType[SettingsType, MessageInType, MessageOutType, StateType]
|
|
202
|
+
"""Create the transformer instance from settings."""
|
|
203
|
+
transformer_type = get_base_transformer_type(self.__class__)
|
|
204
|
+
self.processor = transformer_type(settings=self.SETTINGS)
|
|
205
|
+
|
|
206
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
207
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
208
|
+
@profile_subpub(trace_oldest=False)
|
|
209
|
+
async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
|
|
210
|
+
result = await self.processor.__acall__(message)
|
|
211
|
+
if result is not None: # and math.prod(result.data.shape) > 0:
|
|
212
|
+
yield self.OUTPUT_SIGNAL, result
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class BaseAdaptiveTransformerUnit(
|
|
216
|
+
BaseProcessorUnit[SettingsType],
|
|
217
|
+
ABC,
|
|
218
|
+
typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
|
|
219
|
+
):
|
|
220
|
+
INPUT_SAMPLE = ez.InputStream(SampleMessage)
|
|
221
|
+
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
222
|
+
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
223
|
+
|
|
224
|
+
def create_processor(self) -> None:
|
|
225
|
+
# self.processor: AdaptiveTransformerType[SettingsType, MessageInType, MessageOutType, StateType]
|
|
226
|
+
"""Create the adaptive transformer instance from settings."""
|
|
227
|
+
adaptive_transformer_type = get_base_adaptive_transformer_type(self.__class__)
|
|
228
|
+
self.processor = adaptive_transformer_type(settings=self.SETTINGS)
|
|
229
|
+
|
|
230
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
231
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
232
|
+
@profile_subpub(trace_oldest=False)
|
|
233
|
+
async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
|
|
234
|
+
result = await self.processor.__acall__(message)
|
|
235
|
+
if result is not None: # and math.prod(result.data.shape) > 0:
|
|
236
|
+
yield self.OUTPUT_SIGNAL, result
|
|
237
|
+
|
|
238
|
+
@ez.subscriber(INPUT_SAMPLE)
|
|
239
|
+
async def on_sample(self, msg: SampleMessage) -> None:
|
|
240
|
+
await self.processor.apartial_fit(msg)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# Legacy class
|
|
244
|
+
class GenAxisArray(ez.Unit):
|
|
245
|
+
STATE = GenState
|
|
246
|
+
|
|
247
|
+
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
248
|
+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
249
|
+
INPUT_SETTINGS = ez.InputStream(ez.Settings)
|
|
250
|
+
|
|
251
|
+
async def initialize(self) -> None:
|
|
252
|
+
self.construct_generator()
|
|
253
|
+
|
|
254
|
+
# Method to be implemented by subclasses to construct the specific generator
|
|
255
|
+
def construct_generator(self):
|
|
256
|
+
raise NotImplementedError
|
|
257
|
+
|
|
258
|
+
@ez.subscriber(INPUT_SETTINGS)
|
|
259
|
+
async def on_settings(self, msg: ez.Settings) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Update unit settings and reset generator.
|
|
262
|
+
Note: Not all units will require a full reset with new settings.
|
|
263
|
+
Override this method to implement a selective reset.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
msg: Instance of SETTINGS object.
|
|
267
|
+
"""
|
|
268
|
+
self.apply_settings(msg)
|
|
269
|
+
self.construct_generator()
|
|
270
|
+
|
|
271
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
272
|
+
@ez.publisher(OUTPUT_SIGNAL)
|
|
273
|
+
@profile_subpub(trace_oldest=False)
|
|
274
|
+
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
275
|
+
try:
|
|
276
|
+
ret = self.STATE.gen.send(message)
|
|
277
|
+
if math.prod(ret.data.shape) > 0:
|
|
278
|
+
yield self.OUTPUT_SIGNAL, ret
|
|
279
|
+
except (StopIteration, GeneratorExit):
|
|
280
|
+
ez.logger.debug(f"Generator closed in {self.address}")
|
|
281
|
+
except Exception:
|
|
282
|
+
ez.logger.info(traceback.format_exc())
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Utility modules for ezmsg-baseproc
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import contextlib
|
|
3
|
+
import inspect
|
|
4
|
+
import threading
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
+
from typing import Any, Coroutine, TypeVar
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CoroutineExecutionError(Exception):
|
|
12
|
+
"""Custom exception for coroutine execution failures"""
|
|
13
|
+
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def run_coroutine_sync(coroutine: Coroutine[Any, Any, T], timeout: float = 30) -> T:
|
|
18
|
+
"""
|
|
19
|
+
Executes an asyncio coroutine synchronously, with enhanced error handling.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
coroutine: The asyncio coroutine to execute
|
|
23
|
+
timeout: Maximum time in seconds to wait for coroutine completion (default: 30)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The result of the coroutine execution
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
CoroutineExecutionError: If execution fails due to threading or event loop issues
|
|
30
|
+
TimeoutError: If execution exceeds the timeout period
|
|
31
|
+
Exception: Any exception raised by the coroutine
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def run_in_new_loop() -> T:
|
|
35
|
+
"""
|
|
36
|
+
Creates and runs a new event loop in the current thread.
|
|
37
|
+
Ensures proper cleanup of the loop.
|
|
38
|
+
"""
|
|
39
|
+
new_loop = asyncio.new_event_loop()
|
|
40
|
+
asyncio.set_event_loop(new_loop)
|
|
41
|
+
try:
|
|
42
|
+
return new_loop.run_until_complete(asyncio.wait_for(coroutine, timeout=timeout))
|
|
43
|
+
finally:
|
|
44
|
+
with contextlib.suppress(Exception):
|
|
45
|
+
# Clean up any pending tasks
|
|
46
|
+
pending = asyncio.all_tasks(new_loop)
|
|
47
|
+
for task in pending:
|
|
48
|
+
task.cancel()
|
|
49
|
+
new_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
|
50
|
+
new_loop.close()
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
loop = asyncio.get_running_loop()
|
|
54
|
+
except RuntimeError:
|
|
55
|
+
try:
|
|
56
|
+
return asyncio.run(asyncio.wait_for(coroutine, timeout=timeout))
|
|
57
|
+
except Exception as e:
|
|
58
|
+
raise CoroutineExecutionError(f"Failed to execute coroutine: {str(e)}") from e
|
|
59
|
+
|
|
60
|
+
if threading.current_thread() is threading.main_thread():
|
|
61
|
+
if not loop.is_running():
|
|
62
|
+
try:
|
|
63
|
+
return loop.run_until_complete(asyncio.wait_for(coroutine, timeout=timeout))
|
|
64
|
+
except Exception as e:
|
|
65
|
+
raise CoroutineExecutionError(f"Failed to execute coroutine in main loop: {str(e)}") from e
|
|
66
|
+
else:
|
|
67
|
+
with ThreadPoolExecutor() as pool:
|
|
68
|
+
try:
|
|
69
|
+
future = pool.submit(run_in_new_loop)
|
|
70
|
+
return future.result(timeout=timeout)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
raise CoroutineExecutionError(f"Failed to execute coroutine in thread: {str(e)}") from e
|
|
73
|
+
else:
|
|
74
|
+
try:
|
|
75
|
+
future = asyncio.run_coroutine_threadsafe(coroutine, loop)
|
|
76
|
+
return future.result(timeout=timeout)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
raise CoroutineExecutionError(f"Failed to execute coroutine threadsafe: {str(e)}") from e
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SyncToAsyncGeneratorWrapper:
|
|
82
|
+
"""
|
|
83
|
+
A wrapper for synchronous generators to be used in an async context.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, gen):
|
|
87
|
+
self._gen = gen
|
|
88
|
+
self._closed = False
|
|
89
|
+
# Prime the generator to ready for first send/next call
|
|
90
|
+
try:
|
|
91
|
+
is_not_primed = inspect.getgeneratorstate(self._gen) is inspect.GEN_CREATED
|
|
92
|
+
except AttributeError as e:
|
|
93
|
+
raise TypeError("The provided generator is not a valid generator object") from e
|
|
94
|
+
if is_not_primed:
|
|
95
|
+
try:
|
|
96
|
+
next(self._gen)
|
|
97
|
+
except StopIteration:
|
|
98
|
+
self._closed = True
|
|
99
|
+
except Exception as e:
|
|
100
|
+
raise RuntimeError(f"Failed to prime generator: {e}") from e
|
|
101
|
+
|
|
102
|
+
async def asend(self, value):
|
|
103
|
+
if self._closed:
|
|
104
|
+
raise StopAsyncIteration("Generator is closed")
|
|
105
|
+
try:
|
|
106
|
+
return await asyncio.to_thread(self._gen.send, value)
|
|
107
|
+
except StopIteration as e:
|
|
108
|
+
self._closed = True
|
|
109
|
+
raise StopAsyncIteration("Generator is closed") from e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
raise RuntimeError(f"Error while sending value to generator: {e}") from e
|
|
112
|
+
|
|
113
|
+
async def __anext__(self):
|
|
114
|
+
if self._closed:
|
|
115
|
+
raise StopAsyncIteration("Generator is closed")
|
|
116
|
+
try:
|
|
117
|
+
return await asyncio.to_thread(self._gen.__next__)
|
|
118
|
+
except StopIteration as e:
|
|
119
|
+
self._closed = True
|
|
120
|
+
raise StopAsyncIteration("Generator is closed") from e
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise RuntimeError(f"Error while getting next value from generator: {e}") from e
|
|
123
|
+
|
|
124
|
+
async def aclose(self):
|
|
125
|
+
if self._closed:
|
|
126
|
+
return
|
|
127
|
+
try:
|
|
128
|
+
await asyncio.to_thread(self._gen.close)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
raise RuntimeError(f"Error while closing generator: {e}") from e
|
|
131
|
+
finally:
|
|
132
|
+
self._closed = True
|
|
133
|
+
|
|
134
|
+
def __aiter__(self):
|
|
135
|
+
return self
|
|
136
|
+
|
|
137
|
+
def __getattr__(self, name):
|
|
138
|
+
return getattr(self._gen, name)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import typing
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(unsafe_hash=True)
|
|
9
|
+
class SampleTriggerMessage:
|
|
10
|
+
timestamp: float = field(default_factory=time.time)
|
|
11
|
+
"""Time of the trigger, in seconds. The Clock depends on the input but defaults to time.time"""
|
|
12
|
+
|
|
13
|
+
period: tuple[float, float] | None = None
|
|
14
|
+
"""The period around the timestamp, in seconds"""
|
|
15
|
+
|
|
16
|
+
value: typing.Any = None
|
|
17
|
+
"""A value or 'label' associated with the trigger."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class SampleMessage:
|
|
22
|
+
trigger: SampleTriggerMessage
|
|
23
|
+
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
|
+
|
|
25
|
+
sample: AxisArray
|
|
26
|
+
"""The data sampled around the trigger."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
|
|
30
|
+
"""Check if the message is a SampleMessage."""
|
|
31
|
+
return hasattr(message, "trigger")
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import typing
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import ezmsg.core as ez
|
|
9
|
+
|
|
10
|
+
HEADER = "Time,Source,Topic,SampleTime,PerfCounter,Elapsed"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_logger_path() -> Path:
|
|
14
|
+
# Retrieve the logfile name from the environment variable
|
|
15
|
+
logfile = os.environ.get("EZMSG_PROFILE", None)
|
|
16
|
+
|
|
17
|
+
# Determine the log file path, defaulting to "ezprofiler.log" if not set
|
|
18
|
+
logpath = Path(logfile or "ezprofiler.log")
|
|
19
|
+
|
|
20
|
+
# If the log path is not absolute, prepend it with the user's home directory and ".ezmsg/profile"
|
|
21
|
+
if not logpath.is_absolute():
|
|
22
|
+
logpath = Path.home() / ".ezmsg" / "profile" / logpath
|
|
23
|
+
|
|
24
|
+
return logpath
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _setup_logger(append: bool = False) -> logging.Logger:
|
|
28
|
+
logpath = get_logger_path()
|
|
29
|
+
logpath.parent.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
|
|
31
|
+
write_header = True
|
|
32
|
+
if logpath.exists() and logpath.is_file():
|
|
33
|
+
if append:
|
|
34
|
+
with open(logpath) as f:
|
|
35
|
+
first_line = f.readline().rstrip()
|
|
36
|
+
if first_line == HEADER:
|
|
37
|
+
write_header = False
|
|
38
|
+
else:
|
|
39
|
+
# Remove the file if appending, but headers do not match
|
|
40
|
+
ezmsg_logger = logging.getLogger("ezmsg")
|
|
41
|
+
ezmsg_logger.warning(
|
|
42
|
+
"Profiling header mismatch: please make sure to use the same version of ezmsg for all processes."
|
|
43
|
+
)
|
|
44
|
+
logpath.unlink()
|
|
45
|
+
else:
|
|
46
|
+
# Remove the file if not appending
|
|
47
|
+
logpath.unlink()
|
|
48
|
+
|
|
49
|
+
# Create a logger with the name "ezprofile"
|
|
50
|
+
_logger = logging.getLogger("ezprofile")
|
|
51
|
+
|
|
52
|
+
# Set the logger's level to EZMSG_LOGLEVEL env var value if it exists, otherwise INFO
|
|
53
|
+
_logger.setLevel(os.environ.get("EZMSG_LOGLEVEL", "INFO").upper())
|
|
54
|
+
|
|
55
|
+
# Create a file handler to write log messages to the log file
|
|
56
|
+
fh = logging.FileHandler(logpath)
|
|
57
|
+
fh.setLevel(logging.DEBUG) # Set the file handler log level to DEBUG
|
|
58
|
+
|
|
59
|
+
# Add the file handler to the logger
|
|
60
|
+
_logger.addHandler(fh)
|
|
61
|
+
|
|
62
|
+
# Add the header if writing to new file or if header matched header in file.
|
|
63
|
+
if write_header:
|
|
64
|
+
_logger.debug(HEADER)
|
|
65
|
+
|
|
66
|
+
# Set the log message format
|
|
67
|
+
formatter = logging.Formatter("%(asctime)s,%(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z")
|
|
68
|
+
fh.setFormatter(formatter)
|
|
69
|
+
|
|
70
|
+
return _logger
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
logger = _setup_logger(append=True)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _process_obj(obj, trace_oldest: bool = True):
|
|
77
|
+
samp_time = None
|
|
78
|
+
if hasattr(obj, "axes") and ("time" in obj.axes or "win" in obj.axes):
|
|
79
|
+
axis = "win" if "win" in obj.axes else "time"
|
|
80
|
+
ax = obj.get_axis(axis)
|
|
81
|
+
len = obj.data.shape[obj.get_axis_idx(axis)]
|
|
82
|
+
if len > 0:
|
|
83
|
+
idx = 0 if trace_oldest else (len - 1)
|
|
84
|
+
if hasattr(ax, "data"):
|
|
85
|
+
samp_time = ax.data[idx]
|
|
86
|
+
else:
|
|
87
|
+
samp_time = ax.value(idx)
|
|
88
|
+
if ax == "win" and "time" in obj.axes:
|
|
89
|
+
if hasattr(obj.axes["time"], "data"):
|
|
90
|
+
samp_time += obj.axes["time"].data[idx]
|
|
91
|
+
else:
|
|
92
|
+
samp_time += obj.axes["time"].value(idx)
|
|
93
|
+
return samp_time
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def profile_method(trace_oldest: bool = True):
|
|
97
|
+
"""
|
|
98
|
+
Decorator to profile a method by logging its execution time and other details.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Callable: The decorated function with profiling.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def profiling_decorator(func: typing.Callable):
|
|
108
|
+
@functools.wraps(func)
|
|
109
|
+
def wrapped_func(caller, *args, **kwargs):
|
|
110
|
+
start = time.perf_counter()
|
|
111
|
+
res = func(caller, *args, **kwargs)
|
|
112
|
+
stop = time.perf_counter()
|
|
113
|
+
source = ".".join((caller.__class__.__module__, caller.__class__.__name__))
|
|
114
|
+
topic = f"{caller.address}"
|
|
115
|
+
samp_time = _process_obj(res, trace_oldest=trace_oldest)
|
|
116
|
+
logger.debug(
|
|
117
|
+
",".join(
|
|
118
|
+
[
|
|
119
|
+
source,
|
|
120
|
+
topic,
|
|
121
|
+
f"{samp_time}",
|
|
122
|
+
f"{stop}",
|
|
123
|
+
f"{(stop - start) * 1e3:0.4f}",
|
|
124
|
+
]
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
return res
|
|
128
|
+
|
|
129
|
+
return wrapped_func if logger.level == logging.DEBUG else func
|
|
130
|
+
|
|
131
|
+
return profiling_decorator
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def profile_subpub(trace_oldest: bool = True):
|
|
135
|
+
"""
|
|
136
|
+
Decorator to profile a subscriber-publisher method in an ezmsg Unit
|
|
137
|
+
by logging its execution time and other details.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
trace_oldest (bool): If True, trace the oldest sample time; otherwise, trace the newest.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Callable: The decorated async task with profiling.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def profiling_decorator(func: typing.Callable):
|
|
147
|
+
@functools.wraps(func)
|
|
148
|
+
async def wrapped_task(unit: ez.Unit, msg: typing.Any = None):
|
|
149
|
+
source = ".".join((unit.__class__.__module__, unit.__class__.__name__))
|
|
150
|
+
topic = f"{unit.address}"
|
|
151
|
+
start = time.perf_counter()
|
|
152
|
+
async for stream, obj in func(unit, msg):
|
|
153
|
+
stop = time.perf_counter()
|
|
154
|
+
samp_time = _process_obj(obj, trace_oldest=trace_oldest)
|
|
155
|
+
logger.debug(
|
|
156
|
+
",".join(
|
|
157
|
+
[
|
|
158
|
+
source,
|
|
159
|
+
topic,
|
|
160
|
+
f"{samp_time}",
|
|
161
|
+
f"{stop}",
|
|
162
|
+
f"{(stop - start) * 1e3:0.4f}",
|
|
163
|
+
]
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
start = stop
|
|
167
|
+
yield stream, obj
|
|
168
|
+
|
|
169
|
+
return wrapped_task if logger.level == logging.DEBUG else func
|
|
170
|
+
|
|
171
|
+
return profiling_decorator
|