ezmsg-baseproc 1.2.0__tar.gz → 1.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/PKG-INFO +1 -1
  2. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/ProcessorsBase.md +2 -2
  3. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/clockdriven.rst +2 -2
  4. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__init__.py +2 -2
  5. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__version__.py +2 -2
  6. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/counter.py +2 -2
  7. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/protocols.py +3 -4
  8. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/stateful.py +23 -9
  9. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/units.py +38 -7
  10. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/message.py +19 -3
  11. ezmsg_baseproc-1.3.0/tests/conftest.py +18 -0
  12. ezmsg_baseproc-1.3.0/tests/helpers/__init__.py +0 -0
  13. ezmsg_baseproc-1.3.0/tests/helpers/util.py +24 -0
  14. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_baseproc.py +30 -30
  15. ezmsg_baseproc-1.3.0/tests/test_clock_counter_system.py +229 -0
  16. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/docs.yml +0 -0
  17. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/python-publish.yml +0 -0
  18. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/python-tests.yml +0 -0
  19. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.gitignore +0 -0
  20. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.pre-commit-config.yaml +0 -0
  21. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/LICENSE +0 -0
  22. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/README.md +0 -0
  23. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/Makefile +0 -0
  24. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/make.bat +0 -0
  25. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/_templates/autosummary/module.rst +0 -0
  26. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/api/index.rst +0 -0
  27. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/conf.py +0 -0
  28. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/adaptive.rst +0 -0
  29. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/checkpoint.rst +0 -0
  30. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/composite.rst +0 -0
  31. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/content-processors.rst +0 -0
  32. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/processor.rst +0 -0
  33. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/standalone.rst +0 -0
  34. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/stateful.rst +0 -0
  35. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/unit.rst +0 -0
  36. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/index.md +0 -0
  37. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/pyproject.toml +0 -0
  38. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clock.py +0 -0
  39. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clockdriven.py +0 -0
  40. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/composite.py +0 -0
  41. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/processor.py +0 -0
  42. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/__init__.py +0 -0
  43. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/asio.py +0 -0
  44. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/profile.py +0 -0
  45. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/typeresolution.py +0 -0
  46. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_clock.py +0 -0
  47. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_clockdriven.py +0 -0
  48. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_counter.py +0 -0
  49. {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_profile.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ezmsg-baseproc
3
- Version: 1.2.0
3
+ Version: 1.3.0
4
4
  Summary: Base processor classes and protocols for ezmsg signal processing pipelines
5
5
  Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
6
6
  License-Expression: MIT
@@ -80,9 +80,9 @@ do not inherit from `BaseStatefulProcessor` and `BaseStatefulProducer`. They acc
80
80
  | 3 | `BaseConsumerUnit` | 1 | `ConsumerType` |
81
81
  | 4 | `BaseTransformerUnit` | 1 | `TransformerType` |
82
82
  | 5 | `BaseAdaptiveTransformerUnit` | 1 | `AdaptiveTransformerType` |
83
- | 6 | `BaseClockDrivenProducerUnit` | 1 | `ClockDrivenProducerType` |
83
+ | 6 | `BaseClockDrivenUnit` | 1 | `ClockDrivenProducerType` |
84
84
 
85
- Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `BaseClockDrivenProducerUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.
85
+ Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `BaseClockDrivenUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.
86
86
 
87
87
 
88
88
  ## Implementing a custom standalone processor
@@ -42,7 +42,7 @@ Here's a complete example of a sine wave generator:
42
42
 
43
43
  from ezmsg.baseproc import (
44
44
  BaseClockDrivenProducer,
45
- BaseClockDrivenProducerUnit,
45
+ BaseClockDrivenUnit,
46
46
  ClockDrivenSettings,
47
47
  ClockDrivenState,
48
48
  processor_state,
@@ -124,7 +124,7 @@ Here's a complete example of a sine wave generator:
124
124
 
125
125
 
126
126
  class SinGeneratorUnit(
127
- BaseClockDrivenProducerUnit[SinGeneratorSettings, SinGenerator]
127
+ BaseClockDrivenUnit[SinGeneratorSettings, SinGenerator]
128
128
  ):
129
129
  """
130
130
  ezmsg Unit wrapper for SinGenerator.
@@ -82,7 +82,7 @@ from .stateful import (
82
82
  from .units import (
83
83
  AdaptiveTransformerType,
84
84
  BaseAdaptiveTransformerUnit,
85
- BaseClockDrivenProducerUnit,
85
+ BaseClockDrivenUnit,
86
86
  BaseConsumerUnit,
87
87
  BaseProcessorUnit,
88
88
  BaseProducerUnit,
@@ -158,7 +158,7 @@ __all__ = [
158
158
  "BaseConsumerUnit",
159
159
  "BaseTransformerUnit",
160
160
  "BaseAdaptiveTransformerUnit",
161
- "BaseClockDrivenProducerUnit",
161
+ "BaseClockDrivenUnit",
162
162
  "GenAxisArray",
163
163
  # Type resolution helpers
164
164
  "get_base_producer_type",
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.2.0'
32
- __version_tuple__ = version_tuple = (1, 2, 0)
31
+ __version__ = version = '1.3.0'
32
+ __version_tuple__ = version_tuple = (1, 3, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -9,7 +9,7 @@ from .clockdriven import (
9
9
  ClockDrivenState,
10
10
  )
11
11
  from .protocols import processor_state
12
- from .units import BaseClockDrivenProducerUnit
12
+ from .units import BaseClockDrivenUnit
13
13
 
14
14
 
15
15
  class CounterSettings(ClockDrivenSettings):
@@ -57,7 +57,7 @@ class CounterTransformer(BaseClockDrivenProducer[CounterSettings, CounterTransfo
57
57
  )
58
58
 
59
59
 
60
- class Counter(BaseClockDrivenProducerUnit[CounterSettings, CounterTransformer]):
60
+ class Counter(BaseClockDrivenUnit[CounterSettings, CounterTransformer]):
61
61
  """
62
62
  Transforms clock ticks into monotonically increasing counter values as AxisArray.
63
63
 
@@ -5,8 +5,7 @@ import typing
5
5
  from dataclasses import dataclass
6
6
 
7
7
  import ezmsg.core as ez
8
-
9
- from .util.message import SampleMessage
8
+ from ezmsg.util.messages.axisarray import AxisArray
10
9
 
11
10
  # --- Processor state decorator ---
12
11
  processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
@@ -138,7 +137,7 @@ class StatefulTransformer(
138
137
 
139
138
 
140
139
  class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
141
- def partial_fit(self, message: SampleMessage) -> None:
140
+ def partial_fit(self, message: AxisArray) -> None:
142
141
  """Update transformer state using labeled training data.
143
142
 
144
143
  This method should update the internal state/parameters of the transformer
@@ -146,4 +145,4 @@ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
146
145
  """
147
146
  ...
148
147
 
149
- async def apartial_fit(self, message: SampleMessage) -> None: ...
148
+ async def apartial_fit(self, message: AxisArray) -> None: ...
@@ -4,6 +4,9 @@ import pickle
4
4
  import typing
5
5
  from abc import ABC, abstractmethod
6
6
 
7
+ from ezmsg.util.messages.axisarray import AxisArray
8
+ from ezmsg.util.messages.util import replace
9
+
7
10
  from .processor import (
8
11
  BaseProcessor,
9
12
  BaseProducer,
@@ -256,7 +259,7 @@ class BaseStatefulTransformer(
256
259
  class BaseAdaptiveTransformer(
257
260
  BaseStatefulTransformer[
258
261
  SettingsType,
259
- MessageInType | SampleMessage,
262
+ MessageInType,
260
263
  MessageOutType | None,
261
264
  StateType,
262
265
  ],
@@ -264,30 +267,41 @@ class BaseAdaptiveTransformer(
264
267
  typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
265
268
  ):
266
269
  @abstractmethod
267
- def partial_fit(self, message: SampleMessage) -> None: ...
270
+ def partial_fit(self, message: AxisArray) -> None: ...
268
271
 
269
- async def apartial_fit(self, message: SampleMessage) -> None:
272
+ async def apartial_fit(self, message: AxisArray) -> None:
270
273
  """Override me if you need async partial fitting."""
271
274
  return self.partial_fit(message)
272
275
 
273
- def __call__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
276
+ def __call__(self, message: MessageInType) -> MessageOutType | None:
274
277
  """
275
278
  Adapt transformer with training data (and optionally labels)
276
- in SampleMessage
279
+ in AxisArray with attrs["trigger"].
277
280
 
278
281
  Args:
279
- message: An instance of SampleMessage with optional
280
- labels (y) in message.trigger.value.data and
281
- data (X) in message.sample.data
282
+ message: An AxisArray with optional trigger in attrs["trigger"],
283
+ containing labels (y) in attrs["trigger"].value and
284
+ data (X) in message.data
282
285
 
283
286
  Returns: None
284
287
  """
285
288
  if is_sample_message(message):
289
+ if isinstance(message, SampleMessage):
290
+ # Auto-convert old format → new format
291
+ message = replace(
292
+ message.sample,
293
+ attrs={**message.sample.attrs, "trigger": message.trigger},
294
+ )
286
295
  return self.partial_fit(message)
287
296
  return super().__call__(message)
288
297
 
289
- async def __acall__(self, message: MessageInType | SampleMessage) -> MessageOutType | None:
298
+ async def __acall__(self, message: MessageInType) -> MessageOutType | None:
290
299
  if is_sample_message(message):
300
+ if isinstance(message, SampleMessage):
301
+ message = replace(
302
+ message.sample,
303
+ attrs={**message.sample.attrs, "trigger": message.trigger},
304
+ )
291
305
  return await self.apartial_fit(message)
292
306
  return await super().__acall__(message)
293
307
 
@@ -3,10 +3,10 @@
3
3
  import math
4
4
  import traceback
5
5
  import typing
6
+ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
 
8
9
  import ezmsg.core as ez
9
- from ezmsg.util.generator import GenState
10
10
  from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
11
11
 
12
12
  from .clockdriven import BaseClockDrivenProducer
@@ -14,7 +14,6 @@ from .composite import CompositeProcessor
14
14
  from .processor import BaseConsumer, BaseProducer, BaseTransformer
15
15
  from .protocols import MessageInType, MessageOutType, SettingsType
16
16
  from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
17
- from .util.message import SampleMessage
18
17
  from .util.profile import profile_subpub
19
18
  from .util.typeresolution import resolve_typevar
20
19
 
@@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
223
222
  ABC,
224
223
  typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
225
224
  ):
226
- INPUT_SAMPLE = ez.InputStream(SampleMessage)
225
+ INPUT_SAMPLE = ez.InputStream(AxisArray)
227
226
  INPUT_SIGNAL = ez.InputStream(MessageInType)
228
227
  OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
229
228
 
@@ -242,11 +241,11 @@ class BaseAdaptiveTransformerUnit(
242
241
  yield self.OUTPUT_SIGNAL, result
243
242
 
244
243
  @ez.subscriber(INPUT_SAMPLE)
245
- async def on_sample(self, msg: SampleMessage) -> None:
244
+ async def on_sample(self, msg: AxisArray) -> None:
246
245
  await self.processor.apartial_fit(msg)
247
246
 
248
247
 
249
- class BaseClockDrivenProducerUnit(
248
+ class BaseClockDrivenUnit(
250
249
  BaseProcessorUnit[SettingsType],
251
250
  ABC,
252
251
  typing.Generic[SettingsType, ClockDrivenProducerType],
@@ -260,7 +259,7 @@ class BaseClockDrivenProducerUnit(
260
259
 
261
260
  Implement a new Unit as follows::
262
261
 
263
- class SinGeneratorUnit(BaseClockDrivenProducerUnit[
262
+ class SinGeneratorUnit(BaseClockDrivenUnit[
264
263
  SinGeneratorSettings, # SettingsType (must extend ClockDrivenSettings)
265
264
  SinProducer, # ClockDrivenProducerType
266
265
  ]):
@@ -287,10 +286,42 @@ class BaseClockDrivenProducerUnit(
287
286
  yield self.OUTPUT_SIGNAL, result
288
287
 
289
288
 
290
- # Legacy class
289
+ class GenState(ez.State):
290
+ """
291
+ .. deprecated::
292
+ ``GenState`` is deprecated. Define a local state class or use
293
+ ``ezmsg.baseproc`` processor classes instead.
294
+ """
295
+
296
+ gen: typing.Generator[typing.Any, typing.Any, None]
297
+
298
+ def __init_subclass__(cls, **kwargs):
299
+ super().__init_subclass__(**kwargs)
300
+ warnings.warn(
301
+ "GenState is deprecated. Define a local state class instead of subclassing GenState.",
302
+ DeprecationWarning,
303
+ stacklevel=2,
304
+ )
305
+
306
+
291
307
  class GenAxisArray(ez.Unit):
308
+ """
309
+ .. deprecated::
310
+ ``GenAxisArray`` is deprecated. Use ``BaseTransformerUnit`` or
311
+ ``BaseAdaptiveTransformerUnit`` from ``ezmsg.baseproc`` instead.
312
+ """
313
+
292
314
  STATE = GenState
293
315
 
316
+ def __init_subclass__(cls, **kwargs):
317
+ super().__init_subclass__(**kwargs)
318
+ warnings.warn(
319
+ "GenAxisArray is deprecated. Use BaseTransformerUnit or "
320
+ "BaseAdaptiveTransformerUnit from ezmsg.baseproc instead.",
321
+ DeprecationWarning,
322
+ stacklevel=2,
323
+ )
324
+
294
325
  INPUT_SIGNAL = ez.InputStream(AxisArray)
295
326
  OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
296
327
  INPUT_SETTINGS = ez.InputStream(ez.Settings)
@@ -1,5 +1,6 @@
1
1
  import time
2
2
  import typing
3
+ import warnings
3
4
  from dataclasses import dataclass, field
4
5
 
5
6
  from ezmsg.util.messages.axisarray import AxisArray
@@ -19,13 +20,28 @@ class SampleTriggerMessage:
19
20
 
20
21
  @dataclass
21
22
  class SampleMessage:
23
+ """
24
+ .. deprecated::
25
+ ``SampleMessage`` is deprecated. Use ``AxisArray`` with
26
+ ``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
27
+ """
28
+
22
29
  trigger: SampleTriggerMessage
23
30
  """The time, window, and value (if any) associated with the trigger."""
24
31
 
25
32
  sample: AxisArray
26
33
  """The data sampled around the trigger."""
27
34
 
35
+ def __post_init__(self):
36
+ warnings.warn(
37
+ "SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
38
+ DeprecationWarning,
39
+ stacklevel=2,
40
+ )
41
+
28
42
 
29
- def is_sample_message(message: typing.Any) -> typing.TypeGuard[SampleMessage]:
30
- """Check if the message is a SampleMessage."""
31
- return hasattr(message, "trigger")
43
+ def is_sample_message(message: typing.Any) -> bool:
44
+ """Detect old SampleMessage OR new AxisArray-with-trigger."""
45
+ if isinstance(message, SampleMessage):
46
+ return True
47
+ return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
@@ -0,0 +1,18 @@
1
+ """pytest configuration for ezmsg-baseproc tests."""
2
+
3
+ import os
4
+ import sys
5
+
6
+ import pytest
7
+
8
+ # Add tests directory to path so 'tests.helpers' can be imported
9
+ _tests_dir = os.path.dirname(__file__)
10
+ _parent_dir = os.path.dirname(_tests_dir)
11
+ if _parent_dir not in sys.path:
12
+ sys.path.insert(0, _parent_dir)
13
+
14
+
15
+ @pytest.fixture
16
+ def test_name(request):
17
+ """Provide the test name to test functions."""
18
+ return request.node.name
File without changes
@@ -0,0 +1,24 @@
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+
6
+ def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path:
7
+ """PYTEST compatible temporary test file creator"""
8
+
9
+ # Get current test name if we can..
10
+ if test_name is None:
11
+ test_name = os.environ.get("PYTEST_CURRENT_TEST")
12
+ if test_name is not None:
13
+ test_name = test_name.split(":")[-1].split(" ")[0]
14
+ else:
15
+ test_name = __name__
16
+
17
+ file_path = Path(tempfile.gettempdir())
18
+ file_path = file_path / Path(f"{test_name}.{extension}")
19
+
20
+ # Create the file
21
+ with open(file_path, "w"):
22
+ pass
23
+
24
+ return file_path
@@ -3,11 +3,11 @@
3
3
  import dataclasses
4
4
  import pickle
5
5
  from types import NoneType
6
- from typing import Any, Generator
7
- from unittest.mock import MagicMock
6
+ from typing import Any
8
7
 
8
+ import numpy as np
9
9
  import pytest
10
- from ezmsg.util.generator import consumer
10
+ from ezmsg.util.messages.axisarray import AxisArray
11
11
 
12
12
  from ezmsg.baseproc import (
13
13
  BaseAdaptiveTransformer,
@@ -22,7 +22,7 @@ from ezmsg.baseproc import (
22
22
  BaseTransformer,
23
23
  CompositeProcessor,
24
24
  CompositeProducer,
25
- SampleMessage,
25
+ SampleTriggerMessage,
26
26
  _get_base_processor_message_in_type,
27
27
  _get_base_processor_message_out_type,
28
28
  _get_base_processor_settings_type,
@@ -136,11 +136,11 @@ class MockAdaptiveTransformer(BaseAdaptiveTransformer[MockSettings, MockMessageA
136
136
  def _reset_state(self, message: MockMessageA) -> None:
137
137
  self._state.iterations = 0
138
138
 
139
- def _process(self, message: MockMessageA | SampleMessage) -> MockMessageB:
139
+ def _process(self, message: MockMessageA) -> MockMessageB:
140
140
  self._state.iterations += 1
141
141
  return MockMessageB()
142
142
 
143
- def partial_fit(self, message: SampleMessage) -> None:
143
+ def partial_fit(self, message: AxisArray) -> None:
144
144
  self._state.iterations += 1
145
145
 
146
146
 
@@ -251,18 +251,18 @@ class ChainedCompositeProcessorWithDeepProcessors(CompositeProcessor[MockSetting
251
251
  }
252
252
 
253
253
 
254
- @consumer
255
- def mock_generator() -> Generator[MockMessageA, MockMessageA, None]:
256
- """A mock generator function for testing purposes."""
257
- while True:
258
- yield MockMessageA()
254
+ class MockGeneratorTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
255
+ """A mock transformer replacing the legacy @consumer generator."""
256
+
257
+ def _process(self, message: MockMessageA) -> MockMessageA:
258
+ return MockMessageA()
259
259
 
260
260
 
261
261
  class MockGeneratorCompositeProcessor(CompositeProcessor[MockSettings, MockMessageA, MockMessageB]):
262
262
  @staticmethod
263
263
  def _initialize_processors(settings):
264
264
  return {
265
- "generator": mock_generator(),
265
+ "generator": MockGeneratorTransformer(settings=settings),
266
266
  "stateful_processor": MockStatefulProcessor(settings=settings),
267
267
  }
268
268
 
@@ -333,27 +333,26 @@ class ChainedCompositeProducerWithDeepProcessors(CompositeProducer[MockSettings,
333
333
  }
334
334
 
335
335
 
336
- @consumer
337
- def mock_producer_generator() -> Generator[MockMessageA, None, None]:
338
- """A mock generator function for testing purposes."""
339
- while True:
340
- yield MockMessageA()
336
+ class MockGeneratorProducer(BaseProducer[MockSettings, MockMessageA]):
337
+ """A mock producer replacing the legacy @consumer producer generator."""
338
+
339
+ async def _produce(self) -> MockMessageA:
340
+ return MockMessageA()
341
+
341
342
 
343
+ class MockGeneratorPassthroughTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
344
+ """A mock transformer replacing the legacy unprimed generator."""
342
345
 
343
- def mock_generator_unprimed() -> Generator[MockMessageA, MockMessageA, None]:
344
- """A mock generator function for testing purposes."""
345
- output = MockMessageA()
346
- while True:
347
- input = yield output
348
- output = input or output
346
+ def _process(self, message: MockMessageA) -> MockMessageA:
347
+ return message or MockMessageA()
349
348
 
350
349
 
351
350
  class MockGeneratorCompositeProducer(CompositeProducer[MockSettings, MockMessageB]):
352
351
  @staticmethod
353
352
  def _initialize_processors(settings):
354
353
  return {
355
- "generator": mock_producer_generator(),
356
- "mock_generator_unprimed": mock_generator_unprimed(),
354
+ "generator": MockGeneratorProducer(settings=settings),
355
+ "mock_generator_passthrough": MockGeneratorPassthroughTransformer(settings=settings),
357
356
  "stateful_processor": MockStatefulProcessor(settings=settings),
358
357
  }
359
358
 
@@ -758,10 +757,13 @@ class TestBaseStatefulTransformer:
758
757
  assert new_state[0].iterations == 1
759
758
 
760
759
 
761
- # Mock SampleMessage for testing BaseAdaptiveTransformer
760
+ # Helper to create an AxisArray with trigger in attrs for testing BaseAdaptiveTransformer
762
761
  def mock_sample_message():
763
- sample_message = MagicMock(spec=SampleMessage)
764
- return sample_message
762
+ return AxisArray(
763
+ data=np.zeros((1, 1)),
764
+ dims=["time", "ch"],
765
+ attrs={"trigger": SampleTriggerMessage()},
766
+ )
765
767
 
766
768
 
767
769
  class TestBaseAdaptiveTransformer:
@@ -778,9 +780,7 @@ class TestBaseAdaptiveTransformer:
778
780
 
779
781
  def test_call_with_sample_message(self):
780
782
  transformer = MockAdaptiveTransformer()
781
- # Create a sample message with a trigger attribute
782
783
  sample_msg = mock_sample_message()
783
- setattr(sample_msg, "trigger", None)
784
784
  result = transformer(sample_msg)
785
785
  assert result is None # partial_fit returns None
786
786
  assert transformer.state.iterations == 1
@@ -0,0 +1,229 @@
1
+ """Integration tests for Clock and Counter ezmsg systems."""
2
+
3
+ import math
4
+ import os
5
+ from dataclasses import field
6
+
7
+ import ezmsg.core as ez
8
+ import numpy as np
9
+ import pytest
10
+ from ezmsg.util.messagecodec import message_log
11
+ from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
12
+ from ezmsg.util.messages.axisarray import AxisArray
13
+ from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
14
+
15
+ from ezmsg.baseproc import (
16
+ Clock,
17
+ ClockSettings,
18
+ Counter,
19
+ CounterSettings,
20
+ )
21
+ from tests.helpers.util import get_test_fn
22
+
23
+
24
+ class ClockTestSystemSettings(ez.Settings):
25
+ clock_settings: ClockSettings
26
+ log_settings: MessageLoggerSettings
27
+ term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)
28
+
29
+
30
+ class ClockTestSystem(ez.Collection):
31
+ SETTINGS = ClockTestSystemSettings
32
+
33
+ CLOCK = Clock()
34
+ LOG = MessageLogger()
35
+ TERM = TerminateOnTotal()
36
+
37
+ def configure(self) -> None:
38
+ self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
39
+ self.LOG.apply_settings(self.SETTINGS.log_settings)
40
+ self.TERM.apply_settings(self.SETTINGS.term_settings)
41
+
42
+ def network(self) -> ez.NetworkDefinition:
43
+ return (
44
+ (self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
45
+ (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
46
+ )
47
+
48
+
49
+ @pytest.mark.parametrize("dispatch_rate", [math.inf, 2.0, 20.0])
50
+ def test_clock_system(
51
+ dispatch_rate: float,
52
+ test_name: str | None = None,
53
+ ):
54
+ run_time = 1.0
55
+ n_target = 100 if math.isinf(dispatch_rate) else int(np.ceil(dispatch_rate * run_time))
56
+ test_filename = get_test_fn(test_name)
57
+ ez.logger.info(test_filename)
58
+ settings = ClockTestSystemSettings(
59
+ clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
60
+ log_settings=MessageLoggerSettings(output=test_filename),
61
+ term_settings=TerminateOnTotalSettings(total=n_target),
62
+ )
63
+ system = ClockTestSystem(settings)
64
+ ez.run(SYSTEM=system)
65
+
66
+ # Collect result
67
+ messages = list(message_log(test_filename))
68
+ os.remove(test_filename)
69
+
70
+ # Clock produces LinearAxis with gain and offset
71
+ assert all(isinstance(m, AxisArray.LinearAxis) for m in messages)
72
+ assert len(messages) >= n_target
73
+
74
+
75
+ class CounterTestSystemSettings(ez.Settings):
76
+ clock_settings: ClockSettings
77
+ counter_settings: CounterSettings
78
+ log_settings: MessageLoggerSettings
79
+ term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)
80
+
81
+
82
+ class CounterTestSystem(ez.Collection):
83
+ """Counter must be driven by Clock in the new architecture."""
84
+
85
+ SETTINGS = CounterTestSystemSettings
86
+
87
+ CLOCK = Clock()
88
+ COUNTER = Counter()
89
+ LOG = MessageLogger()
90
+ TERM = TerminateOnTotal()
91
+
92
+ def configure(self) -> None:
93
+ self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
94
+ self.COUNTER.apply_settings(self.SETTINGS.counter_settings)
95
+ self.LOG.apply_settings(self.SETTINGS.log_settings)
96
+ self.TERM.apply_settings(self.SETTINGS.term_settings)
97
+
98
+ def network(self) -> ez.NetworkDefinition:
99
+ return (
100
+ (self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK),
101
+ (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
102
+ (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
103
+ )
104
+
105
+
106
+ @pytest.mark.parametrize(
107
+ "n_time, fs, dispatch_rate, mod",
108
+ [
109
+ (1, 10.0, math.inf, None), # AFAP mode
110
+ (20, 1000.0, 50.0, None), # Realtime mode (50 Hz dispatch = 20 samples/tick @ 1000 Hz)
111
+ (1, 1000.0, 100.0, 2**3), # 100 Hz dispatch with mod
112
+ (10, 10.0, 10.0, 2**3), # 10 Hz dispatch with mod
113
+ ],
114
+ )
115
+ def test_counter_system(
116
+ n_time: int,
117
+ fs: float,
118
+ dispatch_rate: float,
119
+ mod: int | None,
120
+ test_name: str | None = None,
121
+ ):
122
+ target_dur = 2.6 # 2.6 seconds per test
123
+ if math.isinf(dispatch_rate):
124
+ # AFAP mode - runs as fast as possible
125
+ target_messages = 100 # Fixed target for AFAP
126
+ else:
127
+ target_messages = int(target_dur * dispatch_rate)
128
+
129
+ test_filename = get_test_fn(test_name)
130
+ ez.logger.info(test_filename)
131
+ settings = CounterTestSystemSettings(
132
+ clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
133
+ counter_settings=CounterSettings(
134
+ n_time=n_time,
135
+ fs=fs,
136
+ mod=mod,
137
+ ),
138
+ log_settings=MessageLoggerSettings(
139
+ output=test_filename,
140
+ ),
141
+ term_settings=TerminateOnTotalSettings(
142
+ total=target_messages,
143
+ ),
144
+ )
145
+ system = CounterTestSystem(settings)
146
+ ez.run(SYSTEM=system)
147
+
148
+ # Collect result
149
+ messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
150
+ os.remove(test_filename)
151
+
152
+ if math.isinf(dispatch_rate):
153
+ # The number of messages depends on how fast the computer is
154
+ target_messages = len(messages)
155
+ # This should be an equivalence assertion (==) but the use of TerminateOnTotal does
156
+ # not guarantee that MessageLogger will exit before an additional message is received.
157
+ # Let's just clip the last message if we exceed the target messages.
158
+ if len(messages) > target_messages:
159
+ messages = messages[:target_messages]
160
+ assert len(messages) >= target_messages
161
+
162
+ # Just do one quick data check (Counter now outputs 1D array)
163
+ agg = AxisArray.concatenate(*messages, dim="time")
164
+ target_samples = n_time * target_messages
165
+ expected_data = np.arange(target_samples)
166
+ if mod is not None:
167
+ expected_data = expected_data % mod
168
+ assert np.array_equal(agg.data, expected_data)
169
+
170
+
171
+ @pytest.mark.parametrize(
172
+ "clock_rate, fs, n_time",
173
+ [
174
+ (10.0, 1000.0, 100), # 10 Hz clock, fs=1000, n_time=100 (fixed)
175
+ (20.0, 500.0, None), # 20 Hz clock, fs=500, n_time derived (25 samples per tick)
176
+ (5.0, 1000.0, None), # 5 Hz clock, fs=1000, n_time derived (200 samples per tick)
177
+ ],
178
+ )
179
+ def test_counter_with_external_clock(
180
+ clock_rate: float,
181
+ fs: float,
182
+ n_time: int | None,
183
+ test_name: str | None = None,
184
+ ):
185
+ """Test Counter driven by external Clock (now the standard pattern)."""
186
+ target_messages = 20
187
+ test_filename = get_test_fn(test_name)
188
+ ez.logger.info(test_filename)
189
+
190
+ # This now uses the same CounterTestSystem since all counters need clocks
191
+ settings = CounterTestSystemSettings(
192
+ clock_settings=ClockSettings(dispatch_rate=clock_rate),
193
+ counter_settings=CounterSettings(
194
+ fs=fs,
195
+ n_time=n_time,
196
+ ),
197
+ log_settings=MessageLoggerSettings(output=test_filename),
198
+ term_settings=TerminateOnTotalSettings(total=target_messages),
199
+ )
200
+ system = CounterTestSystem(settings)
201
+ ez.run(SYSTEM=system)
202
+
203
+ # Collect result
204
+ messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
205
+ os.remove(test_filename)
206
+
207
+ assert len(messages) >= target_messages
208
+
209
+ # Verify each message has correct sample rate (gain = 1/fs)
210
+ for msg in messages:
211
+ assert msg.axes["time"].gain == 1.0 / fs
212
+
213
+ # Verify data continuity
214
+ messages = messages[:target_messages] # Trim to target
215
+ agg = AxisArray.concatenate(*messages, dim="time")
216
+
217
+ # Expected samples per tick
218
+ if n_time is not None:
219
+ expected_samples_per_tick = n_time
220
+ else:
221
+ expected_samples_per_tick = int(fs / clock_rate)
222
+
223
+ expected_total = expected_samples_per_tick * target_messages
224
+ # Allow for fractional sample accumulation variance
225
+ assert abs(len(agg.data) - expected_total) <= target_messages
226
+
227
+ # Counter values should be sequential (0, 1, 2, ...)
228
+ expected_data = np.arange(len(agg.data))
229
+ assert np.array_equal(agg.data, expected_data)
File without changes
File without changes