ezmsg-baseproc 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,7 +82,7 @@ from .stateful import (
82
82
  from .units import (
83
83
  AdaptiveTransformerType,
84
84
  BaseAdaptiveTransformerUnit,
85
- BaseClockDrivenProducerUnit,
85
+ BaseClockDrivenUnit,
86
86
  BaseConsumerUnit,
87
87
  BaseProcessorUnit,
88
88
  BaseProducerUnit,
@@ -158,7 +158,7 @@ __all__ = [
158
158
  "BaseConsumerUnit",
159
159
  "BaseTransformerUnit",
160
160
  "BaseAdaptiveTransformerUnit",
161
- "BaseClockDrivenProducerUnit",
161
+ "BaseClockDrivenUnit",
162
162
  "GenAxisArray",
163
163
  # Type resolution helpers
164
164
  "get_base_producer_type",
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.2.0'
32
- __version_tuple__ = version_tuple = (1, 2, 0)
31
+ __version__ = version = '1.3.0'
32
+ __version_tuple__ = version_tuple = (1, 3, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
ezmsg/baseproc/counter.py CHANGED
@@ -9,7 +9,7 @@ from .clockdriven import (
9
9
  ClockDrivenState,
10
10
  )
11
11
  from .protocols import processor_state
12
- from .units import BaseClockDrivenProducerUnit
12
+ from .units import BaseClockDrivenUnit
13
13
 
14
14
 
15
15
  class CounterSettings(ClockDrivenSettings):
@@ -57,7 +57,7 @@ class CounterTransformer(BaseClockDrivenProducer[CounterSettings, CounterTransfo
57
57
  )
58
58
 
59
59
 
60
- class Counter(BaseClockDrivenProducerUnit[CounterSettings, CounterTransformer]):
60
+ class Counter(BaseClockDrivenUnit[CounterSettings, CounterTransformer]):
61
61
  """
62
62
  Transforms clock ticks into monotonically increasing counter values as AxisArray.
63
63
 
@@ -5,8 +5,7 @@ import typing
5
5
  from dataclasses import dataclass
6
6
 
7
7
  import ezmsg.core as ez
8
-
9
- from .util.message import SampleMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
10
9
 
11
10
  # --- Processor state decorator ---
12
11
  processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
@@ -138,7 +137,7 @@ class StatefulTransformer(
138
137
 
139
138
 
140
139
  class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
141
- def partial_fit(self, message: SampleMessage) -> None:
140
+ def partial_fit(self, message: AxisArray) -> None:
142
141
  """Update transformer state using labeled training data.
143
142
 
144
143
  This method should update the internal state/parameters of the transformer
@@ -146,4 +145,4 @@ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
146
145
  """
147
146
  ...
148
147
 
149
- async def apartial_fit(self, message: SampleMessage) -> None: ...
148
+ async def apartial_fit(self, message: AxisArray) -> None: ...
@@ -4,6 +4,9 @@ import pickle
4
4
  import typing
5
5
  from abc import ABC, abstractmethod
6
6
 
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from .processor import (
8
11
  BaseProcessor,
9
12
  BaseProducer,
@@ -256,7 +259,7 @@ class BaseStatefulTransformer(
256
259
  class BaseAdaptiveTransformer(
257
260
  BaseStatefulTransformer[
258
261
  SettingsType,
259
- MessageInType | SampleMessage,
262
+ MessageInType,
260
263
  MessageOutType | None,
261
264
  StateType,
262
265
  ],
@@ -264,30 +267,41 @@ class BaseAdaptiveTransformer(
264
267
  typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
265
268
  ):
266
269
  @abstractmethod
267
- def partial_fit(self, message: SampleMessage) -> None: ...
270
+ def partial_fit(self, message: AxisArray) -> None: ...
268
271
 
269
- async def apartial_fit(self, message: SampleMessage) -> None:
272
+ async def apartial_fit(self, message: AxisArray) -> None:
270
273
  """Override me if you need async partial fitting."""
271
274
  return self.partial_fit(message)
272
275
 
273
- def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
276
+ def __call__(self, message: MessageInType) -> MessageOutType | None:
274
277
  """
275
278
  Adapt transformer with training data (and optionally labels)
276
- in SampleMessage
279
+ in AxisArray with attrs["trigger"].
277
280
 
278
281
  Args:
279
- message: An instance of SampleMessage with optional
280
- labels (y) in message.trigger.value.data and
281
- data (X) in message.sample.data
282
+ message: An AxisArray with optional trigger in attrs["trigger"],
283
+ containing labels (y) in attrs["trigger"].value and
284
+ data (X) in message.data
282
285
 
283
286
  Returns: None
284
287
  """
285
288
  if is_sample_message(message):
289
+ if isinstance(message, SampleMessage):
290
+ # Auto-convert old format → new format
291
+ message = replace(
292
+ message.sample,
293
+ attrs={**message.sample.attrs, "trigger": message.trigger},
294
+ )
286
295
  return self.partial_fit(message)
287
296
  return super().__call__(message)
288
297
 
289
- async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
298
+ async def __acall__(self, message: MessageInType) -> MessageOutType | None:
290
299
  if is_sample_message(message):
300
+ if isinstance(message, SampleMessage):
301
+ message = replace(
302
+ message.sample,
303
+ attrs={**message.sample.attrs, "trigger": message.trigger},
304
+ )
291
305
  return await self.apartial_fit(message)
292
306
  return await super().__acall__(message)
293
307
 
ezmsg/baseproc/units.py CHANGED
@@ -3,10 +3,10 @@
3
3
  import math
4
4
  import traceback
5
5
  import typing
6
+ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
 
8
9
  import ezmsg.core as ez
9
- from ezmsg.util.generator import GenState
10
10
  from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
11
11
 
12
12
  from .clockdriven import BaseClockDrivenProducer
@@ -14,7 +14,6 @@ from .composite import CompositeProcessor
14
14
  from .processor import BaseConsumer, BaseProducer, BaseTransformer
15
15
  from .protocols import MessageInType, MessageOutType, SettingsType
16
16
  from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
17
- from .util.message import SampleMessage
18
17
  from .util.profile import profile_subpub
19
18
  from .util.typeresolution import resolve_typevar
20
19
 
@@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
223
222
  ABC,
224
223
  typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
225
224
  ):
226
- INPUT_SAMPLE = ez.InputStream(SampleMessage)
225
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
227
226
  INPUT_SIGNAL = ez.InputStream(MessageInType)
228
227
  OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
229
228
 
@@ -242,11 +241,11 @@ class BaseAdaptiveTransformerUnit(
242
241
  yield self.OUTPUT_SIGNAL, result
243
242
 
244
243
  @ez.subscriber(INPUT_SAMPLE)
245
- async def on_sample(self, msg: SampleMessage) -> None:
244
+ async def on_sample(self, msg: AxisArray) -> None:
246
245
  await self.processor.apartial_fit(msg)
247
246
 
248
247
 
249
- class BaseClockDrivenProducerUnit(
248
+ class BaseClockDrivenUnit(
250
249
  BaseProcessorUnit[SettingsType],
251
250
  ABC,
252
251
  typing.Generic[SettingsType, ClockDrivenProducerType],
@@ -260,7 +259,7 @@ class BaseClockDrivenProducerUnit(
260
259
 
261
260
  Implement a new Unit as follows::
262
261
 
263
- class SinGeneratorUnit(BaseClockDrivenProducerUnit[
262
+ class SinGeneratorUnit(BaseClockDrivenUnit[
264
263
  SinGeneratorSettings, # SettingsType (must extend ClockDrivenSettings)
265
264
  SinProducer, # ClockDrivenProducerType
266
265
  ]):
@@ -287,10 +286,42 @@ class BaseClockDrivenProducerUnit(
287
286
  yield self.OUTPUT_SIGNAL, result
288
287
 
289
288
 
290
- # Legacy class
289
+ class GenState(ez.State):
290
+ """
291
+ .. deprecated::
292
+ ``GenState`` is deprecated. Define a local state class or use
293
+ ``ezmsg.baseproc`` processor classes instead.
294
+ """
295
+
296
+ gen: typing.Generator[typing.Any, typing.Any, None]
297
+
298
+ def __init_subclass__(cls, **kwargs):
299
+ super().__init_subclass__(**kwargs)
300
+ warnings.warn(
301
+ "GenState is deprecated. Define a local state class instead of subclassing GenState.",
302
+ DeprecationWarning,
303
+ stacklevel=2,
304
+ )
305
+
306
+
291
307
  class GenAxisArray(ez.Unit):
308
+ """
309
+ .. deprecated::
310
+ ``GenAxisArray`` is deprecated. Use ``BaseTransformerUnit`` or
311
+ ``BaseAdaptiveTransformerUnit`` from ``ezmsg.baseproc`` instead.
312
+ """
313
+
292
314
  STATE = GenState
293
315
 
316
+ def __init_subclass__(cls, **kwargs):
317
+ super().__init_subclass__(**kwargs)
318
+ warnings.warn(
319
+ "GenAxisArray is deprecated. Use BaseTransformerUnit or "
320
+ "BaseAdaptiveTransformerUnit from ezmsg.baseproc instead.",
321
+ DeprecationWarning,
322
+ stacklevel=2,
323
+ )
324
+
294
325
  INPUT_SIGNAL = ez.InputStream(AxisArray)
295
326
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
296
327
  INPUT_SETTINGS = ez.InputStream(ez.Settings)
@@ -1,5 +1,6 @@
1
1
  import time
2
2
  import typing
3
+ import warnings
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  from ezmsg.util.messages.axisarray import AxisArray
@@ -19,13 +20,28 @@ class SampleTriggerMessage:
19
20
 
20
21
  @dataclass
21
22
  class SampleMessage:
23
+ """
24
+ .. deprecated::
25
+ ``SampleMessage`` is deprecated. Use ``AxisArray`` with
26
+ ``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
27
+ """
28
+
22
29
  trigger: SampleTriggerMessage
23
30
  """The time, window, and value (if any) associated with the trigger."""
24
31
 
25
32
  sample: AxisArray
26
33
  """The data sampled around the trigger."""
27
34
 
35
+ def __post_init__(self):
36
+ warnings.warn(
37
+ "SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
38
+ DeprecationWarning,
39
+ stacklevel=2,
40
+ )
41
+
28
42
 
29
- def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
- """Check if the message is a SampleMessage."""
31
- return hasattr(message, "trigger")
43
+ def is_sample_message(message: typing.Any) -> bool:
44
+ """Detect old SampleMessage OR new AxisArray-with-trigger."""
45
+ if isinstance(message, SampleMessage):
46
+ return True
47
+ return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-baseproc
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: Base processor classes and protocols for ezmsg signal processing pipelines
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,19 @@
1
+ ezmsg/baseproc/__init__.py,sha256=DhINT6FB8iFdQk7RaNe8JY6YBZdv2FehuMfv8lWBN1w,4739
2
+ ezmsg/baseproc/__version__.py,sha256=0Oc4EBzGTJOvXX0Vym4evglW1NQPpe8RLn8TdxsKzfs,704
3
+ ezmsg/baseproc/clock.py,sha256=jRJGxaWWBek503OcGRYsW4Qc7Lt4m-tVl1wn_v-qCIk,3254
4
+ ezmsg/baseproc/clockdriven.py,sha256=ckPZSVHZfYfjFRHDCERUWjDwyQgOu-aRTNsJ3EDFBaI,6161
5
+ ezmsg/baseproc/composite.py,sha256=Lin4K_rmS2Tnxt-m8daP-PUyeeqL4id5JkVh-AUNrQw,14901
6
+ ezmsg/baseproc/counter.py,sha256=FGW-Uu0PDHa6AJFjMC7GtiXC3LulBbyP1Z_ae895F3s,2152
7
+ ezmsg/baseproc/processor.py,sha256=Ir9FtNuVG4yc-frwNxoYrlld99ff1mXwwGWaHxEJ6tY,8056
8
+ ezmsg/baseproc/protocols.py,sha256=ECgFW48Mma7rRcW6b4QVVUzAMzXL-2dAiRIUv_ZG9nw,5109
9
+ ezmsg/baseproc/stateful.py,sha256=sbChz1soxgW3IjDHKpU4gxsiJ3JzT3eDIDMvLkPr5F4,11995
10
+ ezmsg/baseproc/units.py,sha256=1ds69VuWJtmCLy_UeAcFHO_4iExe2HxbdCJ5PgIy4Wc,12995
11
+ ezmsg/baseproc/util/__init__.py,sha256=hvMUJOBuqioER50GZ5-GZiQbQ9NtQYEze13ZlR2jbMA,37
12
+ ezmsg/baseproc/util/asio.py,sha256=0sF5oDc58DSLlcEgoUpNiqjjcbqnZhjSpQrXn6IdosM,4960
13
+ ezmsg/baseproc/util/message.py,sha256=8eMB-OYwDkRIbzoCX30rfodNWCYYVcT-HQosuqoj2a8,1434
14
+ ezmsg/baseproc/util/profile.py,sha256=MOQDsFsW6ddXT0uAOgytW3aK_AZW5ieA16Pz2hWuE2o,6189
15
+ ezmsg/baseproc/util/typeresolution.py,sha256=5on4QcrYd1rxsRoDEqivNjuWT5BkU-Wg7XdTNaOircI,3485
16
+ ezmsg_baseproc-1.3.0.dist-info/METADATA,sha256=Yy16gYyLvtjFcyEgUrCq1gbj-Jq-b7oK8KDfeS3rMNg,3415
17
+ ezmsg_baseproc-1.3.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
18
+ ezmsg_baseproc-1.3.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
19
+ ezmsg_baseproc-1.3.0.dist-info/RECORD,,
@@ -1,19 +0,0 @@
1
- ezmsg/baseproc/__init__.py,sha256=cSkyXkfPvIiAxe6maHiZvJpFbdRw3xTEitBClXeehec,4755
2
- ezmsg/baseproc/__version__.py,sha256=-uLONazCO1SzFfcY-K6A1keL--LIVfTYccGX6ciADac,704
3
- ezmsg/baseproc/clock.py,sha256=jRJGxaWWBek503OcGRYsW4Qc7Lt4m-tVl1wn_v-qCIk,3254
4
- ezmsg/baseproc/clockdriven.py,sha256=ckPZSVHZfYfjFRHDCERUWjDwyQgOu-aRTNsJ3EDFBaI,6161
5
- ezmsg/baseproc/composite.py,sha256=Lin4K_rmS2Tnxt-m8daP-PUyeeqL4id5JkVh-AUNrQw,14901
6
- ezmsg/baseproc/counter.py,sha256=kcBPiVxMPULp4ojnVESNw7mn_4v0xSODfASHrL83GtM,2168
7
- ezmsg/baseproc/processor.py,sha256=Ir9FtNuVG4yc-frwNxoYrlld99ff1mXwwGWaHxEJ6tY,8056
8
- ezmsg/baseproc/protocols.py,sha256=O3Qp0ymE9Ovlmh8t22v-lMmFzuWK0D93REAYMnJV3xA,5106
9
- ezmsg/baseproc/stateful.py,sha256=-jjAZIyJA5eiTECi1fSfazfqgv__RtyqPp1ZvLFFIDI,11424
10
- ezmsg/baseproc/units.py,sha256=byFijVLEZFO145HE74sZk1_qpCu6nFjB8-vSYz9Grds,12077
11
- ezmsg/baseproc/util/__init__.py,sha256=hvMUJOBuqioER50GZ5-GZiQbQ9NtQYEze13ZlR2jbMA,37
12
- ezmsg/baseproc/util/asio.py,sha256=0sF5oDc58DSLlcEgoUpNiqjjcbqnZhjSpQrXn6IdosM,4960
13
- ezmsg/baseproc/util/message.py,sha256=l_b1b6bXX8N6VF9RbUELzsHs73cKkDURBdIr0lt3CY0,909
14
- ezmsg/baseproc/util/profile.py,sha256=MOQDsFsW6ddXT0uAOgytW3aK_AZW5ieA16Pz2hWuE2o,6189
15
- ezmsg/baseproc/util/typeresolution.py,sha256=5on4QcrYd1rxsRoDEqivNjuWT5BkU-Wg7XdTNaOircI,3485
16
- ezmsg_baseproc-1.2.0.dist-info/METADATA,sha256=H2jTn5VSw0pEwQWLdd3Hu0G1OCbbKM-On4AmKyyyfm4,3415
17
- ezmsg_baseproc-1.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
18
- ezmsg_baseproc-1.2.0.dist-info/licenses/LICENSE,sha256=BDD8rfac1Ur7mp0_3izEdr6fHgSA3Or6U1Kb0ZAWsow,1066
19
- ezmsg_baseproc-1.2.0.dist-info/RECORD,,