ezmsg-baseproc 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,209 @@
1
+ """Base processor classes for ezmsg (non-stateful)."""
2
+
3
+ import typing
4
+ from abc import ABC, abstractmethod
5
+
6
+ from .protocols import MessageInType, MessageOutType, SettingsType
7
+ from .util.asio import run_coroutine_sync
8
+ from .util.typeresolution import resolve_typevar
9
+
10
+
11
+ def _get_base_processor_settings_type(cls: type) -> type:
12
+ try:
13
+ return resolve_typevar(cls, SettingsType)
14
+ except TypeError as e:
15
+ raise TypeError(
16
+ f"Could not resolve settings type for {cls}. "
17
+ f"Ensure that the class is properly annotated with a SettingsType."
18
+ ) from e
19
+
20
+
21
+ def _get_base_processor_message_in_type(cls: type) -> type:
22
+ return resolve_typevar(cls, MessageInType)
23
+
24
+
25
+ def _get_base_processor_message_out_type(cls: type) -> type:
26
+ return resolve_typevar(cls, MessageOutType)
27
+
28
+
29
+ def _unify_settings(obj: typing.Any, settings: object | None, *args, **kwargs) -> typing.Any:
30
+ """Helper function to unify settings for processor initialization."""
31
+ settings_type = _get_base_processor_settings_type(obj.__class__)
32
+
33
+ if settings is None:
34
+ if len(args) > 0 and isinstance(args[0], settings_type):
35
+ settings = args[0]
36
+ elif len(args) > 0 or len(kwargs) > 0:
37
+ settings = settings_type(*args, **kwargs)
38
+ else:
39
+ settings = settings_type()
40
+ assert isinstance(settings, settings_type), "Settings must be of type " + str(settings_type)
41
+ return settings
42
+
43
+
44
+ class BaseProcessor(ABC, typing.Generic[SettingsType, MessageInType, MessageOutType]):
45
+ """
46
+ Base class for processors. You probably do not want to inherit from this class directly.
47
+ Refer instead to the more specific base classes.
48
+ * Use :obj:`BaseConsumer` or :obj:`BaseTransformer` for ops that return a result or not, respectively.
49
+ * Use :obj:`BaseStatefulProcessor` and its children for operations that require state.
50
+
51
+ Note that `BaseProcessor` and its children are sync by default. If you need async by defualt, then
52
+ override the async methods and call them from the sync methods. Look to `BaseProducer` for examples of
53
+ calling async methods from sync methods.
54
+ """
55
+
56
+ settings: SettingsType
57
+
58
+ @classmethod
59
+ def get_settings_type(cls) -> type[SettingsType]:
60
+ return _get_base_processor_settings_type(cls)
61
+
62
+ @classmethod
63
+ def get_message_type(cls, dir: str) -> typing.Any:
64
+ if dir == "in":
65
+ return _get_base_processor_message_in_type(cls)
66
+ elif dir == "out":
67
+ return _get_base_processor_message_out_type(cls)
68
+ else:
69
+ raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
70
+
71
+ def __init__(self, *args, settings: SettingsType | None = None, **kwargs) -> None:
72
+ self.settings = _unify_settings(self, settings, *args, **kwargs)
73
+
74
+ @abstractmethod
75
+ def _process(self, message: typing.Any) -> typing.Any: ...
76
+
77
+ async def _aprocess(self, message: typing.Any) -> typing.Any:
78
+ """Override this for native async processing."""
79
+ return self._process(message)
80
+
81
+ def __call__(self, message: typing.Any) -> typing.Any:
82
+ # Note: We use the indirection to `_process` because this allows us to
83
+ # modify __call__ in derived classes with common functionality while
84
+ # minimizing the boilerplate code in derived classes as they only need to
85
+ # implement `_process`.
86
+ return self._process(message)
87
+
88
+ async def __acall__(self, message: typing.Any) -> typing.Any:
89
+ """
90
+ In Python 3.12+, we can invoke this method simply with `await obj(message)`,
91
+ but earlier versions require direct syntax: `await obj.__acall__(message)`.
92
+ """
93
+ return await self._aprocess(message)
94
+
95
+ def send(self, message: typing.Any) -> typing.Any:
96
+ """Alias for __call__."""
97
+ return self(message)
98
+
99
+ async def asend(self, message: typing.Any) -> typing.Any:
100
+ """Alias for __acall__."""
101
+ return await self.__acall__(message)
102
+
103
+
104
+ class BaseProducer(ABC, typing.Generic[SettingsType, MessageOutType]):
105
+ """
106
+ Base class for producers -- processors that generate messages without consuming inputs.
107
+
108
+ Note that `BaseProducer` and its children are async by default, and the sync methods simply wrap
109
+ the async methods. This is the opposite of :obj:`BaseProcessor` and its children which are sync by default.
110
+ These classes are designed this way because it is highly likely that a producer, which (probably) does not
111
+ receive inputs, will require some sort of IO which will benefit from being async.
112
+ """
113
+
114
+ @classmethod
115
+ def get_settings_type(cls) -> type[SettingsType]:
116
+ return _get_base_processor_settings_type(cls)
117
+
118
+ @classmethod
119
+ def get_message_type(cls, dir: str) -> type[MessageOutType] | None:
120
+ if dir == "out":
121
+ return _get_base_processor_message_out_type(cls)
122
+ elif dir == "in":
123
+ return None
124
+ else:
125
+ raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
126
+
127
+ def __init__(self, *args, settings: SettingsType | None = None, **kwargs) -> None:
128
+ self.settings = _unify_settings(self, settings, *args, **kwargs)
129
+
130
+ @abstractmethod
131
+ async def _produce(self) -> MessageOutType: ...
132
+
133
+ async def __acall__(self) -> MessageOutType:
134
+ return await self._produce()
135
+
136
+ def __call__(self) -> MessageOutType:
137
+ # Warning: This is a bit slow. Override this method in derived classes if performance is critical.
138
+ return run_coroutine_sync(self.__acall__())
139
+
140
+ def __iter__(self) -> typing.Iterator[MessageOutType]:
141
+ # Make self an iterator
142
+ return self
143
+
144
+ async def __anext__(self) -> MessageOutType:
145
+ # So this can be used as an async generator.
146
+ return await self.__acall__()
147
+
148
+ def __next__(self) -> MessageOutType:
149
+ # So this can be used as a generator.
150
+ return self()
151
+
152
+
153
+ class BaseConsumer(
154
+ BaseProcessor[SettingsType, MessageInType, None],
155
+ ABC,
156
+ typing.Generic[SettingsType, MessageInType],
157
+ ):
158
+ """
159
+ Base class for consumers -- processors that receive messages but don't produce output.
160
+ This base simply overrides type annotations of BaseProcessor to remove the outputs.
161
+ (We don't bother overriding `send` and `asend` because those are deprecated.)
162
+ """
163
+
164
+ @classmethod
165
+ def get_message_type(cls, dir: str) -> type[MessageInType] | None:
166
+ if dir == "in":
167
+ return _get_base_processor_message_in_type(cls)
168
+ elif dir == "out":
169
+ return None
170
+ else:
171
+ raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
172
+
173
+ @abstractmethod
174
+ def _process(self, message: MessageInType) -> None: ...
175
+
176
+ async def _aprocess(self, message: MessageInType) -> None:
177
+ """Override this for native async processing."""
178
+ return self._process(message)
179
+
180
+ def __call__(self, message: MessageInType) -> None:
181
+ return super().__call__(message)
182
+
183
+ async def __acall__(self, message: MessageInType) -> None:
184
+ return await super().__acall__(message)
185
+
186
+
187
+ class BaseTransformer(
188
+ BaseProcessor[SettingsType, MessageInType, MessageOutType],
189
+ ABC,
190
+ typing.Generic[SettingsType, MessageInType, MessageOutType],
191
+ ):
192
+ """
193
+ Base class for transformers -- processors which receive messages and produce output.
194
+ This base simply overrides type annotations of :obj:`BaseProcessor` to indicate that outputs are not optional.
195
+ (We don't bother overriding `send` and `asend` because those are deprecated.)
196
+ """
197
+
198
+ @abstractmethod
199
+ def _process(self, message: MessageInType) -> MessageOutType: ...
200
+
201
+ async def _aprocess(self, message: MessageInType) -> MessageOutType:
202
+ """Override this for native async processing."""
203
+ return self._process(message)
204
+
205
+ def __call__(self, message: MessageInType) -> MessageOutType:
206
+ return super().__call__(message)
207
+
208
+ async def __acall__(self, message: MessageInType) -> MessageOutType:
209
+ return await super().__acall__(message)
@@ -0,0 +1,147 @@
1
+ """Protocol definitions and type variables for ezmsg processors."""
2
+
3
+ import functools
4
+ import typing
5
+ from dataclasses import dataclass
6
+
7
+ from .util.message import SampleMessage
8
+
9
+ # --- Processor state decorator ---
10
+ processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
11
+
12
+ # --- Type variables for protocols and processors ---
13
+ MessageInType = typing.TypeVar("MessageInType")
14
+ MessageOutType = typing.TypeVar("MessageOutType")
15
+ SettingsType = typing.TypeVar("SettingsType")
16
+ StateType = typing.TypeVar("StateType")
17
+
18
+
19
+ # --- Protocols for processors ---
20
+ class Processor(typing.Protocol[SettingsType, MessageInType, MessageOutType]):
21
+ """
22
+ Protocol for processors.
23
+ You probably will not implement this protocol directly.
24
+ Refer instead to the less ambiguous Consumer and Transformer protocols, and the base classes
25
+ in this module which implement them.
26
+
27
+ Note: In Python 3.12+, we can invoke `__acall__` directly using `await obj(message)`,
28
+ but to support earlier versions we need to use `await obj.__acall__(message)`.
29
+ """
30
+
31
+ def __call__(self, message: typing.Any) -> typing.Any: ...
32
+ async def __acall__(self, message: typing.Any) -> typing.Any: ...
33
+
34
+
35
+ class Producer(typing.Protocol[SettingsType, MessageOutType]):
36
+ """
37
+ Protocol for producers that generate messages.
38
+ """
39
+
40
+ def __call__(self) -> MessageOutType: ...
41
+ async def __acall__(self) -> MessageOutType: ...
42
+
43
+
44
+ class Consumer(Processor[SettingsType, MessageInType, None], typing.Protocol):
45
+ """
46
+ Protocol for consumers that receive messages but do not return a result.
47
+ """
48
+
49
+ def __call__(self, message: MessageInType) -> None: ...
50
+ async def __acall__(self, message: MessageInType) -> None: ...
51
+
52
+
53
+ class Transformer(Processor[SettingsType, MessageInType, MessageOutType], typing.Protocol):
54
+ """Protocol for transformers that receive messages and return a result of the same class."""
55
+
56
+ def __call__(self, message: MessageInType) -> MessageOutType: ...
57
+ async def __acall__(self, message: MessageInType) -> MessageOutType: ...
58
+
59
+
60
+ class StatefulProcessor(typing.Protocol[SettingsType, MessageInType, MessageOutType, StateType]):
61
+ """
62
+ Base protocol for _stateful_ message processors.
63
+ You probably will not implement this protocol directly.
64
+ Refer instead to the less ambiguous StatefulConsumer and StatefulTransformer protocols.
65
+ """
66
+
67
+ @property
68
+ def state(self) -> StateType: ...
69
+
70
+ @state.setter
71
+ def state(self, state: StateType | bytes | None) -> None: ...
72
+
73
+ def __call__(self, message: typing.Any) -> typing.Any: ...
74
+ async def __acall__(self, message: typing.Any) -> typing.Any: ...
75
+
76
+ def stateful_op(
77
+ self,
78
+ state: typing.Any,
79
+ message: typing.Any,
80
+ ) -> tuple[typing.Any, typing.Any]: ...
81
+
82
+
83
+ class StatefulProducer(typing.Protocol[SettingsType, MessageOutType, StateType]):
84
+ """Protocol for producers that generate messages without consuming inputs."""
85
+
86
+ @property
87
+ def state(self) -> StateType: ...
88
+
89
+ @state.setter
90
+ def state(self, state: StateType | bytes | None) -> None: ...
91
+
92
+ def __call__(self) -> MessageOutType: ...
93
+ async def __acall__(self) -> MessageOutType: ...
94
+
95
+ def stateful_op(
96
+ self,
97
+ state: typing.Any,
98
+ ) -> tuple[typing.Any, typing.Any]: ...
99
+
100
+
101
+ class StatefulConsumer(StatefulProcessor[SettingsType, MessageInType, None, StateType], typing.Protocol):
102
+ """Protocol specifically for processors that consume messages without producing output."""
103
+
104
+ def __call__(self, message: MessageInType) -> None: ...
105
+ async def __acall__(self, message: MessageInType) -> None: ...
106
+
107
+ def stateful_op(
108
+ self,
109
+ state: tuple[StateType, int],
110
+ message: MessageInType,
111
+ ) -> tuple[tuple[StateType, int], None]: ...
112
+
113
+ """
114
+ Note: The return type is still a tuple even though the second entry is always None.
115
+ This is intentional so we can use the same protocol for both consumers and transformers,
116
+ and chain them together in a pipeline (e.g., `CompositeProcessor`).
117
+ """
118
+
119
+
120
+ class StatefulTransformer(
121
+ StatefulProcessor[SettingsType, MessageInType, MessageOutType, StateType],
122
+ typing.Protocol,
123
+ ):
124
+ """
125
+ Protocol specifically for processors that transform messages.
126
+ """
127
+
128
+ def __call__(self, message: MessageInType) -> MessageOutType: ...
129
+ async def __acall__(self, message: MessageInType) -> MessageOutType: ...
130
+
131
+ def stateful_op(
132
+ self,
133
+ state: tuple[StateType, int],
134
+ message: MessageInType,
135
+ ) -> tuple[tuple[StateType, int], MessageOutType]: ...
136
+
137
+
138
+ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
139
+ def partial_fit(self, message: SampleMessage) -> None:
140
+ """Update transformer state using labeled training data.
141
+
142
+ This method should update the internal state/parameters of the transformer
143
+ based on the provided labeled samples, without performing any transformation.
144
+ """
145
+ ...
146
+
147
+ async def apartial_fit(self, message: SampleMessage) -> None: ...
@@ -0,0 +1,323 @@
1
+ """Stateful processor base classes for ezmsg."""
2
+
3
+ import pickle
4
+ import typing
5
+ from abc import ABC, abstractmethod
6
+
7
+ from .processor import (
8
+ BaseProcessor,
9
+ BaseProducer,
10
+ _get_base_processor_message_in_type,
11
+ )
12
+ from .protocols import MessageInType, MessageOutType, SettingsType, StateType
13
+ from .util.asio import run_coroutine_sync
14
+ from .util.message import SampleMessage, is_sample_message
15
+ from .util.typeresolution import resolve_typevar
16
+
17
+
18
+ def _get_base_processor_state_type(cls: type) -> type:
19
+ try:
20
+ return resolve_typevar(cls, StateType)
21
+ except TypeError as e:
22
+ raise TypeError(
23
+ f"Could not resolve state type for {cls}. Ensure that the class is properly annotated with a StateType."
24
+ ) from e
25
+
26
+
27
+ class Stateful(ABC, typing.Generic[StateType]):
28
+ """
29
+ Mixin class for stateful processors. DO NOT use this class directly.
30
+ Used to enforce that the processor/producer has a state attribute and stateful_op method.
31
+ """
32
+
33
+ _state: StateType
34
+
35
+ @classmethod
36
+ def get_state_type(cls) -> type[StateType]:
37
+ return _get_base_processor_state_type(cls)
38
+
39
+ @property
40
+ def state(self) -> StateType:
41
+ return self._state
42
+
43
+ @state.setter
44
+ def state(self, state: StateType | bytes | None) -> None:
45
+ if state is not None:
46
+ if isinstance(state, bytes):
47
+ self._state = pickle.loads(state)
48
+ else:
49
+ self._state = state # type: ignore
50
+
51
+ def _hash_message(self, message: typing.Any) -> int:
52
+ """
53
+ Check if the message metadata indicates a need for state reset.
54
+
55
+ This method is not abstract because there are some processors that might only
56
+ need to reset once but are otherwise insensitive to the message structure.
57
+
58
+ For example, an activation function that benefits greatly from pre-computed values should
59
+ do this computation in `_reset_state` and attach those values to the processor state,
60
+ but if it e.g. operates elementwise on the input then it doesn't care if the incoming
61
+ data changes shape or sample rate so you don't need to reset again.
62
+
63
+ All processors' initial state should have `.hash = -1` then by returning `0` here
64
+ we force an update on the first message.
65
+ """
66
+ return 0
67
+
68
+ @abstractmethod
69
+ def _reset_state(self, *args: typing.Any, **kwargs: typing.Any) -> None:
70
+ """
71
+ Reset internal state based on
72
+ - new message metadata (processors), or
73
+ - after first call (producers).
74
+ """
75
+ ...
76
+
77
+ @abstractmethod
78
+ def stateful_op(self, *args: typing.Any, **kwargs: typing.Any) -> tuple: ...
79
+
80
+
81
+ class BaseStatefulProcessor(
82
+ BaseProcessor[SettingsType, MessageInType, MessageOutType],
83
+ Stateful[StateType],
84
+ ABC,
85
+ typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
86
+ ):
87
+ """
88
+ Base class implementing common stateful processor functionality.
89
+ You probably do not want to inherit from this class directly.
90
+ Refer instead to the more specific base classes.
91
+ Use BaseStatefulConsumer for operations that do not return a result,
92
+ or BaseStatefulTransformer for operations that do return a result.
93
+ """
94
+
95
+ def __init__(self, *args, **kwargs) -> None:
96
+ super().__init__(*args, **kwargs)
97
+ self._hash = -1
98
+ state_type = self.__class__.get_state_type()
99
+ self._state: StateType = state_type()
100
+ # TODO: Enforce that StateType has .hash: int field.
101
+
102
+ @abstractmethod
103
+ def _reset_state(self, message: typing.Any) -> None:
104
+ """
105
+ Reset internal state based on new message metadata.
106
+ This method will only be called when there is a significant change in the message metadata,
107
+ such as sample rate or shape (criteria defined by `_hash_message`), and not for every message,
108
+ so use it to do all the expensive pre-allocation and caching of variables that can speed up
109
+ the processing of subsequent messages in `_process`.
110
+ """
111
+ ...
112
+
113
+ @abstractmethod
114
+ def _process(self, message: typing.Any) -> typing.Any: ...
115
+
116
+ def __call__(self, message: typing.Any) -> typing.Any:
117
+ msg_hash = self._hash_message(message)
118
+ if msg_hash != self._hash:
119
+ self._reset_state(message)
120
+ self._hash = msg_hash
121
+ return self._process(message)
122
+
123
+ async def __acall__(self, message: typing.Any) -> typing.Any:
124
+ msg_hash = self._hash_message(message)
125
+ if msg_hash != self._hash:
126
+ self._reset_state(message)
127
+ self._hash = msg_hash
128
+ return await self._aprocess(message)
129
+
130
+ def stateful_op(
131
+ self,
132
+ state: tuple[StateType, int] | None,
133
+ message: typing.Any,
134
+ ) -> tuple[tuple[StateType, int], typing.Any]:
135
+ if state is not None:
136
+ self.state, self._hash = state
137
+ result = self(message)
138
+ return (self.state, self._hash), result
139
+
140
+
141
+ class BaseStatefulProducer(
142
+ BaseProducer[SettingsType, MessageOutType],
143
+ Stateful[StateType],
144
+ ABC,
145
+ typing.Generic[SettingsType, MessageOutType, StateType],
146
+ ):
147
+ """
148
+ Base class implementing common stateful producer functionality.
149
+ Examples of stateful producers are things that require counters, clocks,
150
+ or to cycle through a set of values.
151
+
152
+ Unlike BaseStatefulProcessor, this class does not message hashing because there
153
+ are no input messages. We still use self._hash to simply track the transition from
154
+ initialization (.hash == -1) to state reset (.hash == 0).
155
+ """
156
+
157
+ def __init__(self, *args, **kwargs) -> None:
158
+ super().__init__(*args, **kwargs) # .settings
159
+ self._hash = -1
160
+ state_type = self.__class__.get_state_type()
161
+ self._state: StateType = state_type()
162
+
163
+ @abstractmethod
164
+ def _reset_state(self) -> None:
165
+ """
166
+ Reset internal state upon first call.
167
+ """
168
+ ...
169
+
170
+ async def __acall__(self) -> MessageOutType:
171
+ if self._hash == -1:
172
+ self._reset_state()
173
+ self._hash = 0
174
+ return await self._produce()
175
+
176
+ def stateful_op(
177
+ self,
178
+ state: tuple[StateType, int] | None,
179
+ ) -> tuple[tuple[StateType, int], MessageOutType]:
180
+ if state is not None:
181
+ self.state, self._hash = state # Update state via setter
182
+ result = self() # Uses synchronous call
183
+ return (self.state, self._hash), result
184
+
185
+
186
+ class BaseStatefulConsumer(
187
+ BaseStatefulProcessor[SettingsType, MessageInType, None, StateType],
188
+ ABC,
189
+ typing.Generic[SettingsType, MessageInType, StateType],
190
+ ):
191
+ """
192
+ Base class for stateful message consumers that don't produce output.
193
+ This class merely overrides the type annotations of BaseStatefulProcessor.
194
+ """
195
+
196
+ @classmethod
197
+ def get_message_type(cls, dir: str) -> type[MessageInType] | None:
198
+ if dir == "in":
199
+ return _get_base_processor_message_in_type(cls)
200
+ elif dir == "out":
201
+ return None
202
+ else:
203
+ raise ValueError(f"Invalid direction: {dir}. Use 'in' or 'out'.")
204
+
205
+ @abstractmethod
206
+ def _process(self, message: MessageInType) -> None: ...
207
+
208
+ async def _aprocess(self, message: MessageInType) -> None:
209
+ return self._process(message)
210
+
211
+ def __call__(self, message: MessageInType) -> None:
212
+ return super().__call__(message)
213
+
214
+ async def __acall__(self, message: MessageInType) -> None:
215
+ return await super().__acall__(message)
216
+
217
+ def stateful_op(
218
+ self,
219
+ state: tuple[StateType, int] | None,
220
+ message: MessageInType,
221
+ ) -> tuple[tuple[StateType, int], None]:
222
+ state, _ = super().stateful_op(state, message)
223
+ return state, None
224
+
225
+
226
+ class BaseStatefulTransformer(
227
+ BaseStatefulProcessor[SettingsType, MessageInType, MessageOutType, StateType],
228
+ ABC,
229
+ typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
230
+ ):
231
+ """
232
+ Base class for stateful message transformers that produce output.
233
+ This class merely overrides the type annotations of BaseStatefulProcessor.
234
+ """
235
+
236
+ @abstractmethod
237
+ def _process(self, message: MessageInType) -> MessageOutType: ...
238
+
239
+ async def _aprocess(self, message: MessageInType) -> MessageOutType:
240
+ return self._process(message)
241
+
242
+ def __call__(self, message: MessageInType) -> MessageOutType:
243
+ return super().__call__(message)
244
+
245
+ async def __acall__(self, message: MessageInType) -> MessageOutType:
246
+ return await super().__acall__(message)
247
+
248
+ def stateful_op(
249
+ self,
250
+ state: tuple[StateType, int] | None,
251
+ message: MessageInType,
252
+ ) -> tuple[tuple[StateType, int], MessageOutType]:
253
+ return super().stateful_op(state, message)
254
+
255
+
256
+ class BaseAdaptiveTransformer(
257
+ BaseStatefulTransformer[
258
+ SettingsType,
259
+ MessageInType | SampleMessage,
260
+ MessageOutType | None,
261
+ StateType,
262
+ ],
263
+ ABC,
264
+ typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
265
+ ):
266
+ @abstractmethod
267
+ def partial_fit(self, message: SampleMessage) -> None: ...
268
+
269
+ async def apartial_fit(self, message: SampleMessage) -> None:
270
+ """Override me if you need async partial fitting."""
271
+ return self.partial_fit(message)
272
+
273
+ def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
274
+ """
275
+ Adapt transformer with training data (and optionally labels)
276
+ in SampleMessage
277
+
278
+ Args:
279
+ message: An instance of SampleMessage with optional
280
+ labels (y) in message.trigger.value.data and
281
+ data (X) in message.sample.data
282
+
283
+ Returns: None
284
+ """
285
+ if is_sample_message(message):
286
+ return self.partial_fit(message)
287
+ return super().__call__(message)
288
+
289
+ async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
290
+ if is_sample_message(message):
291
+ return await self.apartial_fit(message)
292
+ return await super().__acall__(message)
293
+
294
+
295
+ class BaseAsyncTransformer(
296
+ BaseStatefulTransformer[SettingsType, MessageInType, MessageOutType, StateType],
297
+ ABC,
298
+ typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
299
+ ):
300
+ """
301
+ This reverses the priority of async and sync methods from :obj:`BaseStatefulTransformer`.
302
+ Whereas in :obj:`BaseStatefulTransformer`, the async methods simply called the sync methods,
303
+ here the sync methods call the async methods, more similar to :obj:`BaseStatefulProducer`.
304
+ """
305
+
306
+ def _process(self, message: MessageInType) -> MessageOutType:
307
+ return run_coroutine_sync(self._aprocess(message))
308
+
309
+ @abstractmethod
310
+ async def _aprocess(self, message: MessageInType) -> MessageOutType: ...
311
+
312
+ def __call__(self, message: MessageInType) -> MessageOutType:
313
+ # Override (synchronous) __call__ to run coroutine `aprocess`.
314
+ return run_coroutine_sync(self.__acall__(message))
315
+
316
+ async def __acall__(self, message: MessageInType) -> MessageOutType:
317
+ # Note: In Python 3.12, we can invoke this with `await obj(message)`
318
+ # Earlier versions must be explicit: `await obj.__acall__(message)`
319
+ msg_hash = self._hash_message(message)
320
+ if msg_hash != self._hash:
321
+ self._reset_state(message)
322
+ self._hash = msg_hash
323
+ return await self._aprocess(message)