ezmsg-baseproc 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/baseproc/__init__.py +2 -2
- ezmsg/baseproc/__version__.py +2 -2
- ezmsg/baseproc/counter.py +2 -2
- ezmsg/baseproc/protocols.py +3 -4
- ezmsg/baseproc/stateful.py +23 -9
- ezmsg/baseproc/units.py +38 -7
- ezmsg/baseproc/util/message.py +19 -3
- {ezmsg_baseproc-1.2.0.dist-info → ezmsg_baseproc-1.3.0.dist-info}/METADATA +1 -1
- ezmsg_baseproc-1.3.0.dist-info/RECORD +19 -0
- ezmsg_baseproc-1.2.0.dist-info/RECORD +0 -19
- {ezmsg_baseproc-1.2.0.dist-info → ezmsg_baseproc-1.3.0.dist-info}/WHEEL +0 -0
- {ezmsg_baseproc-1.2.0.dist-info → ezmsg_baseproc-1.3.0.dist-info}/licenses/LICENSE +0 -0
ezmsg/baseproc/__init__.py
CHANGED
|
@@ -82,7 +82,7 @@ from .stateful import (
|
|
|
82
82
|
from .units import (
|
|
83
83
|
AdaptiveTransformerType,
|
|
84
84
|
BaseAdaptiveTransformerUnit,
|
|
85
|
-
|
|
85
|
+
BaseClockDrivenUnit,
|
|
86
86
|
BaseConsumerUnit,
|
|
87
87
|
BaseProcessorUnit,
|
|
88
88
|
BaseProducerUnit,
|
|
@@ -158,7 +158,7 @@ __all__ = [
|
|
|
158
158
|
"BaseConsumerUnit",
|
|
159
159
|
"BaseTransformerUnit",
|
|
160
160
|
"BaseAdaptiveTransformerUnit",
|
|
161
|
-
"
|
|
161
|
+
"BaseClockDrivenUnit",
|
|
162
162
|
"GenAxisArray",
|
|
163
163
|
# Type resolution helpers
|
|
164
164
|
"get_base_producer_type",
|
ezmsg/baseproc/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.
|
|
32
|
-
__version_tuple__ = version_tuple = (1,
|
|
31
|
+
__version__ = version = '1.3.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 3, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
ezmsg/baseproc/counter.py
CHANGED
|
@@ -9,7 +9,7 @@ from .clockdriven import (
|
|
|
9
9
|
ClockDrivenState,
|
|
10
10
|
)
|
|
11
11
|
from .protocols import processor_state
|
|
12
|
-
from .units import
|
|
12
|
+
from .units import BaseClockDrivenUnit
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class CounterSettings(ClockDrivenSettings):
|
|
@@ -57,7 +57,7 @@ class CounterTransformer(BaseClockDrivenProducer[CounterSettings, CounterTransfo
|
|
|
57
57
|
)
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
class Counter(
|
|
60
|
+
class Counter(BaseClockDrivenUnit[CounterSettings, CounterTransformer]):
|
|
61
61
|
"""
|
|
62
62
|
Transforms clock ticks into monotonically increasing counter values as AxisArray.
|
|
63
63
|
|
ezmsg/baseproc/protocols.py
CHANGED
|
@@ -5,8 +5,7 @@ import typing
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
|
|
7
7
|
import ezmsg.core as ez
|
|
8
|
-
|
|
9
|
-
from .util.message import SampleMessage
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
9
|
|
|
11
10
|
# --- Processor state decorator ---
|
|
12
11
|
processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
|
|
@@ -138,7 +137,7 @@ class StatefulTransformer(
|
|
|
138
137
|
|
|
139
138
|
|
|
140
139
|
class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
|
|
141
|
-
def partial_fit(self, message:
|
|
140
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
142
141
|
"""Update transformer state using labeled training data.
|
|
143
142
|
|
|
144
143
|
This method should update the internal state/parameters of the transformer
|
|
@@ -146,4 +145,4 @@ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
|
|
|
146
145
|
"""
|
|
147
146
|
...
|
|
148
147
|
|
|
149
|
-
async def apartial_fit(self, message:
|
|
148
|
+
async def apartial_fit(self, message: AxisArray) -> None: ...
|
ezmsg/baseproc/stateful.py
CHANGED
|
@@ -4,6 +4,9 @@ import pickle
|
|
|
4
4
|
import typing
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.util import replace
|
|
9
|
+
|
|
7
10
|
from .processor import (
|
|
8
11
|
BaseProcessor,
|
|
9
12
|
BaseProducer,
|
|
@@ -256,7 +259,7 @@ class BaseStatefulTransformer(
|
|
|
256
259
|
class BaseAdaptiveTransformer(
|
|
257
260
|
BaseStatefulTransformer[
|
|
258
261
|
SettingsType,
|
|
259
|
-
MessageInType
|
|
262
|
+
MessageInType,
|
|
260
263
|
MessageOutType | None,
|
|
261
264
|
StateType,
|
|
262
265
|
],
|
|
@@ -264,30 +267,41 @@ class BaseAdaptiveTransformer(
|
|
|
264
267
|
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
265
268
|
):
|
|
266
269
|
@abstractmethod
|
|
267
|
-
def partial_fit(self, message:
|
|
270
|
+
def partial_fit(self, message: AxisArray) -> None: ...
|
|
268
271
|
|
|
269
|
-
async def apartial_fit(self, message:
|
|
272
|
+
async def apartial_fit(self, message: AxisArray) -> None:
|
|
270
273
|
"""Override me if you need async partial fitting."""
|
|
271
274
|
return self.partial_fit(message)
|
|
272
275
|
|
|
273
|
-
def __call__(self, message: MessageInType
|
|
276
|
+
def __call__(self, message: MessageInType) -> MessageOutType | None:
|
|
274
277
|
"""
|
|
275
278
|
Adapt transformer with training data (and optionally labels)
|
|
276
|
-
in
|
|
279
|
+
in AxisArray with attrs["trigger"].
|
|
277
280
|
|
|
278
281
|
Args:
|
|
279
|
-
message: An
|
|
280
|
-
labels (y) in
|
|
281
|
-
data (X) in message.
|
|
282
|
+
message: An AxisArray with optional trigger in attrs["trigger"],
|
|
283
|
+
containing labels (y) in attrs["trigger"].value and
|
|
284
|
+
data (X) in message.data
|
|
282
285
|
|
|
283
286
|
Returns: None
|
|
284
287
|
"""
|
|
285
288
|
if is_sample_message(message):
|
|
289
|
+
if isinstance(message, SampleMessage):
|
|
290
|
+
# Auto-convert old format → new format
|
|
291
|
+
message = replace(
|
|
292
|
+
message.sample,
|
|
293
|
+
attrs={**message.sample.attrs, "trigger": message.trigger},
|
|
294
|
+
)
|
|
286
295
|
return self.partial_fit(message)
|
|
287
296
|
return super().__call__(message)
|
|
288
297
|
|
|
289
|
-
async def __acall__(self, message: MessageInType
|
|
298
|
+
async def __acall__(self, message: MessageInType) -> MessageOutType | None:
|
|
290
299
|
if is_sample_message(message):
|
|
300
|
+
if isinstance(message, SampleMessage):
|
|
301
|
+
message = replace(
|
|
302
|
+
message.sample,
|
|
303
|
+
attrs={**message.sample.attrs, "trigger": message.trigger},
|
|
304
|
+
)
|
|
291
305
|
return await self.apartial_fit(message)
|
|
292
306
|
return await super().__acall__(message)
|
|
293
307
|
|
ezmsg/baseproc/units.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import math
|
|
4
4
|
import traceback
|
|
5
5
|
import typing
|
|
6
|
+
import warnings
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
9
|
import ezmsg.core as ez
|
|
9
|
-
from ezmsg.util.generator import GenState
|
|
10
10
|
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
|
|
11
11
|
|
|
12
12
|
from .clockdriven import BaseClockDrivenProducer
|
|
@@ -14,7 +14,6 @@ from .composite import CompositeProcessor
|
|
|
14
14
|
from .processor import BaseConsumer, BaseProducer, BaseTransformer
|
|
15
15
|
from .protocols import MessageInType, MessageOutType, SettingsType
|
|
16
16
|
from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
|
|
17
|
-
from .util.message import SampleMessage
|
|
18
17
|
from .util.profile import profile_subpub
|
|
19
18
|
from .util.typeresolution import resolve_typevar
|
|
20
19
|
|
|
@@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
|
|
|
223
222
|
ABC,
|
|
224
223
|
typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
|
|
225
224
|
):
|
|
226
|
-
INPUT_SAMPLE = ez.InputStream(
|
|
225
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
227
226
|
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
228
227
|
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
229
228
|
|
|
@@ -242,11 +241,11 @@ class BaseAdaptiveTransformerUnit(
|
|
|
242
241
|
yield self.OUTPUT_SIGNAL, result
|
|
243
242
|
|
|
244
243
|
@ez.subscriber(INPUT_SAMPLE)
|
|
245
|
-
async def on_sample(self, msg:
|
|
244
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
246
245
|
await self.processor.apartial_fit(msg)
|
|
247
246
|
|
|
248
247
|
|
|
249
|
-
class
|
|
248
|
+
class BaseClockDrivenUnit(
|
|
250
249
|
BaseProcessorUnit[SettingsType],
|
|
251
250
|
ABC,
|
|
252
251
|
typing.Generic[SettingsType, ClockDrivenProducerType],
|
|
@@ -260,7 +259,7 @@ class BaseClockDrivenProducerUnit(
|
|
|
260
259
|
|
|
261
260
|
Implement a new Unit as follows::
|
|
262
261
|
|
|
263
|
-
class SinGeneratorUnit(
|
|
262
|
+
class SinGeneratorUnit(BaseClockDrivenUnit[
|
|
264
263
|
SinGeneratorSettings, # SettingsType (must extend ClockDrivenSettings)
|
|
265
264
|
SinProducer, # ClockDrivenProducerType
|
|
266
265
|
]):
|
|
@@ -287,10 +286,42 @@ class BaseClockDrivenProducerUnit(
|
|
|
287
286
|
yield self.OUTPUT_SIGNAL, result
|
|
288
287
|
|
|
289
288
|
|
|
290
|
-
|
|
289
|
+
class GenState(ez.State):
|
|
290
|
+
"""
|
|
291
|
+
.. deprecated::
|
|
292
|
+
``GenState`` is deprecated. Define a local state class or use
|
|
293
|
+
``ezmsg.baseproc`` processor classes instead.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
gen: typing.Generator[typing.Any, typing.Any, None]
|
|
297
|
+
|
|
298
|
+
def __init_subclass__(cls, **kwargs):
|
|
299
|
+
super().__init_subclass__(**kwargs)
|
|
300
|
+
warnings.warn(
|
|
301
|
+
"GenState is deprecated. Define a local state class instead of subclassing GenState.",
|
|
302
|
+
DeprecationWarning,
|
|
303
|
+
stacklevel=2,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
291
307
|
class GenAxisArray(ez.Unit):
|
|
308
|
+
"""
|
|
309
|
+
.. deprecated::
|
|
310
|
+
``GenAxisArray`` is deprecated. Use ``BaseTransformerUnit`` or
|
|
311
|
+
``BaseAdaptiveTransformerUnit`` from ``ezmsg.baseproc`` instead.
|
|
312
|
+
"""
|
|
313
|
+
|
|
292
314
|
STATE = GenState
|
|
293
315
|
|
|
316
|
+
def __init_subclass__(cls, **kwargs):
|
|
317
|
+
super().__init_subclass__(**kwargs)
|
|
318
|
+
warnings.warn(
|
|
319
|
+
"GenAxisArray is deprecated. Use BaseTransformerUnit or "
|
|
320
|
+
"BaseAdaptiveTransformerUnit from ezmsg.baseproc instead.",
|
|
321
|
+
DeprecationWarning,
|
|
322
|
+
stacklevel=2,
|
|
323
|
+
)
|
|
324
|
+
|
|
294
325
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
295
326
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
296
327
|
INPUT_SETTINGS = ez.InputStream(ez.Settings)
|
ezmsg/baseproc/util/message.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
import typing
|
|
3
|
+
import warnings
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
@@ -19,13 +20,28 @@ class SampleTriggerMessage:
|
|
|
19
20
|
|
|
20
21
|
@dataclass
|
|
21
22
|
class SampleMessage:
|
|
23
|
+
"""
|
|
24
|
+
.. deprecated::
|
|
25
|
+
``SampleMessage`` is deprecated. Use ``AxisArray`` with
|
|
26
|
+
``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
|
|
27
|
+
"""
|
|
28
|
+
|
|
22
29
|
trigger: SampleTriggerMessage
|
|
23
30
|
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
31
|
|
|
25
32
|
sample: AxisArray
|
|
26
33
|
"""The data sampled around the trigger."""
|
|
27
34
|
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
warnings.warn(
|
|
37
|
+
"SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
|
|
38
|
+
DeprecationWarning,
|
|
39
|
+
stacklevel=2,
|
|
40
|
+
)
|
|
41
|
+
|
|
28
42
|
|
|
29
|
-
def is_sample_message(message: typing.Any) ->
|
|
30
|
-
"""
|
|
31
|
-
|
|
43
|
+
def is_sample_message(message: typing.Any) -> bool:
|
|
44
|
+
"""Detect old SampleMessage OR new AxisArray-with-trigger."""
|
|
45
|
+
if isinstance(message, SampleMessage):
|
|
46
|
+
return True
|
|
47
|
+
return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-baseproc
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Base processor classes and protocols for ezmsg signal processing pipelines
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
ezmsg/baseproc/__init__.py,sha256=DhINT6FB8iFdQk7RaNe8JY6YBZdv2FehuMfv8lWBN1w,4739
|
|
2
|
+
ezmsg/baseproc/__version__.py,sha256=0Oc4EBzGTJOvXX0Vym4evglW1NQPpe8RLn8TdxsKzfs,704
|
|
3
|
+
ezmsg/baseproc/clock.py,sha256=jRJGxaWWBek503OcGRYsW4Qc7Lt4m-tVl1wn_v-qCIk,3254
|
|
4
|
+
ezmsg/baseproc/clockdriven.py,sha256=ckPZSVHZfYfjFRHDCERUWjDwyQgOu-aRTNsJ3EDFBaI,6161
|
|
5
|
+
ezmsg/baseproc/composite.py,sha256=Lin4K_rmS2Tnxt-m8daP-PUyeeqL4id5JkVh-AUNrQw,14901
|
|
6
|
+
ezmsg/baseproc/counter.py,sha256=FGW-Uu0PDHa6AJFjMC7GtiXC3LulBbyP1Z_ae895F3s,2152
|
|
7
|
+
ezmsg/baseproc/processor.py,sha256=Ir9FtNuVG4yc-frwNxoYrlld99ff1mXwwGWaHxEJ6tY,8056
|
|
8
|
+
ezmsg/baseproc/protocols.py,sha256=ECgFW48Mma7rRcW6b4QVVUzAMzXL-2dAiRIUv_ZG9nw,5109
|
|
9
|
+
ezmsg/baseproc/stateful.py,sha256=sbChz1soxgW3IjDHKpU4gxsiJ3JzT3eDIDMvLkPr5F4,11995
|
|
10
|
+
ezmsg/baseproc/units.py,sha256=1ds69VuWJtmCLy_UeAcFHO_4iExe2HxbdCJ5PgIy4Wc,12995
|
|
11
|
+
ezmsg/baseproc/util/__init__.py,sha256=hvMUJOBuqioER50GZ5-GZiQbQ9NtQYEze13ZlR2jbMA,37
|
|
12
|
+
ezmsg/baseproc/util/asio.py,sha256=0sF5oDc58DSLlcEgoUpNiqjjcbqnZhjSpQrXn6IdosM,4960
|
|
13
|
+
ezmsg/baseproc/util/message.py,sha256=8eMB-OYwDkRIbzoCX30rfodNWCYYVcT-HQosuqoj2a8,1434
|
|
14
|
+
ezmsg/baseproc/util/profile.py,sha256=MOQDsFsW6ddXT0uAOgytW3aK_AZW5ieA16Pz2hWuE2o,6189
|
|
15
|
+
ezmsg/baseproc/util/typeresolution.py,sha256=5on4QcrYd1rxsRoDEqivNjuWT5BkU-Wg7XdTNaOircI,3485
|
|
16
|
+
ezmsg_baseproc-1.3.0.dist-info/METADATA,sha256=Yy16gYyLvtjFcyEgUrCq1gbj-Jq-b7oK8KDfeS3rMNg,3415
|
|
17
|
+
ezmsg_baseproc-1.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
18
|
+
ezmsg_baseproc-1.3.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
|
|
19
|
+
ezmsg_baseproc-1.3.0.dist-info/RECORD,,
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
ezmsg/baseproc/__init__.py,sha256=cSkyXkfPvIiAxe6maHiZvJpFbdRw3xTEitBClXeehec,4755
|
|
2
|
-
ezmsg/baseproc/__version__.py,sha256=-uLONazCO1SzFfcY-K6A1keL--LIVfTYccGX6ciADac,704
|
|
3
|
-
ezmsg/baseproc/clock.py,sha256=jRJGxaWWBek503OcGRYsW4Qc7Lt4m-tVl1wn_v-qCIk,3254
|
|
4
|
-
ezmsg/baseproc/clockdriven.py,sha256=ckPZSVHZfYfjFRHDCERUWjDwyQgOu-aRTNsJ3EDFBaI,6161
|
|
5
|
-
ezmsg/baseproc/composite.py,sha256=Lin4K_rmS2Tnxt-m8daP-PUyeeqL4id5JkVh-AUNrQw,14901
|
|
6
|
-
ezmsg/baseproc/counter.py,sha256=kcBPiVxMPULp4ojnVESNw7mn_4v0xSODfASHrL83GtM,2168
|
|
7
|
-
ezmsg/baseproc/processor.py,sha256=Ir9FtNuVG4yc-frwNxoYrlld99ff1mXwwGWaHxEJ6tY,8056
|
|
8
|
-
ezmsg/baseproc/protocols.py,sha256=O3Qp0ymE9Ovlmh8t22v-lMmFzuWK0D93REAYMnJV3xA,5106
|
|
9
|
-
ezmsg/baseproc/stateful.py,sha256=-jjAZIyJA5eiTECi1fSfazfqgv__RtyqPp1ZvLFFIDI,11424
|
|
10
|
-
ezmsg/baseproc/units.py,sha256=byFijVLEZFO145HE74sZk1_qpCu6nFjB8-vSYz9Grds,12077
|
|
11
|
-
ezmsg/baseproc/util/__init__.py,sha256=hvMUJOBuqioER50GZ5-GZiQbQ9NtQYEze13ZlR2jbMA,37
|
|
12
|
-
ezmsg/baseproc/util/asio.py,sha256=0sF5oDc58DSLlcEgoUpNiqjjcbqnZhjSpQrXn6IdosM,4960
|
|
13
|
-
ezmsg/baseproc/util/message.py,sha256=l_b1b6bXX8N6VF9RbUELzsHs73cKkDURBdIr0lt3CY0,909
|
|
14
|
-
ezmsg/baseproc/util/profile.py,sha256=MOQDsFsW6ddXT0uAOgytW3aK_AZW5ieA16Pz2hWuE2o,6189
|
|
15
|
-
ezmsg/baseproc/util/typeresolution.py,sha256=5on4QcrYd1rxsRoDEqivNjuWT5BkU-Wg7XdTNaOircI,3485
|
|
16
|
-
ezmsg_baseproc-1.2.0.dist-info/METADATA,sha256=H2jTn5VSw0pEwQWLdd3Hu0G1OCbbKM-On4AmKyyyfm4,3415
|
|
17
|
-
ezmsg_baseproc-1.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
18
|
-
ezmsg_baseproc-1.2.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
|
|
19
|
-
ezmsg_baseproc-1.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|