ezmsg-sigproc 2.4.1__py3-none-any.whl → 2.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/sigproc/__version__.py +2 -2
- ezmsg/sigproc/activation.py +5 -11
- ezmsg/sigproc/adaptive_lattice_notch.py +11 -29
- ezmsg/sigproc/affinetransform.py +13 -38
- ezmsg/sigproc/aggregate.py +13 -30
- ezmsg/sigproc/bandpower.py +7 -15
- ezmsg/sigproc/base.py +141 -1276
- ezmsg/sigproc/butterworthfilter.py +8 -16
- ezmsg/sigproc/butterworthzerophase.py +123 -0
- ezmsg/sigproc/cheby.py +4 -10
- ezmsg/sigproc/combfilter.py +5 -8
- ezmsg/sigproc/decimate.py +2 -6
- ezmsg/sigproc/denormalize.py +6 -11
- ezmsg/sigproc/detrend.py +3 -4
- ezmsg/sigproc/diff.py +8 -17
- ezmsg/sigproc/downsample.py +6 -14
- ezmsg/sigproc/ewma.py +11 -27
- ezmsg/sigproc/ewmfilter.py +1 -1
- ezmsg/sigproc/extract_axis.py +3 -4
- ezmsg/sigproc/fbcca.py +31 -56
- ezmsg/sigproc/filter.py +19 -45
- ezmsg/sigproc/filterbank.py +33 -70
- ezmsg/sigproc/filterbankdesign.py +5 -12
- ezmsg/sigproc/fir_hilbert.py +336 -0
- ezmsg/sigproc/fir_pmc.py +209 -0
- ezmsg/sigproc/firfilter.py +12 -14
- ezmsg/sigproc/gaussiansmoothing.py +5 -9
- ezmsg/sigproc/kaiser.py +11 -15
- ezmsg/sigproc/math/abs.py +1 -3
- ezmsg/sigproc/math/add.py +121 -0
- ezmsg/sigproc/math/clip.py +1 -1
- ezmsg/sigproc/math/difference.py +98 -36
- ezmsg/sigproc/math/invert.py +1 -3
- ezmsg/sigproc/math/log.py +2 -6
- ezmsg/sigproc/messages.py +1 -2
- ezmsg/sigproc/quantize.py +2 -4
- ezmsg/sigproc/resample.py +13 -34
- ezmsg/sigproc/rollingscaler.py +232 -0
- ezmsg/sigproc/sampler.py +17 -35
- ezmsg/sigproc/scaler.py +8 -18
- ezmsg/sigproc/signalinjector.py +6 -16
- ezmsg/sigproc/slicer.py +9 -28
- ezmsg/sigproc/spectral.py +3 -3
- ezmsg/sigproc/spectrogram.py +12 -19
- ezmsg/sigproc/spectrum.py +12 -32
- ezmsg/sigproc/transpose.py +7 -18
- ezmsg/sigproc/util/asio.py +25 -156
- ezmsg/sigproc/util/axisarray_buffer.py +10 -26
- ezmsg/sigproc/util/buffer.py +18 -43
- ezmsg/sigproc/util/message.py +17 -31
- ezmsg/sigproc/util/profile.py +23 -174
- ezmsg/sigproc/util/sparse.py +5 -15
- ezmsg/sigproc/util/typeresolution.py +17 -83
- ezmsg/sigproc/wavelets.py +6 -15
- ezmsg/sigproc/window.py +24 -78
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/METADATA +4 -3
- ezmsg_sigproc-2.6.0.dist-info/RECORD +63 -0
- ezmsg/sigproc/synth.py +0 -774
- ezmsg_sigproc-2.4.1.dist-info/RECORD +0 -59
- {ezmsg_sigproc-2.4.1.dist-info → ezmsg_sigproc-2.6.0.dist-info}/WHEEL +0 -0
- /ezmsg_sigproc-2.4.1.dist-info/licenses/LICENSE.txt → /ezmsg_sigproc-2.6.0.dist-info/licenses/LICENSE +0 -0
ezmsg/sigproc/base.py
CHANGED
|
@@ -1,1284 +1,149 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
import inspect
|
|
5
|
-
import math
|
|
6
|
-
import pickle
|
|
7
|
-
import traceback
|
|
8
|
-
from types import GeneratorType
|
|
9
|
-
import typing
|
|
1
|
+
"""
|
|
2
|
+
Backwards-compatible re-exports from ezmsg.baseproc.
|
|
10
3
|
|
|
11
|
-
|
|
12
|
-
from ezmsg.
|
|
13
|
-
from ezmsg.util.generator import GenState
|
|
4
|
+
This module re-exports all symbols from ezmsg.baseproc to maintain backwards
|
|
5
|
+
compatibility for code that imports from ezmsg.sigproc.base.
|
|
14
6
|
|
|
15
|
-
from ezmsg.
|
|
16
|
-
|
|
17
|
-
resolve_typevar,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
from .util.profile import profile_subpub
|
|
21
|
-
from .util.message import SampleMessage, is_sample_message
|
|
22
|
-
from .util.asio import SyncToAsyncGeneratorWrapper, run_coroutine_sync
|
|
7
|
+
New code should import directly from ezmsg.baseproc instead.
|
|
8
|
+
"""
|
|
23
9
|
|
|
10
|
+
import warnings
|
|
24
11
|
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
12
|
+
warnings.warn(
|
|
13
|
+
"Importing from 'ezmsg.sigproc.base' is deprecated. Please import from 'ezmsg.baseproc' instead.",
|
|
14
|
+
DeprecationWarning,
|
|
15
|
+
stacklevel=2,
|
|
29
16
|
)
|
|
30
17
|
|
|
31
|
-
#
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def __call__(self, message: typing.Any) -> typing.Any: ...
|
|
97
|
-
async def __acall__(self, message: typing.Any) -> typing.Any: ...
|
|
98
|
-
|
|
99
|
-
def stateful_op(
|
|
100
|
-
self,
|
|
101
|
-
state: typing.Any,
|
|
102
|
-
message: typing.Any,
|
|
103
|
-
) -> tuple[typing.Any, typing.Any]: ...
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class StatefulProducer(typing.Protocol[SettingsType, MessageOutType, StateType]):
|
|
107
|
-
"""Protocol for producers that generate messages without consuming inputs."""
|
|
108
|
-
|
|
109
|
-
@property
|
|
110
|
-
def state(self) -> StateType: ...
|
|
111
|
-
|
|
112
|
-
@state.setter
|
|
113
|
-
def state(self, state: StateType | bytes | None) -> None: ...
|
|
114
|
-
|
|
115
|
-
def __call__(self) -> MessageOutType: ...
|
|
116
|
-
async def __acall__(self) -> MessageOutType: ...
|
|
117
|
-
|
|
118
|
-
def stateful_op(
|
|
119
|
-
self,
|
|
120
|
-
state: typing.Any,
|
|
121
|
-
) -> tuple[typing.Any, typing.Any]: ...
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
class StatefulConsumer(
|
|
125
|
-
StatefulProcessor[SettingsType, MessageInType, None, StateType], typing.Protocol
|
|
126
|
-
):
|
|
127
|
-
"""Protocol specifically for processors that consume messages without producing output."""
|
|
128
|
-
|
|
129
|
-
def __call__(self, message: MessageInType) -> None: ...
|
|
130
|
-
async def __acall__(self, message: MessageInType) -> None: ...
|
|
131
|
-
|
|
132
|
-
def stateful_op(
|
|
133
|
-
self,
|
|
134
|
-
state: tuple[StateType, int],
|
|
135
|
-
message: MessageInType,
|
|
136
|
-
) -> tuple[tuple[StateType, int], None]: ...
|
|
137
|
-
|
|
138
|
-
"""
|
|
139
|
-
Note: The return type is still a tuple even though the second entry is always None.
|
|
140
|
-
This is intentional so we can use the same protocol for both consumers and transformers,
|
|
141
|
-
and chain them together in a pipeline (e.g., `CompositeProcessor`).
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class StatefulTransformer(
|
|
146
|
-
StatefulProcessor[SettingsType, MessageInType, MessageOutType, StateType],
|
|
147
|
-
typing.Protocol,
|
|
148
|
-
):
|
|
149
|
-
"""
|
|
150
|
-
Protocol specifically for processors that transform messages.
|
|
151
|
-
"""
|
|
152
|
-
|
|
153
|
-
def __call__(self, message: MessageInType) -> MessageOutType: ...
|
|
154
|
-
async def __acall__(self, message: MessageInType) -> MessageOutType: ...
|
|
155
|
-
|
|
156
|
-
def stateful_op(
|
|
157
|
-
self,
|
|
158
|
-
state: tuple[StateType, int],
|
|
159
|
-
message: MessageInType,
|
|
160
|
-
) -> tuple[tuple[StateType, int], MessageOutType]: ...
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
|
|
164
|
-
def partial_fit(self, message: SampleMessage) -> None:
|
|
165
|
-
"""Update transformer state using labeled training data.
|
|
166
|
-
|
|
167
|
-
This method should update the internal state/parameters of the transformer
|
|
168
|
-
based on the provided labeled samples, without performing any transformation.
|
|
169
|
-
"""
|
|
170
|
-
...
|
|
171
|
-
|
|
172
|
-
async def apartial_fit(self, message: SampleMessage) -> None: ...
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
# --- Base implementation classes for processors ---
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def _get_base_processor_settings_type(cls: type) -> type:
|
|
179
|
-
try:
|
|
180
|
-
return resolve_typevar(cls, SettingsType)
|
|
181
|
-
except TypeError as e:
|
|
182
|
-
raise TypeError(
|
|
183
|
-
f"Could not resolve settings type for {cls}. "
|
|
184
|
-
f"Ensure that the class is properly annotated with a SettingsType."
|
|
185
|
-
) from e
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def _get_base_processor_message_in_type(cls: type) -> type:
|
|
189
|
-
return resolve_typevar(cls, MessageInType)
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def _get_base_processor_message_out_type(cls: type) -> type:
|
|
193
|
-
return resolve_typevar(cls, MessageOutType)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
def _unify_settings(
|
|
197
|
-
obj: typing.Any, settings: object | None, *args, **kwargs
|
|
198
|
-
) -> typing.Any:
|
|
199
|
-
"""Helper function to unify settings for processor initialization."""
|
|
200
|
-
settings_type = _get_base_processor_settings_type(obj.__class__)
|
|
201
|
-
|
|
202
|
-
if settings is None:
|
|
203
|
-
if len(args) > 0 and isinstance(args[0], settings_type):
|
|
204
|
-
settings = args[0]
|
|
205
|
-
elif len(args) > 0 or len(kwargs) > 0:
|
|
206
|
-
settings = settings_type(*args, **kwargs)
|
|
207
|
-
else:
|
|
208
|
-
settings = settings_type()
|
|
209
|
-
assert isinstance(settings, settings_type), "Settings must be of type " + str(
|
|
210
|
-
settings_type
|
|
211
|
-
)
|
|
212
|
-
return settings
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
class BaseProcessor(ABC, typing.Generic[SettingsType, MessageInType, MessageOutType]):
|
|
216
|
-
"""
|
|
217
|
-
Base class for processors. You probably do not want to inherit from this class directly.
|
|
218
|
-
Refer instead to the more specific base classes.
|
|
219
|
-
* Use :obj:`BaseConsumer` or :obj:`BaseTransformer` for ops that return a result or not, respectively.
|
|
220
|
-
* Use :obj:`BaseStatefulProcessor` and its children for operations that require state.
|
|
221
|
-
|
|
222
|
-
Note that `BaseProcessor` and its children are sync by default. If you need async by defualt, then
|
|
223
|
-
override the async methods and call them from the sync methods. Look to `BaseProducer` for examples of
|
|
224
|
-
calling async methods from sync methods.
|
|
225
|
-
"""
|
|
226
|
-
|
|
227
|
-
settings: SettingsType
|
|
228
|
-
|
|
229
|
-
@classmethod
|
|
230
|
-
def get_settings_type(cls) -> type[SettingsType]:
|
|
231
|
-
return _get_base_processor_settings_type(cls)
|
|
232
|
-
|
|
233
|
-
@classmethod
|
|
234
|
-
def get_message_type(cls, dir: str) -> typing.Any:
|
|
235
|
-
if dir == "in":
|
|
236
|
-
return _get_base_processor_message_in_type(cls)
|
|
237
|
-
elif dir == "out":
|
|
238
|
-
return _get_base_processor_message_out_type(cls)
|
|
239
|
-
else:
|
|
240
|
-
raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
|
|
241
|
-
|
|
242
|
-
def __init__(self, *args, settings: SettingsType | None = None, **kwargs) -> None:
|
|
243
|
-
self.settings = _unify_settings(self, settings, *args, **kwargs)
|
|
244
|
-
|
|
245
|
-
@abstractmethod
|
|
246
|
-
def _process(self, message: typing.Any) -> typing.Any: ...
|
|
247
|
-
|
|
248
|
-
async def _aprocess(self, message: typing.Any) -> typing.Any:
|
|
249
|
-
"""Override this for native async processing."""
|
|
250
|
-
return self._process(message)
|
|
251
|
-
|
|
252
|
-
def __call__(self, message: typing.Any) -> typing.Any:
|
|
253
|
-
# Note: We use the indirection to `_process` because this allows us to
|
|
254
|
-
# modify __call__ in derived classes with common functionality while
|
|
255
|
-
# minimizing the boilerplate code in derived classes as they only need to
|
|
256
|
-
# implement `_process`.
|
|
257
|
-
return self._process(message)
|
|
258
|
-
|
|
259
|
-
async def __acall__(self, message: typing.Any) -> typing.Any:
|
|
260
|
-
"""
|
|
261
|
-
In Python 3.12+, we can invoke this method simply with `await obj(message)`,
|
|
262
|
-
but earlier versions require direct syntax: `await obj.__acall__(message)`.
|
|
263
|
-
"""
|
|
264
|
-
return await self._aprocess(message)
|
|
265
|
-
|
|
266
|
-
def send(self, message: typing.Any) -> typing.Any:
|
|
267
|
-
"""Alias for __call__."""
|
|
268
|
-
return self(message)
|
|
269
|
-
|
|
270
|
-
async def asend(self, message: typing.Any) -> typing.Any:
|
|
271
|
-
"""Alias for __acall__."""
|
|
272
|
-
return await self.__acall__(message)
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
class BaseProducer(ABC, typing.Generic[SettingsType, MessageOutType]):
|
|
276
|
-
"""
|
|
277
|
-
Base class for producers -- processors that generate messages without consuming inputs.
|
|
278
|
-
|
|
279
|
-
Note that `BaseProducer` and its children are async by default, and the sync methods simply wrap
|
|
280
|
-
the async methods. This is the opposite of :obj:`BaseProcessor` and its children which are sync by default.
|
|
281
|
-
These classes are designed this way because it is highly likely that a producer, which (probably) does not
|
|
282
|
-
receive inputs, will require some sort of IO which will benefit from being async.
|
|
283
|
-
"""
|
|
284
|
-
|
|
285
|
-
@classmethod
|
|
286
|
-
def get_settings_type(cls) -> type[SettingsType]:
|
|
287
|
-
return _get_base_processor_settings_type(cls)
|
|
288
|
-
|
|
289
|
-
@classmethod
|
|
290
|
-
def get_message_type(cls, dir: str) -> type[MessageOutType] | None:
|
|
291
|
-
if dir == "out":
|
|
292
|
-
return _get_base_processor_message_out_type(cls)
|
|
293
|
-
elif dir == "in":
|
|
294
|
-
return None
|
|
295
|
-
else:
|
|
296
|
-
raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
|
|
297
|
-
|
|
298
|
-
def __init__(self, *args, settings: SettingsType | None = None, **kwargs) -> None:
|
|
299
|
-
self.settings = _unify_settings(self, settings, *args, **kwargs)
|
|
300
|
-
|
|
301
|
-
@abstractmethod
|
|
302
|
-
async def _produce(self) -> MessageOutType: ...
|
|
303
|
-
|
|
304
|
-
async def __acall__(self) -> MessageOutType:
|
|
305
|
-
return await self._produce()
|
|
306
|
-
|
|
307
|
-
def __call__(self) -> MessageOutType:
|
|
308
|
-
# Warning: This is a bit slow. Override this method in derived classes if performance is critical.
|
|
309
|
-
return run_coroutine_sync(self.__acall__())
|
|
310
|
-
|
|
311
|
-
def __iter__(self) -> typing.Iterator[MessageOutType]:
|
|
312
|
-
# Make self an iterator
|
|
313
|
-
return self
|
|
314
|
-
|
|
315
|
-
async def __anext__(self) -> MessageOutType:
|
|
316
|
-
# So this can be used as an async generator.
|
|
317
|
-
return await self.__acall__()
|
|
318
|
-
|
|
319
|
-
def __next__(self) -> MessageOutType:
|
|
320
|
-
# So this can be used as a generator.
|
|
321
|
-
return self()
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
class BaseConsumer(
|
|
325
|
-
BaseProcessor[SettingsType, MessageInType, None],
|
|
326
|
-
ABC,
|
|
327
|
-
typing.Generic[SettingsType, MessageInType],
|
|
328
|
-
):
|
|
329
|
-
"""
|
|
330
|
-
Base class for consumers -- processors that receive messages but don't produce output.
|
|
331
|
-
This base simply overrides type annotations of BaseProcessor to remove the outputs.
|
|
332
|
-
(We don't bother overriding `send` and `asend` because those are deprecated.)
|
|
333
|
-
"""
|
|
334
|
-
|
|
335
|
-
@classmethod
|
|
336
|
-
def get_message_type(cls, dir: str) -> type[MessageInType] | None:
|
|
337
|
-
if dir == "in":
|
|
338
|
-
return _get_base_processor_message_in_type(cls)
|
|
339
|
-
elif dir == "out":
|
|
340
|
-
return None
|
|
341
|
-
else:
|
|
342
|
-
raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
|
|
343
|
-
|
|
344
|
-
@abstractmethod
|
|
345
|
-
def _process(self, message: MessageInType) -> None: ...
|
|
346
|
-
|
|
347
|
-
async def _aprocess(self, message: MessageInType) -> None:
|
|
348
|
-
"""Override this for native async processing."""
|
|
349
|
-
return self._process(message)
|
|
350
|
-
|
|
351
|
-
def __call__(self, message: MessageInType) -> None:
|
|
352
|
-
return super().__call__(message)
|
|
353
|
-
|
|
354
|
-
async def __acall__(self, message: MessageInType) -> None:
|
|
355
|
-
return await super().__acall__(message)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
class BaseTransformer(
|
|
359
|
-
BaseProcessor[SettingsType, MessageInType, MessageOutType],
|
|
360
|
-
ABC,
|
|
361
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType],
|
|
362
|
-
):
|
|
363
|
-
"""
|
|
364
|
-
Base class for transformers -- processors which receive messages and produce output.
|
|
365
|
-
This base simply overrides type annotations of :obj:`BaseProcessor` to indicate that outputs are not optional.
|
|
366
|
-
(We don't bother overriding `send` and `asend` because those are deprecated.)
|
|
367
|
-
"""
|
|
368
|
-
|
|
369
|
-
@abstractmethod
|
|
370
|
-
def _process(self, message: MessageInType) -> MessageOutType: ...
|
|
371
|
-
|
|
372
|
-
async def _aprocess(self, message: MessageInType) -> MessageOutType:
|
|
373
|
-
"""Override this for native async processing."""
|
|
374
|
-
return self._process(message)
|
|
375
|
-
|
|
376
|
-
def __call__(self, message: MessageInType) -> MessageOutType:
|
|
377
|
-
return super().__call__(message)
|
|
378
|
-
|
|
379
|
-
async def __acall__(self, message: MessageInType) -> MessageOutType:
|
|
380
|
-
return await super().__acall__(message)
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
def _get_base_processor_state_type(cls: type) -> type:
|
|
384
|
-
try:
|
|
385
|
-
return resolve_typevar(cls, StateType)
|
|
386
|
-
except TypeError as e:
|
|
387
|
-
raise TypeError(
|
|
388
|
-
f"Could not resolve state type for {cls}. "
|
|
389
|
-
f"Ensure that the class is properly annotated with a StateType."
|
|
390
|
-
) from e
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
class Stateful(ABC, typing.Generic[StateType]):
|
|
394
|
-
"""
|
|
395
|
-
Mixin class for stateful processors. DO NOT use this class directly.
|
|
396
|
-
Used to enforce that the processor/producer has a state attribute and stateful_op method.
|
|
397
|
-
"""
|
|
398
|
-
|
|
399
|
-
_state: StateType
|
|
400
|
-
|
|
401
|
-
@classmethod
|
|
402
|
-
def get_state_type(cls) -> type[StateType]:
|
|
403
|
-
return _get_base_processor_state_type(cls)
|
|
404
|
-
|
|
405
|
-
@property
|
|
406
|
-
def state(self) -> StateType:
|
|
407
|
-
return self._state
|
|
408
|
-
|
|
409
|
-
@state.setter
|
|
410
|
-
def state(self, state: StateType | bytes | None) -> None:
|
|
411
|
-
if state is not None:
|
|
412
|
-
if isinstance(state, bytes):
|
|
413
|
-
self._state = pickle.loads(state)
|
|
414
|
-
else:
|
|
415
|
-
self._state = state # type: ignore
|
|
416
|
-
|
|
417
|
-
def _hash_message(self, message: typing.Any) -> int:
|
|
418
|
-
"""
|
|
419
|
-
Check if the message metadata indicates a need for state reset.
|
|
420
|
-
|
|
421
|
-
This method is not abstract because there are some processors that might only
|
|
422
|
-
need to reset once but are otherwise insensitive to the message structure.
|
|
423
|
-
|
|
424
|
-
For example, an activation function that benefits greatly from pre-computed values should
|
|
425
|
-
do this computation in `_reset_state` and attach those values to the processor state,
|
|
426
|
-
but if it e.g. operates elementwise on the input then it doesn't care if the incoming
|
|
427
|
-
data changes shape or sample rate so you don't need to reset again.
|
|
428
|
-
|
|
429
|
-
All processors' initial state should have `.hash = -1` then by returning `0` here
|
|
430
|
-
we force an update on the first message.
|
|
431
|
-
"""
|
|
432
|
-
return 0
|
|
433
|
-
|
|
434
|
-
@abstractmethod
|
|
435
|
-
def _reset_state(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
|
436
|
-
"""
|
|
437
|
-
Reset internal state based on
|
|
438
|
-
- new message metadata (processors), or
|
|
439
|
-
- after first call (producers).
|
|
440
|
-
"""
|
|
441
|
-
...
|
|
442
|
-
|
|
443
|
-
@abstractmethod
|
|
444
|
-
def stateful_op(self, *args: typing.Any, **kwargs: typing.Any) -> tuple: ...
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
class BaseStatefulProcessor(
|
|
448
|
-
BaseProcessor[SettingsType, MessageInType, MessageOutType],
|
|
449
|
-
Stateful[StateType],
|
|
450
|
-
ABC,
|
|
451
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
452
|
-
):
|
|
453
|
-
"""
|
|
454
|
-
Base class implementing common stateful processor functionality.
|
|
455
|
-
You probably do not want to inherit from this class directly.
|
|
456
|
-
Refer instead to the more specific base classes.
|
|
457
|
-
Use BaseStatefulConsumer for operations that do not return a result,
|
|
458
|
-
or BaseStatefulTransformer for operations that do return a result.
|
|
459
|
-
"""
|
|
460
|
-
|
|
461
|
-
def __init__(self, *args, **kwargs) -> None:
|
|
462
|
-
super().__init__(*args, **kwargs)
|
|
463
|
-
self._hash = -1
|
|
464
|
-
state_type = self.__class__.get_state_type()
|
|
465
|
-
self._state: StateType = state_type()
|
|
466
|
-
# TODO: Enforce that StateType has .hash: int field.
|
|
467
|
-
|
|
468
|
-
@abstractmethod
|
|
469
|
-
def _reset_state(self, message: typing.Any) -> None:
|
|
470
|
-
"""
|
|
471
|
-
Reset internal state based on new message metadata.
|
|
472
|
-
This method will only be called when there is a significant change in the message metadata,
|
|
473
|
-
such as sample rate or shape (criteria defined by `_hash_message`), and not for every message,
|
|
474
|
-
so use it to do all the expensive pre-allocation and caching of variables that can speed up
|
|
475
|
-
the processing of subsequent messages in `_process`.
|
|
476
|
-
"""
|
|
477
|
-
...
|
|
478
|
-
|
|
479
|
-
@abstractmethod
|
|
480
|
-
def _process(self, message: typing.Any) -> typing.Any: ...
|
|
481
|
-
|
|
482
|
-
def __call__(self, message: typing.Any) -> typing.Any:
|
|
483
|
-
msg_hash = self._hash_message(message)
|
|
484
|
-
if msg_hash != self._hash:
|
|
485
|
-
self._reset_state(message)
|
|
486
|
-
self._hash = msg_hash
|
|
487
|
-
return self._process(message)
|
|
488
|
-
|
|
489
|
-
async def __acall__(self, message: typing.Any) -> typing.Any:
|
|
490
|
-
msg_hash = self._hash_message(message)
|
|
491
|
-
if msg_hash != self._hash:
|
|
492
|
-
self._reset_state(message)
|
|
493
|
-
self._hash = msg_hash
|
|
494
|
-
return await self._aprocess(message)
|
|
495
|
-
|
|
496
|
-
def stateful_op(
|
|
497
|
-
self,
|
|
498
|
-
state: tuple[StateType, int] | None,
|
|
499
|
-
message: typing.Any,
|
|
500
|
-
) -> tuple[tuple[StateType, int], typing.Any]:
|
|
501
|
-
if state is not None:
|
|
502
|
-
self.state, self._hash = state
|
|
503
|
-
result = self(message)
|
|
504
|
-
return (self.state, self._hash), result
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class BaseStatefulProducer(
|
|
508
|
-
BaseProducer[SettingsType, MessageOutType],
|
|
509
|
-
Stateful[StateType],
|
|
510
|
-
ABC,
|
|
511
|
-
typing.Generic[SettingsType, MessageOutType, StateType],
|
|
512
|
-
):
|
|
513
|
-
"""
|
|
514
|
-
Base class implementing common stateful producer functionality.
|
|
515
|
-
Examples of stateful producers are things that require counters, clocks,
|
|
516
|
-
or to cycle through a set of values.
|
|
517
|
-
|
|
518
|
-
Unlike BaseStatefulProcessor, this class does not message hashing because there
|
|
519
|
-
are no input messages. We still use self._hash to simply track the transition from
|
|
520
|
-
initialization (.hash == -1) to state reset (.hash == 0).
|
|
521
|
-
"""
|
|
522
|
-
|
|
523
|
-
def __init__(self, *args, **kwargs) -> None:
|
|
524
|
-
super().__init__(*args, **kwargs) # .settings
|
|
525
|
-
self._hash = -1
|
|
526
|
-
state_type = self.__class__.get_state_type()
|
|
527
|
-
self._state: StateType = state_type()
|
|
528
|
-
|
|
529
|
-
@abstractmethod
|
|
530
|
-
def _reset_state(self) -> None:
|
|
531
|
-
"""
|
|
532
|
-
Reset internal state upon first call.
|
|
533
|
-
"""
|
|
534
|
-
...
|
|
535
|
-
|
|
536
|
-
async def __acall__(self) -> MessageOutType:
|
|
537
|
-
if self._hash == -1:
|
|
538
|
-
self._reset_state()
|
|
539
|
-
self._hash = 0
|
|
540
|
-
return await self._produce()
|
|
541
|
-
|
|
542
|
-
def stateful_op(
|
|
543
|
-
self,
|
|
544
|
-
state: tuple[StateType, int] | None,
|
|
545
|
-
) -> tuple[tuple[StateType, int], MessageOutType]:
|
|
546
|
-
if state is not None:
|
|
547
|
-
self.state, self._hash = state # Update state via setter
|
|
548
|
-
result = self() # Uses synchronous call
|
|
549
|
-
return (self.state, self._hash), result
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
class BaseStatefulConsumer(
|
|
553
|
-
BaseStatefulProcessor[SettingsType, MessageInType, None, StateType],
|
|
554
|
-
ABC,
|
|
555
|
-
typing.Generic[SettingsType, MessageInType, StateType],
|
|
556
|
-
):
|
|
557
|
-
"""
|
|
558
|
-
Base class for stateful message consumers that don't produce output.
|
|
559
|
-
This class merely overrides the type annotations of BaseStatefulProcessor.
|
|
560
|
-
"""
|
|
561
|
-
|
|
562
|
-
@classmethod
|
|
563
|
-
def get_message_type(cls, dir: str) -> type[MessageInType] | None:
|
|
564
|
-
if dir == "in":
|
|
565
|
-
return _get_base_processor_message_in_type(cls)
|
|
566
|
-
elif dir == "out":
|
|
567
|
-
return None
|
|
568
|
-
else:
|
|
569
|
-
raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
|
|
570
|
-
|
|
571
|
-
@abstractmethod
|
|
572
|
-
def _process(self, message: MessageInType) -> None: ...
|
|
573
|
-
|
|
574
|
-
async def _aprocess(self, message: MessageInType) -> None:
|
|
575
|
-
return self._process(message)
|
|
576
|
-
|
|
577
|
-
def __call__(self, message: MessageInType) -> None:
|
|
578
|
-
return super().__call__(message)
|
|
579
|
-
|
|
580
|
-
async def __acall__(self, message: MessageInType) -> None:
|
|
581
|
-
return await super().__acall__(message)
|
|
582
|
-
|
|
583
|
-
def stateful_op(
|
|
584
|
-
self,
|
|
585
|
-
state: tuple[StateType, int] | None,
|
|
586
|
-
message: MessageInType,
|
|
587
|
-
) -> tuple[tuple[StateType, int], None]:
|
|
588
|
-
state, _ = super().stateful_op(state, message)
|
|
589
|
-
return state, None
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
class BaseStatefulTransformer(
|
|
593
|
-
BaseStatefulProcessor[SettingsType, MessageInType, MessageOutType, StateType],
|
|
594
|
-
ABC,
|
|
595
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
596
|
-
):
|
|
597
|
-
"""
|
|
598
|
-
Base class for stateful message transformers that produce output.
|
|
599
|
-
This class merely overrides the type annotations of BaseStatefulProcessor.
|
|
600
|
-
"""
|
|
601
|
-
|
|
602
|
-
@abstractmethod
|
|
603
|
-
def _process(self, message: MessageInType) -> MessageOutType: ...
|
|
604
|
-
|
|
605
|
-
async def _aprocess(self, message: MessageInType) -> MessageOutType:
|
|
606
|
-
return self._process(message)
|
|
607
|
-
|
|
608
|
-
def __call__(self, message: MessageInType) -> MessageOutType:
|
|
609
|
-
return super().__call__(message)
|
|
610
|
-
|
|
611
|
-
async def __acall__(self, message: MessageInType) -> MessageOutType:
|
|
612
|
-
return await super().__acall__(message)
|
|
613
|
-
|
|
614
|
-
def stateful_op(
|
|
615
|
-
self,
|
|
616
|
-
state: tuple[StateType, int] | None,
|
|
617
|
-
message: MessageInType,
|
|
618
|
-
) -> tuple[tuple[StateType, int], MessageOutType]:
|
|
619
|
-
return super().stateful_op(state, message)
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
class BaseAdaptiveTransformer(
|
|
623
|
-
BaseStatefulTransformer[
|
|
624
|
-
SettingsType,
|
|
625
|
-
MessageInType | SampleMessage,
|
|
626
|
-
MessageOutType | None,
|
|
627
|
-
StateType,
|
|
628
|
-
],
|
|
629
|
-
ABC,
|
|
630
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
631
|
-
):
|
|
632
|
-
@abstractmethod
|
|
633
|
-
def partial_fit(self, message: SampleMessage) -> None: ...
|
|
634
|
-
|
|
635
|
-
async def apartial_fit(self, message: SampleMessage) -> None:
|
|
636
|
-
"""Override me if you need async partial fitting."""
|
|
637
|
-
return self.partial_fit(message)
|
|
638
|
-
|
|
639
|
-
def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
|
|
640
|
-
"""
|
|
641
|
-
Adapt transformer with training data (and optionally labels)
|
|
642
|
-
in SampleMessage
|
|
643
|
-
|
|
644
|
-
Args:
|
|
645
|
-
message: An instance of SampleMessage with optional
|
|
646
|
-
labels (y) in message.trigger.value.data and
|
|
647
|
-
data (X) in message.sample.data
|
|
648
|
-
|
|
649
|
-
Returns: None
|
|
650
|
-
"""
|
|
651
|
-
if is_sample_message(message):
|
|
652
|
-
return self.partial_fit(message)
|
|
653
|
-
return super().__call__(message)
|
|
654
|
-
|
|
655
|
-
async def __acall__(
|
|
656
|
-
self, message: MessageInType | SampleMessage
|
|
657
|
-
) -> MessageOutType | None:
|
|
658
|
-
if is_sample_message(message):
|
|
659
|
-
return await self.apartial_fit(message)
|
|
660
|
-
return await super().__acall__(message)
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
class BaseAsyncTransformer(
|
|
664
|
-
BaseStatefulTransformer[SettingsType, MessageInType, MessageOutType, StateType],
|
|
665
|
-
ABC,
|
|
666
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
667
|
-
):
|
|
668
|
-
"""
|
|
669
|
-
This reverses the priority of async and sync methods from :obj:`BaseStatefulTransformer`.
|
|
670
|
-
Whereas in :obj:`BaseStatefulTransformer`, the async methods simply called the sync methods,
|
|
671
|
-
here the sync methods call the async methods, more similar to :obj:`BaseStatefulProducer`.
|
|
672
|
-
"""
|
|
673
|
-
|
|
674
|
-
def _process(self, message: MessageInType) -> MessageOutType:
|
|
675
|
-
return run_coroutine_sync(self._aprocess(message))
|
|
676
|
-
|
|
677
|
-
@abstractmethod
|
|
678
|
-
async def _aprocess(self, message: MessageInType) -> MessageOutType: ...
|
|
679
|
-
|
|
680
|
-
def __call__(self, message: MessageInType) -> MessageOutType:
|
|
681
|
-
# Override (synchronous) __call__ to run coroutine `aprocess`.
|
|
682
|
-
return run_coroutine_sync(self.__acall__(message))
|
|
683
|
-
|
|
684
|
-
async def __acall__(self, message: MessageInType) -> MessageOutType:
|
|
685
|
-
# Note: In Python 3.12, we can invoke this with `await obj(message)`
|
|
686
|
-
# Earlier versions must be explicit: `await obj.__acall__(message)`
|
|
687
|
-
msg_hash = self._hash_message(message)
|
|
688
|
-
if msg_hash != self._hash:
|
|
689
|
-
self._reset_state(message)
|
|
690
|
-
self._hash = msg_hash
|
|
691
|
-
return await self._aprocess(message)
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
# Composite processor for building pipelines
|
|
695
|
-
def _get_processor_message_type(
|
|
696
|
-
proc: BaseProcessor | BaseProducer | GeneratorType | SyncToAsyncGeneratorWrapper,
|
|
697
|
-
dir: str,
|
|
698
|
-
) -> type | None:
|
|
699
|
-
"""Extract the input type from a processor."""
|
|
700
|
-
if isinstance(proc, GeneratorType) or isinstance(proc, SyncToAsyncGeneratorWrapper):
|
|
701
|
-
gen_func = proc.gi_frame.f_globals[proc.gi_frame.f_code.co_name]
|
|
702
|
-
args = typing.get_args(gen_func.__annotations__.get("return"))
|
|
703
|
-
return args[0] if dir == "out" else args[1] # yield type / send type
|
|
704
|
-
return proc.__class__.get_message_type(dir)
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
def _has_stateful_op(proc: typing.Any) -> typing.TypeGuard[Stateful]:
|
|
708
|
-
"""
|
|
709
|
-
Check if the processor has a stateful_op method.
|
|
710
|
-
This is used to determine if the processor is stateful or not.
|
|
711
|
-
"""
|
|
712
|
-
return hasattr(proc, "stateful_op")
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
class CompositeStateful(
|
|
716
|
-
Stateful[dict[str, typing.Any]], ABC, typing.Generic[SettingsType, MessageOutType]
|
|
717
|
-
):
|
|
718
|
-
"""
|
|
719
|
-
Mixin class for composite processor/producer chains. DO NOT use this class directly.
|
|
720
|
-
Used to enforce statefulness of the composite processor/producer chain and provide
|
|
721
|
-
initialization and validation methods.
|
|
722
|
-
"""
|
|
723
|
-
|
|
724
|
-
_procs: dict[
|
|
725
|
-
str, BaseProducer | BaseProcessor | GeneratorType | SyncToAsyncGeneratorWrapper
|
|
726
|
-
]
|
|
727
|
-
_processor_type: typing.Literal["producer", "processor"]
|
|
728
|
-
|
|
729
|
-
def _validate_processor_chain(self) -> None:
|
|
730
|
-
"""Validate the composite chain types at runtime."""
|
|
731
|
-
if not self._procs:
|
|
732
|
-
raise ValueError(
|
|
733
|
-
f"Composite {self._processor_type} requires at least one processor"
|
|
734
|
-
)
|
|
735
|
-
|
|
736
|
-
expected_in_type = _get_processor_message_type(self, "in")
|
|
737
|
-
expected_out_type = _get_processor_message_type(self, "out")
|
|
738
|
-
|
|
739
|
-
procs = [p for p in self._procs.items() if p[1] is not None]
|
|
740
|
-
in_type = _get_processor_message_type(procs[0][1], "in")
|
|
741
|
-
if not check_message_type_compatibility(expected_in_type, in_type):
|
|
742
|
-
raise TypeError(
|
|
743
|
-
f"Input type mismatch: Composite {self._processor_type} expects {expected_in_type}, "
|
|
744
|
-
f"but its first processor (name: {procs[0][0]}, type: {procs[0][1].__class__.__name__}) accepts {in_type}"
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
out_type = _get_processor_message_type(procs[-1][1], "out")
|
|
748
|
-
if not check_message_type_compatibility(out_type, expected_out_type):
|
|
749
|
-
raise TypeError(
|
|
750
|
-
f"Output type mismatch: Composite {self._processor_type} wants to return {expected_out_type}, "
|
|
751
|
-
f"but its last processor (name: {procs[-1][0]}, type: {procs[-1][1].__class__.__name__}) returns {out_type}"
|
|
752
|
-
)
|
|
753
|
-
|
|
754
|
-
# Check intermediate connections
|
|
755
|
-
for i in range(len(procs) - 1):
|
|
756
|
-
current_out_type = _get_processor_message_type(procs[i][1], "out")
|
|
757
|
-
next_in_type = _get_processor_message_type(procs[i + 1][1], "in")
|
|
758
|
-
|
|
759
|
-
if current_out_type is None or current_out_type is type(None):
|
|
760
|
-
raise TypeError(
|
|
761
|
-
f"Processor {i} (name: {procs[i][0]}, type: {procs[i][1].__class__.__name__}) is a consumer "
|
|
762
|
-
f"or returns None. Consumers can only be the last processor of a composite {self._processor_type} chain."
|
|
763
|
-
)
|
|
764
|
-
if next_in_type is None or next_in_type is type(None):
|
|
765
|
-
raise TypeError(
|
|
766
|
-
f"Processor {i + 1} (name: {procs[i + 1][0]}, type: {procs[i + 1][1].__class__.__name__}) is a producer "
|
|
767
|
-
f"or receives only None. Producers can only be the first processor of a composite producer chain."
|
|
768
|
-
)
|
|
769
|
-
if not check_message_type_compatibility(current_out_type, next_in_type):
|
|
770
|
-
raise TypeError(
|
|
771
|
-
f"Message type mismatch between processors {i} (name: {procs[i][0]}, type: {procs[i][1].__class__.__name__}) "
|
|
772
|
-
f"and {i + 1} (name: {procs[i + 1][0]}, type: {procs[i + 1][1].__class__.__name__}): "
|
|
773
|
-
f"{procs[i][1].__class__.__name__} outputs {current_out_type}, "
|
|
774
|
-
f"but {procs[i + 1][1].__class__.__name__} expects {next_in_type}"
|
|
775
|
-
)
|
|
776
|
-
if inspect.isgenerator(procs[i][1]) and hasattr(procs[i][1], "send"):
|
|
777
|
-
# If the processor is a generator, wrap it in a SyncToAsyncGeneratorWrapper
|
|
778
|
-
procs[i] = (procs[i][0], SyncToAsyncGeneratorWrapper(procs[i][1]))
|
|
779
|
-
if inspect.isgenerator(procs[-1][1]) and hasattr(procs[-1][1], "send"):
|
|
780
|
-
# If the last processor is a generator, wrap it in a SyncToAsyncGeneratorWrapper
|
|
781
|
-
procs[-1] = (procs[-1][0], SyncToAsyncGeneratorWrapper(procs[-1][1]))
|
|
782
|
-
self._procs = {k: v for (k, v) in procs}
|
|
783
|
-
|
|
784
|
-
@staticmethod
|
|
785
|
-
@abstractmethod
|
|
786
|
-
def _initialize_processors(
|
|
787
|
-
settings: SettingsType,
|
|
788
|
-
) -> dict[str, typing.Any]: ...
|
|
789
|
-
|
|
790
|
-
@property
|
|
791
|
-
def state(self) -> dict[str, typing.Any]:
|
|
792
|
-
return {
|
|
793
|
-
k: getattr(proc, "state")
|
|
794
|
-
for k, proc in self._procs.items()
|
|
795
|
-
if hasattr(proc, "state")
|
|
796
|
-
}
|
|
797
|
-
|
|
798
|
-
@state.setter
|
|
799
|
-
def state(self, state: dict[str, typing.Any] | bytes | None) -> None:
|
|
800
|
-
if state is not None:
|
|
801
|
-
if isinstance(state, bytes):
|
|
802
|
-
state = pickle.loads(state)
|
|
803
|
-
for k, v in state.items(): # type: ignore
|
|
804
|
-
if k not in self._procs:
|
|
805
|
-
raise KeyError(
|
|
806
|
-
f"Processor (name: {k}) in provided state not found in composite {self._processor_type} chain. "
|
|
807
|
-
f"Available keys: {list(self._procs.keys())}"
|
|
808
|
-
)
|
|
809
|
-
if hasattr(self._procs[k], "state"):
|
|
810
|
-
setattr(self._procs[k], "state", v)
|
|
811
|
-
|
|
812
|
-
def _reset_state(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
|
813
|
-
# By default, we don't expect to change the state of a composite processor/producer
|
|
814
|
-
pass
|
|
815
|
-
|
|
816
|
-
@abstractmethod
|
|
817
|
-
def stateful_op(
|
|
818
|
-
self,
|
|
819
|
-
state: dict[str, tuple[typing.Any, int]] | None,
|
|
820
|
-
*args: typing.Any,
|
|
821
|
-
**kwargs: typing.Any,
|
|
822
|
-
) -> tuple[
|
|
823
|
-
dict[str, tuple[typing.Any, int]],
|
|
824
|
-
MessageOutType | None,
|
|
825
|
-
]: ...
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
class CompositeProcessor(
|
|
829
|
-
BaseProcessor[SettingsType, MessageInType, MessageOutType],
|
|
830
|
-
CompositeStateful[SettingsType, MessageOutType],
|
|
831
|
-
ABC,
|
|
832
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType],
|
|
833
|
-
):
|
|
834
|
-
"""
|
|
835
|
-
A processor that chains multiple processor together in a feedforward non-branching graph.
|
|
836
|
-
The individual processors may be stateless or stateful. The last processor may be a consumer,
|
|
837
|
-
otherwise processors must be transformers. Use CompositeProducer if you want the first
|
|
838
|
-
processor to be a producer. Concrete subclasses must implement `_initialize_processors`.
|
|
839
|
-
Optionally override `_reset_state` if you want adaptive state behaviour.
|
|
840
|
-
Example implementation:
|
|
841
|
-
|
|
842
|
-
class CustomCompositeProcessor(CompositeProcessor[CustomSettings, AxisArray, AxisArray]):
|
|
843
|
-
@staticmethod
|
|
844
|
-
def _initialize_processors(settings: CustomSettings) -> dict[str, BaseProcessor]:
|
|
845
|
-
return {
|
|
846
|
-
"stateful_transformer": CustomStatefulProducer(**settings),
|
|
847
|
-
"transformer": CustomTransformer(**settings),
|
|
848
|
-
}
|
|
849
|
-
Where **settings should be replaced with initialisation arguments for each processor.
|
|
850
|
-
"""
|
|
851
|
-
|
|
852
|
-
def __init__(self, *args, **kwargs) -> None:
|
|
853
|
-
super().__init__(*args, **kwargs) # .settings
|
|
854
|
-
self._processor_type = "processor"
|
|
855
|
-
self._procs = self._initialize_processors(self.settings)
|
|
856
|
-
self._validate_processor_chain()
|
|
857
|
-
first_proc = next(iter(self._procs.items()))
|
|
858
|
-
first_proc_in_type = _get_processor_message_type(first_proc[1], "in")
|
|
859
|
-
if first_proc_in_type is None or first_proc_in_type is type(None):
|
|
860
|
-
raise TypeError(
|
|
861
|
-
f"First processor (name: {first_proc[0]}, type: {first_proc[1].__class__.__name__}) "
|
|
862
|
-
f"is a producer or receives only None. Please use CompositeProducer, not "
|
|
863
|
-
f"CompositeProcessor for this composite chain."
|
|
864
|
-
)
|
|
865
|
-
self._hash = -1
|
|
866
|
-
|
|
867
|
-
@staticmethod
|
|
868
|
-
@abstractmethod
|
|
869
|
-
def _initialize_processors(settings: SettingsType) -> dict[str, typing.Any]: ...
|
|
870
|
-
|
|
871
|
-
def _process(self, message: MessageInType | None = None) -> MessageOutType | None:
|
|
872
|
-
"""
|
|
873
|
-
Process a message through the pipeline of processors. If the message is None, or no message is provided,
|
|
874
|
-
then it will be assumed that the first processor is a producer and will be called without arguments.
|
|
875
|
-
This will be invoked via `__call__` or `send`.
|
|
876
|
-
We use `__next__` and `send` to allow using legacy generators that have yet to be converted to transformers.
|
|
877
|
-
|
|
878
|
-
Warning: All processors will be called using their synchronous API, which may invoke a slow sync->async wrapper
|
|
879
|
-
for processors that are async-first (i.e., children of BaseProducer or BaseAsyncTransformer).
|
|
880
|
-
If you are in an async context, please use instead this object's `asend` or `__acall__`,
|
|
881
|
-
which is much faster for async processors and does not incur penalty on sync processors.
|
|
882
|
-
"""
|
|
883
|
-
result = message
|
|
884
|
-
for proc in self._procs.values():
|
|
885
|
-
result = proc.send(result)
|
|
886
|
-
return result
|
|
887
|
-
|
|
888
|
-
async def _aprocess(
|
|
889
|
-
self, message: MessageInType | None = None
|
|
890
|
-
) -> MessageOutType | None:
|
|
891
|
-
"""
|
|
892
|
-
Process a message through the pipeline of processors using their async APIs.
|
|
893
|
-
If the message is None, or no message is provided, then it will be assumed that the first processor
|
|
894
|
-
is a producer and will be called without arguments.
|
|
895
|
-
We use `__anext__` and `asend` to allow using legacy generators that have yet to be converted to transformers.
|
|
896
|
-
"""
|
|
897
|
-
result = message
|
|
898
|
-
for proc in self._procs.values():
|
|
899
|
-
result = await proc.asend(result)
|
|
900
|
-
return result
|
|
901
|
-
|
|
902
|
-
def stateful_op(
|
|
903
|
-
self,
|
|
904
|
-
state: dict[str, tuple[typing.Any, int]] | None,
|
|
905
|
-
message: MessageInType | None,
|
|
906
|
-
) -> tuple[
|
|
907
|
-
dict[str, tuple[typing.Any, int]],
|
|
908
|
-
MessageOutType | None,
|
|
909
|
-
]:
|
|
910
|
-
result = message
|
|
911
|
-
state = state or {}
|
|
912
|
-
try:
|
|
913
|
-
state_keys = list(state.keys())
|
|
914
|
-
except AttributeError as e:
|
|
915
|
-
raise AttributeError(
|
|
916
|
-
"state provided to stateful_op must be a dict or None"
|
|
917
|
-
) from e
|
|
918
|
-
for key in state_keys:
|
|
919
|
-
if key not in self._procs:
|
|
920
|
-
raise KeyError(
|
|
921
|
-
f"Processor (name: {key}) in provided state not found in composite processor chain. "
|
|
922
|
-
f"Available keys: {list(self._procs.keys())}"
|
|
923
|
-
)
|
|
924
|
-
for k, proc in self._procs.items():
|
|
925
|
-
if _has_stateful_op(proc):
|
|
926
|
-
state[k], result = proc.stateful_op(state.get(k, None), result)
|
|
927
|
-
else:
|
|
928
|
-
result = proc.send(result)
|
|
929
|
-
return state, result
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
class CompositeProducer(
|
|
933
|
-
BaseProducer[SettingsType, MessageOutType],
|
|
934
|
-
CompositeStateful[SettingsType, MessageOutType],
|
|
935
|
-
ABC,
|
|
936
|
-
typing.Generic[SettingsType, MessageOutType],
|
|
937
|
-
):
|
|
938
|
-
"""
|
|
939
|
-
A producer that chains multiple processors (starting with a producer) together in a feedforward
|
|
940
|
-
non-branching graph. The individual processors may be stateless or stateful.
|
|
941
|
-
The first processor must be a producer, the last processor may be a consumer, otherwise
|
|
942
|
-
processors must be transformers.
|
|
943
|
-
"""
|
|
944
|
-
|
|
945
|
-
def __init__(self, *args, **kwargs) -> None:
|
|
946
|
-
super().__init__(*args, **kwargs) # .settings
|
|
947
|
-
self._processor_type = "producer"
|
|
948
|
-
self._procs = self._initialize_processors(self.settings)
|
|
949
|
-
self._validate_processor_chain()
|
|
950
|
-
first_proc = next(iter(self._procs.items()))
|
|
951
|
-
first_proc_in_type = _get_processor_message_type(first_proc[1], "in")
|
|
952
|
-
if first_proc_in_type is not None and first_proc_in_type is not type(None):
|
|
953
|
-
raise TypeError(
|
|
954
|
-
f"First processor (name: {first_proc[0]}, type: {first_proc[1].__class__.__name__}) "
|
|
955
|
-
f"is not a producer. Please use CompositeProcessor, not "
|
|
956
|
-
f"CompositeProducer for this composite chain."
|
|
957
|
-
)
|
|
958
|
-
self._hash = -1
|
|
959
|
-
|
|
960
|
-
@staticmethod
|
|
961
|
-
@abstractmethod
|
|
962
|
-
def _initialize_processors(
|
|
963
|
-
settings: SettingsType,
|
|
964
|
-
) -> dict[str, typing.Any]: ...
|
|
965
|
-
|
|
966
|
-
async def _produce(self) -> MessageOutType:
|
|
967
|
-
"""
|
|
968
|
-
Process a message through the pipeline of processors. If the message is None, or no message is provided,
|
|
969
|
-
then it will be assumed that the first processor is a producer and will be called without arguments.
|
|
970
|
-
This will be invoked via `__call__` or `send`.
|
|
971
|
-
We use `__next__` and `send` to allow using legacy generators that have yet to be converted to transformers.
|
|
972
|
-
|
|
973
|
-
Warning: All processors will be called using their asynchronous API, which is much faster for async
|
|
974
|
-
processors and does not incur penalty on sync processors.
|
|
975
|
-
"""
|
|
976
|
-
procs = list(self._procs.values())
|
|
977
|
-
result = await procs[0].__anext__()
|
|
978
|
-
for proc in procs[1:]:
|
|
979
|
-
result = await proc.asend(result)
|
|
980
|
-
return result
|
|
981
|
-
|
|
982
|
-
def stateful_op(
|
|
983
|
-
self,
|
|
984
|
-
state: dict[str, tuple[typing.Any, int]] | None,
|
|
985
|
-
) -> tuple[
|
|
986
|
-
dict[str, tuple[typing.Any, int]],
|
|
987
|
-
MessageOutType | None,
|
|
988
|
-
]:
|
|
989
|
-
state = state or {}
|
|
990
|
-
try:
|
|
991
|
-
state_keys = list(state.keys())
|
|
992
|
-
except AttributeError as e:
|
|
993
|
-
raise AttributeError(
|
|
994
|
-
"state provided to stateful_op must be a dict or None"
|
|
995
|
-
) from e
|
|
996
|
-
for key in state_keys:
|
|
997
|
-
if key not in self._procs:
|
|
998
|
-
raise KeyError(
|
|
999
|
-
f"Processor (name: {key}) in provided state not found in composite producer chain. "
|
|
1000
|
-
f"Available keys: {list(self._procs.keys())}"
|
|
1001
|
-
)
|
|
1002
|
-
labeled_procs = list(self._procs.items())
|
|
1003
|
-
prod_name, prod = labeled_procs[0]
|
|
1004
|
-
if _has_stateful_op(prod):
|
|
1005
|
-
state[prod_name], result = prod.stateful_op(state.get(prod_name, None))
|
|
1006
|
-
else:
|
|
1007
|
-
result = prod.__next__()
|
|
1008
|
-
for k, proc in labeled_procs[1:]:
|
|
1009
|
-
if _has_stateful_op(proc):
|
|
1010
|
-
state[k], result = proc.stateful_op(state.get(k, None), result)
|
|
1011
|
-
else:
|
|
1012
|
-
result = proc.send(result)
|
|
1013
|
-
return state, result
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
# --- Type variables for protocols and processors ---
|
|
1017
|
-
ProducerType = typing.TypeVar("ProducerType", bound=BaseProducer)
|
|
1018
|
-
ConsumerType = typing.TypeVar("ConsumerType", bound=BaseConsumer | BaseStatefulConsumer)
|
|
1019
|
-
TransformerType = typing.TypeVar(
|
|
1020
|
-
"TransformerType",
|
|
1021
|
-
bound=BaseTransformer | BaseStatefulTransformer | CompositeProcessor,
|
|
1022
|
-
)
|
|
1023
|
-
AdaptiveTransformerType = typing.TypeVar(
|
|
1024
|
-
"AdaptiveTransformerType", bound=BaseAdaptiveTransformer
|
|
18
|
+
# Re-export everything from ezmsg.baseproc for backwards compatibility
|
|
19
|
+
from ezmsg.baseproc import ( # noqa: E402
|
|
20
|
+
# Protocols
|
|
21
|
+
AdaptiveTransformer,
|
|
22
|
+
# Type variables
|
|
23
|
+
AdaptiveTransformerType,
|
|
24
|
+
# Stateful classes
|
|
25
|
+
BaseAdaptiveTransformer,
|
|
26
|
+
# Unit classes
|
|
27
|
+
BaseAdaptiveTransformerUnit,
|
|
28
|
+
BaseAsyncTransformer,
|
|
29
|
+
# Base processor classes
|
|
30
|
+
BaseConsumer,
|
|
31
|
+
BaseConsumerUnit,
|
|
32
|
+
BaseProcessor,
|
|
33
|
+
BaseProcessorUnit,
|
|
34
|
+
BaseProducer,
|
|
35
|
+
BaseProducerUnit,
|
|
36
|
+
BaseStatefulConsumer,
|
|
37
|
+
BaseStatefulProcessor,
|
|
38
|
+
BaseStatefulProducer,
|
|
39
|
+
BaseStatefulTransformer,
|
|
40
|
+
BaseTransformer,
|
|
41
|
+
BaseTransformerUnit,
|
|
42
|
+
# Composite classes
|
|
43
|
+
CompositeProcessor,
|
|
44
|
+
CompositeProducer,
|
|
45
|
+
CompositeStateful,
|
|
46
|
+
Consumer,
|
|
47
|
+
ConsumerType,
|
|
48
|
+
GenAxisArray,
|
|
49
|
+
MessageInType,
|
|
50
|
+
MessageOutType,
|
|
51
|
+
Processor,
|
|
52
|
+
Producer,
|
|
53
|
+
ProducerType,
|
|
54
|
+
# Message types
|
|
55
|
+
SampleMessage,
|
|
56
|
+
SettingsType,
|
|
57
|
+
Stateful,
|
|
58
|
+
StatefulConsumer,
|
|
59
|
+
StatefulProcessor,
|
|
60
|
+
StatefulProducer,
|
|
61
|
+
StatefulTransformer,
|
|
62
|
+
StateType,
|
|
63
|
+
Transformer,
|
|
64
|
+
TransformerType,
|
|
65
|
+
# Type resolution helpers
|
|
66
|
+
_get_base_processor_message_in_type,
|
|
67
|
+
_get_base_processor_message_out_type,
|
|
68
|
+
_get_base_processor_settings_type,
|
|
69
|
+
_get_base_processor_state_type,
|
|
70
|
+
_get_processor_message_type,
|
|
71
|
+
# Type utilities
|
|
72
|
+
check_message_type_compatibility,
|
|
73
|
+
get_base_adaptive_transformer_type,
|
|
74
|
+
get_base_consumer_type,
|
|
75
|
+
get_base_producer_type,
|
|
76
|
+
get_base_transformer_type,
|
|
77
|
+
is_sample_message,
|
|
78
|
+
# Decorators
|
|
79
|
+
processor_state,
|
|
80
|
+
# Profiling
|
|
81
|
+
profile_subpub,
|
|
82
|
+
resolve_typevar,
|
|
1025
83
|
)
|
|
1026
84
|
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
""
|
|
1049
|
-
Base
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
""
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
while True:
|
|
1093
|
-
out = await self.producer.__acall__()
|
|
1094
|
-
if out is not None: # and math.prod(out.data.shape) > 0:
|
|
1095
|
-
yield self.OUTPUT_SIGNAL, out
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
class BaseProcessorUnit(ez.Unit, ABC, typing.Generic[SettingsType]):
|
|
1099
|
-
"""
|
|
1100
|
-
Base class for processor units -- i.e. units that process messages.
|
|
1101
|
-
This is an abstract base class that provides common functionality for consumer and transformer
|
|
1102
|
-
units. You probably do not want to inherit from this class directly as you would need to define
|
|
1103
|
-
a custom implementation of `create_processor`.
|
|
1104
|
-
Refer instead to BaseConsumerUnit or BaseTransformerUnit.
|
|
1105
|
-
"""
|
|
1106
|
-
|
|
1107
|
-
INPUT_SETTINGS = ez.InputStream(SettingsType)
|
|
1108
|
-
|
|
1109
|
-
async def initialize(self) -> None:
|
|
1110
|
-
self.create_processor()
|
|
1111
|
-
|
|
1112
|
-
@abstractmethod
|
|
1113
|
-
def create_processor(self) -> None: ...
|
|
1114
|
-
|
|
1115
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
1116
|
-
async def on_settings(self, msg: SettingsType) -> None:
|
|
1117
|
-
"""
|
|
1118
|
-
Receive a settings message, override self.SETTINGS, and re-create the processor.
|
|
1119
|
-
Child classes that wish to have fine-grained control over whether the
|
|
1120
|
-
core processor resets on settings changes should override this method.
|
|
1121
|
-
|
|
1122
|
-
Args:
|
|
1123
|
-
msg: a settings message.
|
|
1124
|
-
"""
|
|
1125
|
-
self.apply_settings(msg) # type: ignore
|
|
1126
|
-
self.create_processor()
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
class BaseConsumerUnit(
|
|
1130
|
-
BaseProcessorUnit[SettingsType],
|
|
1131
|
-
ABC,
|
|
1132
|
-
typing.Generic[SettingsType, MessageInType, ConsumerType],
|
|
1133
|
-
):
|
|
1134
|
-
"""
|
|
1135
|
-
Base class for consumer units -- i.e. units that receive messages but do not return results.
|
|
1136
|
-
Implement a new Unit as follows:
|
|
1137
|
-
|
|
1138
|
-
class CustomUnit(BaseConsumerUnit[
|
|
1139
|
-
CustomConsumerSettings, # SettingsType
|
|
1140
|
-
AxisArray, # MessageInType
|
|
1141
|
-
CustomConsumer, # ConsumerType
|
|
1142
|
-
]):
|
|
1143
|
-
SETTINGS = CustomConsumerSettings
|
|
1144
|
-
|
|
1145
|
-
... that's all!
|
|
1146
|
-
|
|
1147
|
-
Where CustomConsumerSettings and CustomConsumer are custom implementations of:
|
|
1148
|
-
- ez.Settings for settings
|
|
1149
|
-
- BaseConsumer or BaseStatefulConsumer for the consumer implementation
|
|
1150
|
-
"""
|
|
1151
|
-
|
|
1152
|
-
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
1153
|
-
|
|
1154
|
-
def create_processor(self):
|
|
1155
|
-
# self.processor: ConsumerType[SettingsType, MessageInType, StateType]
|
|
1156
|
-
"""Create the consumer instance from settings."""
|
|
1157
|
-
consumer_type = get_base_consumer_type(self.__class__)
|
|
1158
|
-
self.processor = consumer_type(settings=self.SETTINGS)
|
|
1159
|
-
|
|
1160
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
1161
|
-
async def on_signal(self, message: MessageInType):
|
|
1162
|
-
"""
|
|
1163
|
-
Consume the message.
|
|
1164
|
-
Args:
|
|
1165
|
-
message: Input message to be consumed
|
|
1166
|
-
"""
|
|
1167
|
-
await self.processor.__acall__(message)
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
class BaseTransformerUnit(
|
|
1171
|
-
BaseProcessorUnit[SettingsType],
|
|
1172
|
-
ABC,
|
|
1173
|
-
typing.Generic[SettingsType, MessageInType, MessageOutType, TransformerType],
|
|
1174
|
-
):
|
|
1175
|
-
"""
|
|
1176
|
-
Base class for transformer units -- i.e. units that transform input messages into output messages.
|
|
1177
|
-
Implement a new Unit as follows:
|
|
1178
|
-
|
|
1179
|
-
class CustomUnit(BaseTransformerUnit[
|
|
1180
|
-
CustomTransformerSettings, # SettingsType
|
|
1181
|
-
AxisArray, # MessageInType
|
|
1182
|
-
AxisArray, # MessageOutType
|
|
1183
|
-
CustomTransformer, # TransformerType
|
|
1184
|
-
]):
|
|
1185
|
-
SETTINGS = CustomTransformerSettings
|
|
1186
|
-
|
|
1187
|
-
... that's all!
|
|
1188
|
-
|
|
1189
|
-
Where CustomTransformerSettings and CustomTransformer are custom implementations of:
|
|
1190
|
-
- ez.Settings for settings
|
|
1191
|
-
- One of these transformer types:
|
|
1192
|
-
* BaseTransformer
|
|
1193
|
-
* BaseStatefulTransformer
|
|
1194
|
-
* CompositeProcessor
|
|
1195
|
-
"""
|
|
1196
|
-
|
|
1197
|
-
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
1198
|
-
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
1199
|
-
|
|
1200
|
-
def create_processor(self):
|
|
1201
|
-
# self.processor: TransformerType[SettingsType, MessageInType, MessageOutType, StateType]
|
|
1202
|
-
"""Create the transformer instance from settings."""
|
|
1203
|
-
transformer_type = get_base_transformer_type(self.__class__)
|
|
1204
|
-
self.processor = transformer_type(settings=self.SETTINGS)
|
|
1205
|
-
|
|
1206
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
1207
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
1208
|
-
@profile_subpub(trace_oldest=False)
|
|
1209
|
-
async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
|
|
1210
|
-
result = await self.processor.__acall__(message)
|
|
1211
|
-
if result is not None: # and math.prod(result.data.shape) > 0:
|
|
1212
|
-
yield self.OUTPUT_SIGNAL, result
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
class BaseAdaptiveTransformerUnit(
|
|
1216
|
-
BaseProcessorUnit[SettingsType],
|
|
1217
|
-
ABC,
|
|
1218
|
-
typing.Generic[
|
|
1219
|
-
SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType
|
|
1220
|
-
],
|
|
1221
|
-
):
|
|
1222
|
-
INPUT_SAMPLE = ez.InputStream(SampleMessage)
|
|
1223
|
-
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
1224
|
-
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
1225
|
-
|
|
1226
|
-
def create_processor(self) -> None:
|
|
1227
|
-
# self.processor: AdaptiveTransformerType[SettingsType, MessageInType, MessageOutType, StateType]
|
|
1228
|
-
"""Create the adaptive transformer instance from settings."""
|
|
1229
|
-
adaptive_transformer_type = get_base_adaptive_transformer_type(self.__class__)
|
|
1230
|
-
self.processor = adaptive_transformer_type(settings=self.SETTINGS)
|
|
1231
|
-
|
|
1232
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
1233
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
1234
|
-
@profile_subpub(trace_oldest=False)
|
|
1235
|
-
async def on_signal(self, message: MessageInType) -> typing.AsyncGenerator:
|
|
1236
|
-
result = await self.processor.__acall__(message)
|
|
1237
|
-
if result is not None: # and math.prod(result.data.shape) > 0:
|
|
1238
|
-
yield self.OUTPUT_SIGNAL, result
|
|
1239
|
-
|
|
1240
|
-
@ez.subscriber(INPUT_SAMPLE)
|
|
1241
|
-
async def on_sample(self, msg: SampleMessage) -> None:
|
|
1242
|
-
await self.processor.apartial_fit(msg)
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
# Legacy class
|
|
1246
|
-
class GenAxisArray(ez.Unit):
|
|
1247
|
-
STATE = GenState
|
|
1248
|
-
|
|
1249
|
-
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
1250
|
-
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
1251
|
-
INPUT_SETTINGS = ez.InputStream(ez.Settings)
|
|
1252
|
-
|
|
1253
|
-
async def initialize(self) -> None:
|
|
1254
|
-
self.construct_generator()
|
|
1255
|
-
|
|
1256
|
-
# Method to be implemented by subclasses to construct the specific generator
|
|
1257
|
-
def construct_generator(self):
|
|
1258
|
-
raise NotImplementedError
|
|
1259
|
-
|
|
1260
|
-
@ez.subscriber(INPUT_SETTINGS)
|
|
1261
|
-
async def on_settings(self, msg: ez.Settings) -> None:
|
|
1262
|
-
"""
|
|
1263
|
-
Update unit settings and reset generator.
|
|
1264
|
-
Note: Not all units will require a full reset with new settings.
|
|
1265
|
-
Override this method to implement a selective reset.
|
|
1266
|
-
|
|
1267
|
-
Args:
|
|
1268
|
-
msg: Instance of SETTINGS object.
|
|
1269
|
-
"""
|
|
1270
|
-
self.apply_settings(msg)
|
|
1271
|
-
self.construct_generator()
|
|
1272
|
-
|
|
1273
|
-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
1274
|
-
@ez.publisher(OUTPUT_SIGNAL)
|
|
1275
|
-
@profile_subpub(trace_oldest=False)
|
|
1276
|
-
async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
|
|
1277
|
-
try:
|
|
1278
|
-
ret = self.STATE.gen.send(message)
|
|
1279
|
-
if math.prod(ret.data.shape) > 0:
|
|
1280
|
-
yield self.OUTPUT_SIGNAL, ret
|
|
1281
|
-
except (StopIteration, GeneratorExit):
|
|
1282
|
-
ez.logger.debug(f"Generator closed in {self.address}")
|
|
1283
|
-
except Exception:
|
|
1284
|
-
ez.logger.info(traceback.format_exc())
|
|
85
|
+
__all__ = [
|
|
86
|
+
# Protocols
|
|
87
|
+
"Processor",
|
|
88
|
+
"Producer",
|
|
89
|
+
"Consumer",
|
|
90
|
+
"Transformer",
|
|
91
|
+
"StatefulProcessor",
|
|
92
|
+
"StatefulProducer",
|
|
93
|
+
"StatefulConsumer",
|
|
94
|
+
"StatefulTransformer",
|
|
95
|
+
"AdaptiveTransformer",
|
|
96
|
+
# Type variables
|
|
97
|
+
"MessageInType",
|
|
98
|
+
"MessageOutType",
|
|
99
|
+
"SettingsType",
|
|
100
|
+
"StateType",
|
|
101
|
+
"ProducerType",
|
|
102
|
+
"ConsumerType",
|
|
103
|
+
"TransformerType",
|
|
104
|
+
"AdaptiveTransformerType",
|
|
105
|
+
# Decorators
|
|
106
|
+
"processor_state",
|
|
107
|
+
# Base processor classes
|
|
108
|
+
"BaseProcessor",
|
|
109
|
+
"BaseProducer",
|
|
110
|
+
"BaseConsumer",
|
|
111
|
+
"BaseTransformer",
|
|
112
|
+
# Stateful classes
|
|
113
|
+
"Stateful",
|
|
114
|
+
"BaseStatefulProcessor",
|
|
115
|
+
"BaseStatefulProducer",
|
|
116
|
+
"BaseStatefulConsumer",
|
|
117
|
+
"BaseStatefulTransformer",
|
|
118
|
+
"BaseAdaptiveTransformer",
|
|
119
|
+
"BaseAsyncTransformer",
|
|
120
|
+
# Composite classes
|
|
121
|
+
"CompositeStateful",
|
|
122
|
+
"CompositeProcessor",
|
|
123
|
+
"CompositeProducer",
|
|
124
|
+
# Unit classes
|
|
125
|
+
"BaseProducerUnit",
|
|
126
|
+
"BaseProcessorUnit",
|
|
127
|
+
"BaseConsumerUnit",
|
|
128
|
+
"BaseTransformerUnit",
|
|
129
|
+
"BaseAdaptiveTransformerUnit",
|
|
130
|
+
"GenAxisArray",
|
|
131
|
+
# Type resolution helpers
|
|
132
|
+
"get_base_producer_type",
|
|
133
|
+
"get_base_consumer_type",
|
|
134
|
+
"get_base_transformer_type",
|
|
135
|
+
"get_base_adaptive_transformer_type",
|
|
136
|
+
"_get_base_processor_settings_type",
|
|
137
|
+
"_get_base_processor_message_in_type",
|
|
138
|
+
"_get_base_processor_message_out_type",
|
|
139
|
+
"_get_base_processor_state_type",
|
|
140
|
+
"_get_processor_message_type",
|
|
141
|
+
# Message types
|
|
142
|
+
"SampleMessage",
|
|
143
|
+
"is_sample_message",
|
|
144
|
+
# Profiling
|
|
145
|
+
"profile_subpub",
|
|
146
|
+
# Type utilities
|
|
147
|
+
"check_message_type_compatibility",
|
|
148
|
+
"resolve_typevar",
|
|
149
|
+
]
|