ezmsg-baseproc 1.2.1__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/PKG-INFO +1 -1
  2. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__version__.py +2 -2
  3. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/protocols.py +3 -4
  4. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/stateful.py +23 -9
  5. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/units.py +36 -5
  6. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/message.py +19 -3
  7. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_baseproc.py +30 -30
  8. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/.github/workflows/docs.yml +0 -0
  9. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/.github/workflows/python-publish.yml +0 -0
  10. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/.github/workflows/python-tests.yml +0 -0
  11. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/.gitignore +0 -0
  12. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/.pre-commit-config.yaml +0 -0
  13. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/LICENSE +0 -0
  14. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/README.md +0 -0
  15. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/Makefile +0 -0
  16. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/make.bat +0 -0
  17. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/_templates/autosummary/module.rst +0 -0
  18. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/api/index.rst +0 -0
  19. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/conf.py +0 -0
  20. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/ProcessorsBase.md +0 -0
  21. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/adaptive.rst +0 -0
  22. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/checkpoint.rst +0 -0
  23. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/clockdriven.rst +0 -0
  24. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/composite.rst +0 -0
  25. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/content-processors.rst +0 -0
  26. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/processor.rst +0 -0
  27. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/standalone.rst +0 -0
  28. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/stateful.rst +0 -0
  29. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/unit.rst +0 -0
  30. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/docs/source/index.md +0 -0
  31. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/pyproject.toml +0 -0
  32. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__init__.py +0 -0
  33. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clock.py +0 -0
  34. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clockdriven.py +0 -0
  35. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/composite.py +0 -0
  36. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/counter.py +0 -0
  37. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/processor.py +0 -0
  38. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/__init__.py +0 -0
  39. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/asio.py +0 -0
  40. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/profile.py +0 -0
  41. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/typeresolution.py +0 -0
  42. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/conftest.py +0 -0
  43. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/helpers/__init__.py +0 -0
  44. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/helpers/util.py +0 -0
  45. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_clock.py +0 -0
  46. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_clock_counter_system.py +0 -0
  47. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_clockdriven.py +0 -0
  48. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_counter.py +0 -0
  49. {ezmsg_baseproc-1.2.1 → ezmsg_baseproc-1.3.0}/tests/test_profile.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-baseproc
3
- Version: 1.2.1
3
+ Version: 1.3.0
4
4
  Summary: Base processor classes and protocols for ezmsg signal processing pipelines
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.2.1'
32
- __version_tuple__ = version_tuple = (1, 2, 1)
31
+ __version__ = version = '1.3.0'
32
+ __version_tuple__ = version_tuple = (1, 3, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -5,8 +5,7 @@ import typing
5
5
  from dataclasses import dataclass
6
6
 
7
7
  import ezmsg.core as ez
8
-
9
- from .util.message import SampleMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
10
9
 
11
10
  # --- Processor state decorator ---
12
11
  processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
@@ -138,7 +137,7 @@ class StatefulTransformer(
138
137
 
139
138
 
140
139
  class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
141
- def partial_fit(self, message: SampleMessage) -> None:
140
+ def partial_fit(self, message: AxisArray) -> None:
142
141
  """Update transformer state using labeled training data.
143
142
 
144
143
  This method should update the internal state/parameters of the transformer
@@ -146,4 +145,4 @@ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
146
145
  """
147
146
  ...
148
147
 
149
- async def apartial_fit(self, message: SampleMessage) -> None: ...
148
+ async def apartial_fit(self, message: AxisArray) -> None: ...
@@ -4,6 +4,9 @@ import pickle
4
4
  import typing
5
5
  from abc import ABC, abstractmethod
6
6
 
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from .processor import (
8
11
  BaseProcessor,
9
12
  BaseProducer,
@@ -256,7 +259,7 @@ class BaseStatefulTransformer(
256
259
  class BaseAdaptiveTransformer(
257
260
  BaseStatefulTransformer[
258
261
  SettingsType,
259
- MessageInType | SampleMessage,
262
+ MessageInType,
260
263
  MessageOutType | None,
261
264
  StateType,
262
265
  ],
@@ -264,30 +267,41 @@ class BaseAdaptiveTransformer(
264
267
  typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
265
268
  ):
266
269
  @abstractmethod
267
- def partial_fit(self, message: SampleMessage) -> None: ...
270
+ def partial_fit(self, message: AxisArray) -> None: ...
268
271
 
269
- async def apartial_fit(self, message: SampleMessage) -> None:
272
+ async def apartial_fit(self, message: AxisArray) -> None:
270
273
  """Override me if you need async partial fitting."""
271
274
  return self.partial_fit(message)
272
275
 
273
- def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
276
+ def __call__(self, message: MessageInType) -> MessageOutType | None:
274
277
  """
275
278
  Adapt transformer with training data (and optionally labels)
276
- in SampleMessage
279
+ in AxisArray with attrs["trigger"].
277
280
 
278
281
  Args:
279
- message: An instance of SampleMessage with optional
280
- labels (y) in message.trigger.value.data and
281
- data (X) in message.sample.data
282
+ message: An AxisArray with optional trigger in attrs["trigger"],
283
+ containing labels (y) in attrs["trigger"].value and
284
+ data (X) in message.data
282
285
 
283
286
  Returns: None
284
287
  """
285
288
  if is_sample_message(message):
289
+ if isinstance(message, SampleMessage):
290
+ # Auto-convert old format → new format
291
+ message = replace(
292
+ message.sample,
293
+ attrs={**message.sample.attrs, "trigger": message.trigger},
294
+ )
286
295
  return self.partial_fit(message)
287
296
  return super().__call__(message)
288
297
 
289
- async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
298
+ async def __acall__(self, message: MessageInType) -> MessageOutType | None:
290
299
  if is_sample_message(message):
300
+ if isinstance(message, SampleMessage):
301
+ message = replace(
302
+ message.sample,
303
+ attrs={**message.sample.attrs, "trigger": message.trigger},
304
+ )
291
305
  return await self.apartial_fit(message)
292
306
  return await super().__acall__(message)
293
307
 
@@ -3,10 +3,10 @@
3
3
  import math
4
4
  import traceback
5
5
  import typing
6
+ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
 
8
9
  import ezmsg.core as ez
9
- from ezmsg.util.generator import GenState
10
10
  from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
11
11
 
12
12
  from .clockdriven import BaseClockDrivenProducer
@@ -14,7 +14,6 @@ from .composite import CompositeProcessor
14
14
  from .processor import BaseConsumer, BaseProducer, BaseTransformer
15
15
  from .protocols import MessageInType, MessageOutType, SettingsType
16
16
  from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
17
- from .util.message import SampleMessage
18
17
  from .util.profile import profile_subpub
19
18
  from .util.typeresolution import resolve_typevar
20
19
 
@@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
223
222
  ABC,
224
223
  typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
225
224
  ):
226
- INPUT_SAMPLE = ez.InputStream(SampleMessage)
225
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
227
226
  INPUT_SIGNAL = ez.InputStream(MessageInType)
228
227
  OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
229
228
 
@@ -242,7 +241,7 @@ class BaseAdaptiveTransformerUnit(
242
241
  yield self.OUTPUT_SIGNAL, result
243
242
 
244
243
  @ez.subscriber(INPUT_SAMPLE)
245
- async def on_sample(self, msg: SampleMessage) -> None:
244
+ async def on_sample(self, msg: AxisArray) -> None:
246
245
  await self.processor.apartial_fit(msg)
247
246
 
248
247
 
@@ -287,10 +286,42 @@ class BaseClockDrivenUnit(
287
286
  yield self.OUTPUT_SIGNAL, result
288
287
 
289
288
 
290
- # Legacy class
289
+ class GenState(ez.State):
290
+ """
291
+ .. deprecated::
292
+ ``GenState`` is deprecated. Define a local state class or use
293
+ ``ezmsg.baseproc`` processor classes instead.
294
+ """
295
+
296
+ gen: typing.Generator[typing.Any, typing.Any, None]
297
+
298
+ def __init_subclass__(cls, **kwargs):
299
+ super().__init_subclass__(**kwargs)
300
+ warnings.warn(
301
+ "GenState is deprecated. Define a local state class instead of subclassing GenState.",
302
+ DeprecationWarning,
303
+ stacklevel=2,
304
+ )
305
+
306
+
291
307
  class GenAxisArray(ez.Unit):
308
+ """
309
+ .. deprecated::
310
+ ``GenAxisArray`` is deprecated. Use ``BaseTransformerUnit`` or
311
+ ``BaseAdaptiveTransformerUnit`` from ``ezmsg.baseproc`` instead.
312
+ """
313
+
292
314
  STATE = GenState
293
315
 
316
+ def __init_subclass__(cls, **kwargs):
317
+ super().__init_subclass__(**kwargs)
318
+ warnings.warn(
319
+ "GenAxisArray is deprecated. Use BaseTransformerUnit or "
320
+ "BaseAdaptiveTransformerUnit from ezmsg.baseproc instead.",
321
+ DeprecationWarning,
322
+ stacklevel=2,
323
+ )
324
+
294
325
  INPUT_SIGNAL = ez.InputStream(AxisArray)
295
326
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
296
327
  INPUT_SETTINGS = ez.InputStream(ez.Settings)
@@ -1,5 +1,6 @@
1
1
  import time
2
2
  import typing
3
+ import warnings
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  from ezmsg.util.messages.axisarray import AxisArray
@@ -19,13 +20,28 @@ class SampleTriggerMessage:
19
20
 
20
21
  @dataclass
21
22
  class SampleMessage:
23
+ """
24
+ .. deprecated::
25
+ ``SampleMessage`` is deprecated. Use ``AxisArray`` with
26
+ ``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
27
+ """
28
+
22
29
  trigger: SampleTriggerMessage
23
30
  """The time, window, and value (if any) associated with the trigger."""
24
31
 
25
32
  sample: AxisArray
26
33
  """The data sampled around the trigger."""
27
34
 
35
+ def __post_init__(self):
36
+ warnings.warn(
37
+ "SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
38
+ DeprecationWarning,
39
+ stacklevel=2,
40
+ )
41
+
28
42
 
29
- def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
- """Check if the message is a SampleMessage."""
31
- return hasattr(message, "trigger")
43
+ def is_sample_message(message: typing.Any) -> bool:
44
+ """Detect old SampleMessage OR new AxisArray-with-trigger."""
45
+ if isinstance(message, SampleMessage):
46
+ return True
47
+ return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
@@ -3,11 +3,11 @@
3
3
  import dataclasses
4
4
  import pickle
5
5
  from types import NoneType
6
- from typing import Any, Generator
7
- from unittest.mock import MagicMock
6
+ from typing import Any
8
7
 
8
+ import numpy as np
9
9
  import pytest
10
- from ezmsg.util.generator import consumer
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
11
 
12
12
  from ezmsg.baseproc import (
13
13
  BaseAdaptiveTransformer,
@@ -22,7 +22,7 @@ from ezmsg.baseproc import (
22
22
  BaseTransformer,
23
23
  CompositeProcessor,
24
24
  CompositeProducer,
25
- SampleMessage,
25
+ SampleTriggerMessage,
26
26
  _get_base_processor_message_in_type,
27
27
  _get_base_processor_message_out_type,
28
28
  _get_base_processor_settings_type,
@@ -136,11 +136,11 @@ class MockAdaptiveTransformer(BaseAdaptiveTransformer[MockSettings, MockMessageA
136
136
  def _reset_state(self, message: MockMessageA) -> None:
137
137
  self._state.iterations = 0
138
138
 
139
- def _process(self, message: MockMessageA | SampleMessage) -> MockMessageB:
139
+ def _process(self, message: MockMessageA) -> MockMessageB:
140
140
  self._state.iterations += 1
141
141
  return MockMessageB()
142
142
 
143
- def partial_fit(self, message: SampleMessage) -> None:
143
+ def partial_fit(self, message: AxisArray) -> None:
144
144
  self._state.iterations += 1
145
145
 
146
146
 
@@ -251,18 +251,18 @@ class ChainedCompositeProcessorWithDeepProcessors(CompositeProcessor[MockSetting
251
251
  }
252
252
 
253
253
 
254
- @consumer
255
- def mock_generator() -> Generator[MockMessageA, MockMessageA, None]:
256
- """A mock generator function for testing purposes."""
257
- while True:
258
- yield MockMessageA()
254
+ class MockGeneratorTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
255
+ """A mock transformer replacing the legacy @consumer generator."""
256
+
257
+ def _process(self, message: MockMessageA) -> MockMessageA:
258
+ return MockMessageA()
259
259
 
260
260
 
261
261
  class MockGeneratorCompositeProcessor(CompositeProcessor[MockSettings, MockMessageA, MockMessageB]):
262
262
  @staticmethod
263
263
  def _initialize_processors(settings):
264
264
  return {
265
- "generator": mock_generator(),
265
+ "generator": MockGeneratorTransformer(settings=settings),
266
266
  "stateful_processor": MockStatefulProcessor(settings=settings),
267
267
  }
268
268
 
@@ -333,27 +333,26 @@ class ChainedCompositeProducerWithDeepProcessors(CompositeProducer[MockSettings,
333
333
  }
334
334
 
335
335
 
336
- @consumer
337
- def mock_producer_generator() -> Generator[MockMessageA, None, None]:
338
- """A mock generator function for testing purposes."""
339
- while True:
340
- yield MockMessageA()
336
+ class MockGeneratorProducer(BaseProducer[MockSettings, MockMessageA]):
337
+ """A mock producer replacing the legacy @consumer producer generator."""
338
+
339
+ async def _produce(self) -> MockMessageA:
340
+ return MockMessageA()
341
+
341
342
 
343
+ class MockGeneratorPassthroughTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
344
+ """A mock transformer replacing the legacy unprimed generator."""
342
345
 
343
- def mock_generator_unprimed() -> Generator[MockMessageA, MockMessageA, None]:
344
- """A mock generator function for testing purposes."""
345
- output = MockMessageA()
346
- while True:
347
- input = yield output
348
- output = input or output
346
+ def _process(self, message: MockMessageA) -> MockMessageA:
347
+ return message or MockMessageA()
349
348
 
350
349
 
351
350
  class MockGeneratorCompositeProducer(CompositeProducer[MockSettings, MockMessageB]):
352
351
  @staticmethod
353
352
  def _initialize_processors(settings):
354
353
  return {
355
- "generator": mock_producer_generator(),
356
- "mock_generator_unprimed": mock_generator_unprimed(),
354
+ "generator": MockGeneratorProducer(settings=settings),
355
+ "mock_generator_passthrough": MockGeneratorPassthroughTransformer(settings=settings),
357
356
  "stateful_processor": MockStatefulProcessor(settings=settings),
358
357
  }
359
358
 
@@ -758,10 +757,13 @@ class TestBaseStatefulTransformer:
758
757
  assert new_state[0].iterations == 1
759
758
 
760
759
 
761
- # Mock SampleMessage for testing BaseAdaptiveTransformer
760
+ # Helper to create an AxisArray with trigger in attrs for testing BaseAdaptiveTransformer
762
761
  def mock_sample_message():
763
- sample_message = MagicMock(spec=SampleMessage)
764
- return sample_message
762
+ return AxisArray(
763
+ data=np.zeros((1, 1)),
764
+ dims=["time", "ch"],
765
+ attrs={"trigger": SampleTriggerMessage()},
766
+ )
765
767
 
766
768
 
767
769
  class TestBaseAdaptiveTransformer:
@@ -778,9 +780,7 @@ class TestBaseAdaptiveTransformer:
778
780
 
779
781
  def test_call_with_sample_message(self):
780
782
  transformer = MockAdaptiveTransformer()
781
- # Create a sample message with a trigger attribute
782
783
  sample_msg = mock_sample_message()
783
- setattr(sample_msg, "trigger", None)
784
784
  result = transformer(sample_msg)
785
785
  assert result is None # partial_fit returns None
786
786
  assert transformer.state.iterations == 1
File without changes
File without changes