ezmsg-baseproc 1.2.0__tar.gz → 1.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/PKG-INFO +1 -1
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/ProcessorsBase.md +2 -2
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/clockdriven.rst +2 -2
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__init__.py +2 -2
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/__version__.py +2 -2
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/counter.py +2 -2
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/protocols.py +3 -4
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/stateful.py +23 -9
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/units.py +38 -7
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/message.py +19 -3
- ezmsg_baseproc-1.3.0/tests/conftest.py +18 -0
- ezmsg_baseproc-1.3.0/tests/helpers/__init__.py +0 -0
- ezmsg_baseproc-1.3.0/tests/helpers/util.py +24 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_baseproc.py +30 -30
- ezmsg_baseproc-1.3.0/tests/test_clock_counter_system.py +229 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/docs.yml +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/python-publish.yml +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.github/workflows/python-tests.yml +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.gitignore +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/.pre-commit-config.yaml +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/LICENSE +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/README.md +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/Makefile +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/make.bat +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/_templates/autosummary/module.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/api/index.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/conf.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/adaptive.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/checkpoint.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/composite.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/content-processors.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/processor.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/standalone.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/stateful.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/unit.rst +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/index.md +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/pyproject.toml +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clock.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/clockdriven.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/composite.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/processor.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/__init__.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/asio.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/profile.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/src/ezmsg/baseproc/util/typeresolution.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_clock.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_clockdriven.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_counter.py +0 -0
- {ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/tests/test_profile.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ezmsg-baseproc
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.3.0
|
|
4
4
|
Summary: Base processor classes and protocols for ezmsg signal processing pipelines
|
|
5
5
|
Author-email: Griffin Milsap <griffin.milsap@gmail.com>, Preston Peranich <pperanich@gmail.com>, Chadwick Boulay <chadwick.boulay@gmail.com>, Kyle McGraw <kmcgraw@blackrockneuro.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -80,9 +80,9 @@ do not inherit from `BaseStatefulProcessor` and `BaseStatefulProducer`. They acc
|
|
|
80
80
|
| 3 | `BaseConsumerUnit` | 1 | `ConsumerType` |
|
|
81
81
|
| 4 | `BaseTransformerUnit` | 1 | `TransformerType` |
|
|
82
82
|
| 5 | `BaseAdaptiveTransformerUnit` | 1 | `AdaptiveTransformerType` |
|
|
83
|
-
| 6 | `
|
|
83
|
+
| 6 | `BaseClockDrivenUnit` | 1 | `ClockDrivenProducerType` |
|
|
84
84
|
|
|
85
|
-
Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `
|
|
85
|
+
Note, it is strongly recommended to use `BaseConsumerUnit`, `BaseTransformerUnit`, `BaseAdaptiveTransformerUnit`, or `BaseClockDrivenUnit` for implementing concrete subclasses rather than `BaseProcessorUnit`.
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
## Implementing a custom standalone processor
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/clockdriven.rst
RENAMED
|
@@ -42,7 +42,7 @@ Here's a complete example of a sine wave generator:
|
|
|
42
42
|
|
|
43
43
|
from ezmsg.baseproc import (
|
|
44
44
|
BaseClockDrivenProducer,
|
|
45
|
-
|
|
45
|
+
BaseClockDrivenUnit,
|
|
46
46
|
ClockDrivenSettings,
|
|
47
47
|
ClockDrivenState,
|
|
48
48
|
processor_state,
|
|
@@ -124,7 +124,7 @@ Here's a complete example of a sine wave generator:
|
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
class SinGeneratorUnit(
|
|
127
|
-
|
|
127
|
+
BaseClockDrivenUnit[SinGeneratorSettings, SinGenerator]
|
|
128
128
|
):
|
|
129
129
|
"""
|
|
130
130
|
ezmsg Unit wrapper for SinGenerator.
|
|
@@ -82,7 +82,7 @@ from .stateful import (
|
|
|
82
82
|
from .units import (
|
|
83
83
|
AdaptiveTransformerType,
|
|
84
84
|
BaseAdaptiveTransformerUnit,
|
|
85
|
-
|
|
85
|
+
BaseClockDrivenUnit,
|
|
86
86
|
BaseConsumerUnit,
|
|
87
87
|
BaseProcessorUnit,
|
|
88
88
|
BaseProducerUnit,
|
|
@@ -158,7 +158,7 @@ __all__ = [
|
|
|
158
158
|
"BaseConsumerUnit",
|
|
159
159
|
"BaseTransformerUnit",
|
|
160
160
|
"BaseAdaptiveTransformerUnit",
|
|
161
|
-
"
|
|
161
|
+
"BaseClockDrivenUnit",
|
|
162
162
|
"GenAxisArray",
|
|
163
163
|
# Type resolution helpers
|
|
164
164
|
"get_base_producer_type",
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.
|
|
32
|
-
__version_tuple__ = version_tuple = (1,
|
|
31
|
+
__version__ = version = '1.3.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 3, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -9,7 +9,7 @@ from .clockdriven import (
|
|
|
9
9
|
ClockDrivenState,
|
|
10
10
|
)
|
|
11
11
|
from .protocols import processor_state
|
|
12
|
-
from .units import
|
|
12
|
+
from .units import BaseClockDrivenUnit
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class CounterSettings(ClockDrivenSettings):
|
|
@@ -57,7 +57,7 @@ class CounterTransformer(BaseClockDrivenProducer[CounterSettings, CounterTransfo
|
|
|
57
57
|
)
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
class Counter(
|
|
60
|
+
class Counter(BaseClockDrivenUnit[CounterSettings, CounterTransformer]):
|
|
61
61
|
"""
|
|
62
62
|
Transforms clock ticks into monotonically increasing counter values as AxisArray.
|
|
63
63
|
|
|
@@ -5,8 +5,7 @@ import typing
|
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
|
|
7
7
|
import ezmsg.core as ez
|
|
8
|
-
|
|
9
|
-
from .util.message import SampleMessage
|
|
8
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
10
9
|
|
|
11
10
|
# --- Processor state decorator ---
|
|
12
11
|
processor_state = functools.partial(dataclass, unsafe_hash=True, frozen=False, init=False)
|
|
@@ -138,7 +137,7 @@ class StatefulTransformer(
|
|
|
138
137
|
|
|
139
138
|
|
|
140
139
|
class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
|
|
141
|
-
def partial_fit(self, message:
|
|
140
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
142
141
|
"""Update transformer state using labeled training data.
|
|
143
142
|
|
|
144
143
|
This method should update the internal state/parameters of the transformer
|
|
@@ -146,4 +145,4 @@ class AdaptiveTransformer(StatefulTransformer, typing.Protocol):
|
|
|
146
145
|
"""
|
|
147
146
|
...
|
|
148
147
|
|
|
149
|
-
async def apartial_fit(self, message:
|
|
148
|
+
async def apartial_fit(self, message: AxisArray) -> None: ...
|
|
@@ -4,6 +4,9 @@ import pickle
|
|
|
4
4
|
import typing
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
6
|
|
|
7
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
8
|
+
from ezmsg.util.messages.util import replace
|
|
9
|
+
|
|
7
10
|
from .processor import (
|
|
8
11
|
BaseProcessor,
|
|
9
12
|
BaseProducer,
|
|
@@ -256,7 +259,7 @@ class BaseStatefulTransformer(
|
|
|
256
259
|
class BaseAdaptiveTransformer(
|
|
257
260
|
BaseStatefulTransformer[
|
|
258
261
|
SettingsType,
|
|
259
|
-
MessageInType
|
|
262
|
+
MessageInType,
|
|
260
263
|
MessageOutType | None,
|
|
261
264
|
StateType,
|
|
262
265
|
],
|
|
@@ -264,30 +267,41 @@ class BaseAdaptiveTransformer(
|
|
|
264
267
|
typing.Generic[SettingsType, MessageInType, MessageOutType, StateType],
|
|
265
268
|
):
|
|
266
269
|
@abstractmethod
|
|
267
|
-
def partial_fit(self, message:
|
|
270
|
+
def partial_fit(self, message: AxisArray) -> None: ...
|
|
268
271
|
|
|
269
|
-
async def apartial_fit(self, message:
|
|
272
|
+
async def apartial_fit(self, message: AxisArray) -> None:
|
|
270
273
|
"""Override me if you need async partial fitting."""
|
|
271
274
|
return self.partial_fit(message)
|
|
272
275
|
|
|
273
|
-
def __call__(self, message: MessageInType
|
|
276
|
+
def __call__(self, message: MessageInType) -> MessageOutType | None:
|
|
274
277
|
"""
|
|
275
278
|
Adapt transformer with training data (and optionally labels)
|
|
276
|
-
in
|
|
279
|
+
in AxisArray with attrs["trigger"].
|
|
277
280
|
|
|
278
281
|
Args:
|
|
279
|
-
message: An
|
|
280
|
-
labels (y) in
|
|
281
|
-
data (X) in message.
|
|
282
|
+
message: An AxisArray with optional trigger in attrs["trigger"],
|
|
283
|
+
containing labels (y) in attrs["trigger"].value and
|
|
284
|
+
data (X) in message.data
|
|
282
285
|
|
|
283
286
|
Returns: None
|
|
284
287
|
"""
|
|
285
288
|
if is_sample_message(message):
|
|
289
|
+
if isinstance(message, SampleMessage):
|
|
290
|
+
# Auto-convert old format → new format
|
|
291
|
+
message = replace(
|
|
292
|
+
message.sample,
|
|
293
|
+
attrs={**message.sample.attrs, "trigger": message.trigger},
|
|
294
|
+
)
|
|
286
295
|
return self.partial_fit(message)
|
|
287
296
|
return super().__call__(message)
|
|
288
297
|
|
|
289
|
-
async def __acall__(self, message: MessageInType
|
|
298
|
+
async def __acall__(self, message: MessageInType) -> MessageOutType | None:
|
|
290
299
|
if is_sample_message(message):
|
|
300
|
+
if isinstance(message, SampleMessage):
|
|
301
|
+
message = replace(
|
|
302
|
+
message.sample,
|
|
303
|
+
attrs={**message.sample.attrs, "trigger": message.trigger},
|
|
304
|
+
)
|
|
291
305
|
return await self.apartial_fit(message)
|
|
292
306
|
return await super().__acall__(message)
|
|
293
307
|
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import math
|
|
4
4
|
import traceback
|
|
5
5
|
import typing
|
|
6
|
+
import warnings
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
|
|
8
9
|
import ezmsg.core as ez
|
|
9
|
-
from ezmsg.util.generator import GenState
|
|
10
10
|
from ezmsg.util.messages.axisarray import AxisArray, LinearAxis
|
|
11
11
|
|
|
12
12
|
from .clockdriven import BaseClockDrivenProducer
|
|
@@ -14,7 +14,6 @@ from .composite import CompositeProcessor
|
|
|
14
14
|
from .processor import BaseConsumer, BaseProducer, BaseTransformer
|
|
15
15
|
from .protocols import MessageInType, MessageOutType, SettingsType
|
|
16
16
|
from .stateful import BaseAdaptiveTransformer, BaseStatefulConsumer, BaseStatefulTransformer
|
|
17
|
-
from .util.message import SampleMessage
|
|
18
17
|
from .util.profile import profile_subpub
|
|
19
18
|
from .util.typeresolution import resolve_typevar
|
|
20
19
|
|
|
@@ -223,7 +222,7 @@ class BaseAdaptiveTransformerUnit(
|
|
|
223
222
|
ABC,
|
|
224
223
|
typing.Generic[SettingsType, MessageInType, MessageOutType, AdaptiveTransformerType],
|
|
225
224
|
):
|
|
226
|
-
INPUT_SAMPLE = ez.InputStream(
|
|
225
|
+
INPUT_SAMPLE = ez.InputStream(AxisArray)
|
|
227
226
|
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
228
227
|
OUTPUT_SIGNAL = ez.OutputStream(MessageOutType)
|
|
229
228
|
|
|
@@ -242,11 +241,11 @@ class BaseAdaptiveTransformerUnit(
|
|
|
242
241
|
yield self.OUTPUT_SIGNAL, result
|
|
243
242
|
|
|
244
243
|
@ez.subscriber(INPUT_SAMPLE)
|
|
245
|
-
async def on_sample(self, msg:
|
|
244
|
+
async def on_sample(self, msg: AxisArray) -> None:
|
|
246
245
|
await self.processor.apartial_fit(msg)
|
|
247
246
|
|
|
248
247
|
|
|
249
|
-
class
|
|
248
|
+
class BaseClockDrivenUnit(
|
|
250
249
|
BaseProcessorUnit[SettingsType],
|
|
251
250
|
ABC,
|
|
252
251
|
typing.Generic[SettingsType, ClockDrivenProducerType],
|
|
@@ -260,7 +259,7 @@ class BaseClockDrivenProducerUnit(
|
|
|
260
259
|
|
|
261
260
|
Implement a new Unit as follows::
|
|
262
261
|
|
|
263
|
-
class SinGeneratorUnit(
|
|
262
|
+
class SinGeneratorUnit(BaseClockDrivenUnit[
|
|
264
263
|
SinGeneratorSettings, # SettingsType (must extend ClockDrivenSettings)
|
|
265
264
|
SinProducer, # ClockDrivenProducerType
|
|
266
265
|
]):
|
|
@@ -287,10 +286,42 @@ class BaseClockDrivenProducerUnit(
|
|
|
287
286
|
yield self.OUTPUT_SIGNAL, result
|
|
288
287
|
|
|
289
288
|
|
|
290
|
-
|
|
289
|
+
class GenState(ez.State):
|
|
290
|
+
"""
|
|
291
|
+
.. deprecated::
|
|
292
|
+
``GenState`` is deprecated. Define a local state class or use
|
|
293
|
+
``ezmsg.baseproc`` processor classes instead.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
gen: typing.Generator[typing.Any, typing.Any, None]
|
|
297
|
+
|
|
298
|
+
def __init_subclass__(cls, **kwargs):
|
|
299
|
+
super().__init_subclass__(**kwargs)
|
|
300
|
+
warnings.warn(
|
|
301
|
+
"GenState is deprecated. Define a local state class instead of subclassing GenState.",
|
|
302
|
+
DeprecationWarning,
|
|
303
|
+
stacklevel=2,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
291
307
|
class GenAxisArray(ez.Unit):
|
|
308
|
+
"""
|
|
309
|
+
.. deprecated::
|
|
310
|
+
``GenAxisArray`` is deprecated. Use ``BaseTransformerUnit`` or
|
|
311
|
+
``BaseAdaptiveTransformerUnit`` from ``ezmsg.baseproc`` instead.
|
|
312
|
+
"""
|
|
313
|
+
|
|
292
314
|
STATE = GenState
|
|
293
315
|
|
|
316
|
+
def __init_subclass__(cls, **kwargs):
|
|
317
|
+
super().__init_subclass__(**kwargs)
|
|
318
|
+
warnings.warn(
|
|
319
|
+
"GenAxisArray is deprecated. Use BaseTransformerUnit or "
|
|
320
|
+
"BaseAdaptiveTransformerUnit from ezmsg.baseproc instead.",
|
|
321
|
+
DeprecationWarning,
|
|
322
|
+
stacklevel=2,
|
|
323
|
+
)
|
|
324
|
+
|
|
294
325
|
INPUT_SIGNAL = ez.InputStream(AxisArray)
|
|
295
326
|
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
|
|
296
327
|
INPUT_SETTINGS = ez.InputStream(ez.Settings)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import time
|
|
2
2
|
import typing
|
|
3
|
+
import warnings
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
|
|
5
6
|
from ezmsg.util.messages.axisarray import AxisArray
|
|
@@ -19,13 +20,28 @@ class SampleTriggerMessage:
|
|
|
19
20
|
|
|
20
21
|
@dataclass
|
|
21
22
|
class SampleMessage:
|
|
23
|
+
"""
|
|
24
|
+
.. deprecated::
|
|
25
|
+
``SampleMessage`` is deprecated. Use ``AxisArray`` with
|
|
26
|
+
``attrs={"trigger": SampleTriggerMessage(...)}`` instead.
|
|
27
|
+
"""
|
|
28
|
+
|
|
22
29
|
trigger: SampleTriggerMessage
|
|
23
30
|
"""The time, window, and value (if any) associated with the trigger."""
|
|
24
31
|
|
|
25
32
|
sample: AxisArray
|
|
26
33
|
"""The data sampled around the trigger."""
|
|
27
34
|
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
warnings.warn(
|
|
37
|
+
"SampleMessage is deprecated. Use AxisArray with " "attrs={'trigger': SampleTriggerMessage(...)} instead.",
|
|
38
|
+
DeprecationWarning,
|
|
39
|
+
stacklevel=2,
|
|
40
|
+
)
|
|
41
|
+
|
|
28
42
|
|
|
29
|
-
def is_sample_message(message: typing.Any) ->
|
|
30
|
-
"""
|
|
31
|
-
|
|
43
|
+
def is_sample_message(message: typing.Any) -> bool:
|
|
44
|
+
"""Detect old SampleMessage OR new AxisArray-with-trigger."""
|
|
45
|
+
if isinstance(message, SampleMessage):
|
|
46
|
+
return True
|
|
47
|
+
return isinstance(message, AxisArray) and "trigger" in getattr(message, "attrs", {})
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""pytest configuration for ezmsg-baseproc tests."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
# Add tests directory to path so 'tests.helpers' can be imported
|
|
9
|
+
_tests_dir = os.path.dirname(__file__)
|
|
10
|
+
_parent_dir = os.path.dirname(_tests_dir)
|
|
11
|
+
if _parent_dir not in sys.path:
|
|
12
|
+
sys.path.insert(0, _parent_dir)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def test_name(request):
|
|
17
|
+
"""Provide the test name to test functions."""
|
|
18
|
+
return request.node.name
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_test_fn(test_name: str | None = None, extension: str = "txt") -> Path:
|
|
7
|
+
"""PYTEST compatible temporary test file creator"""
|
|
8
|
+
|
|
9
|
+
# Get current test name if we can..
|
|
10
|
+
if test_name is None:
|
|
11
|
+
test_name = os.environ.get("PYTEST_CURRENT_TEST")
|
|
12
|
+
if test_name is not None:
|
|
13
|
+
test_name = test_name.split(":")[-1].split(" ")[0]
|
|
14
|
+
else:
|
|
15
|
+
test_name = __name__
|
|
16
|
+
|
|
17
|
+
file_path = Path(tempfile.gettempdir())
|
|
18
|
+
file_path = file_path / Path(f"{test_name}.{extension}")
|
|
19
|
+
|
|
20
|
+
# Create the file
|
|
21
|
+
with open(file_path, "w"):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
return file_path
|
|
@@ -3,11 +3,11 @@
|
|
|
3
3
|
import dataclasses
|
|
4
4
|
import pickle
|
|
5
5
|
from types import NoneType
|
|
6
|
-
from typing import Any
|
|
7
|
-
from unittest.mock import MagicMock
|
|
6
|
+
from typing import Any
|
|
8
7
|
|
|
8
|
+
import numpy as np
|
|
9
9
|
import pytest
|
|
10
|
-
from ezmsg.util.
|
|
10
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
11
11
|
|
|
12
12
|
from ezmsg.baseproc import (
|
|
13
13
|
BaseAdaptiveTransformer,
|
|
@@ -22,7 +22,7 @@ from ezmsg.baseproc import (
|
|
|
22
22
|
BaseTransformer,
|
|
23
23
|
CompositeProcessor,
|
|
24
24
|
CompositeProducer,
|
|
25
|
-
|
|
25
|
+
SampleTriggerMessage,
|
|
26
26
|
_get_base_processor_message_in_type,
|
|
27
27
|
_get_base_processor_message_out_type,
|
|
28
28
|
_get_base_processor_settings_type,
|
|
@@ -136,11 +136,11 @@ class MockAdaptiveTransformer(BaseAdaptiveTransformer[MockSettings, MockMessageA
|
|
|
136
136
|
def _reset_state(self, message: MockMessageA) -> None:
|
|
137
137
|
self._state.iterations = 0
|
|
138
138
|
|
|
139
|
-
def _process(self, message: MockMessageA
|
|
139
|
+
def _process(self, message: MockMessageA) -> MockMessageB:
|
|
140
140
|
self._state.iterations += 1
|
|
141
141
|
return MockMessageB()
|
|
142
142
|
|
|
143
|
-
def partial_fit(self, message:
|
|
143
|
+
def partial_fit(self, message: AxisArray) -> None:
|
|
144
144
|
self._state.iterations += 1
|
|
145
145
|
|
|
146
146
|
|
|
@@ -251,18 +251,18 @@ class ChainedCompositeProcessorWithDeepProcessors(CompositeProcessor[MockSetting
|
|
|
251
251
|
}
|
|
252
252
|
|
|
253
253
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
254
|
+
class MockGeneratorTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
|
|
255
|
+
"""A mock transformer replacing the legacy @consumer generator."""
|
|
256
|
+
|
|
257
|
+
def _process(self, message: MockMessageA) -> MockMessageA:
|
|
258
|
+
return MockMessageA()
|
|
259
259
|
|
|
260
260
|
|
|
261
261
|
class MockGeneratorCompositeProcessor(CompositeProcessor[MockSettings, MockMessageA, MockMessageB]):
|
|
262
262
|
@staticmethod
|
|
263
263
|
def _initialize_processors(settings):
|
|
264
264
|
return {
|
|
265
|
-
"generator":
|
|
265
|
+
"generator": MockGeneratorTransformer(settings=settings),
|
|
266
266
|
"stateful_processor": MockStatefulProcessor(settings=settings),
|
|
267
267
|
}
|
|
268
268
|
|
|
@@ -333,27 +333,26 @@ class ChainedCompositeProducerWithDeepProcessors(CompositeProducer[MockSettings,
|
|
|
333
333
|
}
|
|
334
334
|
|
|
335
335
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
336
|
+
class MockGeneratorProducer(BaseProducer[MockSettings, MockMessageA]):
|
|
337
|
+
"""A mock producer replacing the legacy @consumer producer generator."""
|
|
338
|
+
|
|
339
|
+
async def _produce(self) -> MockMessageA:
|
|
340
|
+
return MockMessageA()
|
|
341
|
+
|
|
341
342
|
|
|
343
|
+
class MockGeneratorPassthroughTransformer(BaseTransformer[MockSettings, MockMessageA, MockMessageA]):
|
|
344
|
+
"""A mock transformer replacing the legacy unprimed generator."""
|
|
342
345
|
|
|
343
|
-
def
|
|
344
|
-
|
|
345
|
-
output = MockMessageA()
|
|
346
|
-
while True:
|
|
347
|
-
input = yield output
|
|
348
|
-
output = input or output
|
|
346
|
+
def _process(self, message: MockMessageA) -> MockMessageA:
|
|
347
|
+
return message or MockMessageA()
|
|
349
348
|
|
|
350
349
|
|
|
351
350
|
class MockGeneratorCompositeProducer(CompositeProducer[MockSettings, MockMessageB]):
|
|
352
351
|
@staticmethod
|
|
353
352
|
def _initialize_processors(settings):
|
|
354
353
|
return {
|
|
355
|
-
"generator":
|
|
356
|
-
"
|
|
354
|
+
"generator": MockGeneratorProducer(settings=settings),
|
|
355
|
+
"mock_generator_passthrough": MockGeneratorPassthroughTransformer(settings=settings),
|
|
357
356
|
"stateful_processor": MockStatefulProcessor(settings=settings),
|
|
358
357
|
}
|
|
359
358
|
|
|
@@ -758,10 +757,13 @@ class TestBaseStatefulTransformer:
|
|
|
758
757
|
assert new_state[0].iterations == 1
|
|
759
758
|
|
|
760
759
|
|
|
761
|
-
#
|
|
760
|
+
# Helper to create an AxisArray with trigger in attrs for testing BaseAdaptiveTransformer
|
|
762
761
|
def mock_sample_message():
|
|
763
|
-
|
|
764
|
-
|
|
762
|
+
return AxisArray(
|
|
763
|
+
data=np.zeros((1, 1)),
|
|
764
|
+
dims=["time", "ch"],
|
|
765
|
+
attrs={"trigger": SampleTriggerMessage()},
|
|
766
|
+
)
|
|
765
767
|
|
|
766
768
|
|
|
767
769
|
class TestBaseAdaptiveTransformer:
|
|
@@ -778,9 +780,7 @@ class TestBaseAdaptiveTransformer:
|
|
|
778
780
|
|
|
779
781
|
def test_call_with_sample_message(self):
|
|
780
782
|
transformer = MockAdaptiveTransformer()
|
|
781
|
-
# Create a sample message with a trigger attribute
|
|
782
783
|
sample_msg = mock_sample_message()
|
|
783
|
-
setattr(sample_msg, "trigger", None)
|
|
784
784
|
result = transformer(sample_msg)
|
|
785
785
|
assert result is None # partial_fit returns None
|
|
786
786
|
assert transformer.state.iterations == 1
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Integration tests for Clock and Counter ezmsg systems."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
from dataclasses import field
|
|
6
|
+
|
|
7
|
+
import ezmsg.core as ez
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytest
|
|
10
|
+
from ezmsg.util.messagecodec import message_log
|
|
11
|
+
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
|
|
12
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
13
|
+
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
|
|
14
|
+
|
|
15
|
+
from ezmsg.baseproc import (
|
|
16
|
+
Clock,
|
|
17
|
+
ClockSettings,
|
|
18
|
+
Counter,
|
|
19
|
+
CounterSettings,
|
|
20
|
+
)
|
|
21
|
+
from tests.helpers.util import get_test_fn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ClockTestSystemSettings(ez.Settings):
|
|
25
|
+
clock_settings: ClockSettings
|
|
26
|
+
log_settings: MessageLoggerSettings
|
|
27
|
+
term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ClockTestSystem(ez.Collection):
|
|
31
|
+
SETTINGS = ClockTestSystemSettings
|
|
32
|
+
|
|
33
|
+
CLOCK = Clock()
|
|
34
|
+
LOG = MessageLogger()
|
|
35
|
+
TERM = TerminateOnTotal()
|
|
36
|
+
|
|
37
|
+
def configure(self) -> None:
|
|
38
|
+
self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
|
|
39
|
+
self.LOG.apply_settings(self.SETTINGS.log_settings)
|
|
40
|
+
self.TERM.apply_settings(self.SETTINGS.term_settings)
|
|
41
|
+
|
|
42
|
+
def network(self) -> ez.NetworkDefinition:
|
|
43
|
+
return (
|
|
44
|
+
(self.CLOCK.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
|
|
45
|
+
(self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.parametrize("dispatch_rate", [math.inf, 2.0, 20.0])
|
|
50
|
+
def test_clock_system(
|
|
51
|
+
dispatch_rate: float,
|
|
52
|
+
test_name: str | None = None,
|
|
53
|
+
):
|
|
54
|
+
run_time = 1.0
|
|
55
|
+
n_target = 100 if math.isinf(dispatch_rate) else int(np.ceil(dispatch_rate * run_time))
|
|
56
|
+
test_filename = get_test_fn(test_name)
|
|
57
|
+
ez.logger.info(test_filename)
|
|
58
|
+
settings = ClockTestSystemSettings(
|
|
59
|
+
clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
|
|
60
|
+
log_settings=MessageLoggerSettings(output=test_filename),
|
|
61
|
+
term_settings=TerminateOnTotalSettings(total=n_target),
|
|
62
|
+
)
|
|
63
|
+
system = ClockTestSystem(settings)
|
|
64
|
+
ez.run(SYSTEM=system)
|
|
65
|
+
|
|
66
|
+
# Collect result
|
|
67
|
+
messages = list(message_log(test_filename))
|
|
68
|
+
os.remove(test_filename)
|
|
69
|
+
|
|
70
|
+
# Clock produces LinearAxis with gain and offset
|
|
71
|
+
assert all(isinstance(m, AxisArray.LinearAxis) for m in messages)
|
|
72
|
+
assert len(messages) >= n_target
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class CounterTestSystemSettings(ez.Settings):
|
|
76
|
+
clock_settings: ClockSettings
|
|
77
|
+
counter_settings: CounterSettings
|
|
78
|
+
log_settings: MessageLoggerSettings
|
|
79
|
+
term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CounterTestSystem(ez.Collection):
|
|
83
|
+
"""Counter must be driven by Clock in the new architecture."""
|
|
84
|
+
|
|
85
|
+
SETTINGS = CounterTestSystemSettings
|
|
86
|
+
|
|
87
|
+
CLOCK = Clock()
|
|
88
|
+
COUNTER = Counter()
|
|
89
|
+
LOG = MessageLogger()
|
|
90
|
+
TERM = TerminateOnTotal()
|
|
91
|
+
|
|
92
|
+
def configure(self) -> None:
|
|
93
|
+
self.CLOCK.apply_settings(self.SETTINGS.clock_settings)
|
|
94
|
+
self.COUNTER.apply_settings(self.SETTINGS.counter_settings)
|
|
95
|
+
self.LOG.apply_settings(self.SETTINGS.log_settings)
|
|
96
|
+
self.TERM.apply_settings(self.SETTINGS.term_settings)
|
|
97
|
+
|
|
98
|
+
def network(self) -> ez.NetworkDefinition:
|
|
99
|
+
return (
|
|
100
|
+
(self.CLOCK.OUTPUT_SIGNAL, self.COUNTER.INPUT_CLOCK),
|
|
101
|
+
(self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE),
|
|
102
|
+
(self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.parametrize(
|
|
107
|
+
"n_time, fs, dispatch_rate, mod",
|
|
108
|
+
[
|
|
109
|
+
(1, 10.0, math.inf, None), # AFAP mode
|
|
110
|
+
(20, 1000.0, 50.0, None), # Realtime mode (50 Hz dispatch = 20 samples/tick @ 1000 Hz)
|
|
111
|
+
(1, 1000.0, 100.0, 2**3), # 100 Hz dispatch with mod
|
|
112
|
+
(10, 10.0, 10.0, 2**3), # 10 Hz dispatch with mod
|
|
113
|
+
],
|
|
114
|
+
)
|
|
115
|
+
def test_counter_system(
|
|
116
|
+
n_time: int,
|
|
117
|
+
fs: float,
|
|
118
|
+
dispatch_rate: float,
|
|
119
|
+
mod: int | None,
|
|
120
|
+
test_name: str | None = None,
|
|
121
|
+
):
|
|
122
|
+
target_dur = 2.6 # 2.6 seconds per test
|
|
123
|
+
if math.isinf(dispatch_rate):
|
|
124
|
+
# AFAP mode - runs as fast as possible
|
|
125
|
+
target_messages = 100 # Fixed target for AFAP
|
|
126
|
+
else:
|
|
127
|
+
target_messages = int(target_dur * dispatch_rate)
|
|
128
|
+
|
|
129
|
+
test_filename = get_test_fn(test_name)
|
|
130
|
+
ez.logger.info(test_filename)
|
|
131
|
+
settings = CounterTestSystemSettings(
|
|
132
|
+
clock_settings=ClockSettings(dispatch_rate=dispatch_rate),
|
|
133
|
+
counter_settings=CounterSettings(
|
|
134
|
+
n_time=n_time,
|
|
135
|
+
fs=fs,
|
|
136
|
+
mod=mod,
|
|
137
|
+
),
|
|
138
|
+
log_settings=MessageLoggerSettings(
|
|
139
|
+
output=test_filename,
|
|
140
|
+
),
|
|
141
|
+
term_settings=TerminateOnTotalSettings(
|
|
142
|
+
total=target_messages,
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
system = CounterTestSystem(settings)
|
|
146
|
+
ez.run(SYSTEM=system)
|
|
147
|
+
|
|
148
|
+
# Collect result
|
|
149
|
+
messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
|
|
150
|
+
os.remove(test_filename)
|
|
151
|
+
|
|
152
|
+
if math.isinf(dispatch_rate):
|
|
153
|
+
# The number of messages depends on how fast the computer is
|
|
154
|
+
target_messages = len(messages)
|
|
155
|
+
# This should be an equivalence assertion (==) but the use of TerminateOnTotal does
|
|
156
|
+
# not guarantee that MessageLogger will exit before an additional message is received.
|
|
157
|
+
# Let's just clip the last message if we exceed the target messages.
|
|
158
|
+
if len(messages) > target_messages:
|
|
159
|
+
messages = messages[:target_messages]
|
|
160
|
+
assert len(messages) >= target_messages
|
|
161
|
+
|
|
162
|
+
# Just do one quick data check (Counter now outputs 1D array)
|
|
163
|
+
agg = AxisArray.concatenate(*messages, dim="time")
|
|
164
|
+
target_samples = n_time * target_messages
|
|
165
|
+
expected_data = np.arange(target_samples)
|
|
166
|
+
if mod is not None:
|
|
167
|
+
expected_data = expected_data % mod
|
|
168
|
+
assert np.array_equal(agg.data, expected_data)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.mark.parametrize(
|
|
172
|
+
"clock_rate, fs, n_time",
|
|
173
|
+
[
|
|
174
|
+
(10.0, 1000.0, 100), # 10 Hz clock, fs=1000, n_time=100 (fixed)
|
|
175
|
+
(20.0, 500.0, None), # 20 Hz clock, fs=500, n_time derived (25 samples per tick)
|
|
176
|
+
(5.0, 1000.0, None), # 5 Hz clock, fs=1000, n_time derived (200 samples per tick)
|
|
177
|
+
],
|
|
178
|
+
)
|
|
179
|
+
def test_counter_with_external_clock(
|
|
180
|
+
clock_rate: float,
|
|
181
|
+
fs: float,
|
|
182
|
+
n_time: int | None,
|
|
183
|
+
test_name: str | None = None,
|
|
184
|
+
):
|
|
185
|
+
"""Test Counter driven by external Clock (now the standard pattern)."""
|
|
186
|
+
target_messages = 20
|
|
187
|
+
test_filename = get_test_fn(test_name)
|
|
188
|
+
ez.logger.info(test_filename)
|
|
189
|
+
|
|
190
|
+
# This now uses the same CounterTestSystem since all counters need clocks
|
|
191
|
+
settings = CounterTestSystemSettings(
|
|
192
|
+
clock_settings=ClockSettings(dispatch_rate=clock_rate),
|
|
193
|
+
counter_settings=CounterSettings(
|
|
194
|
+
fs=fs,
|
|
195
|
+
n_time=n_time,
|
|
196
|
+
),
|
|
197
|
+
log_settings=MessageLoggerSettings(output=test_filename),
|
|
198
|
+
term_settings=TerminateOnTotalSettings(total=target_messages),
|
|
199
|
+
)
|
|
200
|
+
system = CounterTestSystem(settings)
|
|
201
|
+
ez.run(SYSTEM=system)
|
|
202
|
+
|
|
203
|
+
# Collect result
|
|
204
|
+
messages: list[AxisArray] = [_ for _ in message_log(test_filename)]
|
|
205
|
+
os.remove(test_filename)
|
|
206
|
+
|
|
207
|
+
assert len(messages) >= target_messages
|
|
208
|
+
|
|
209
|
+
# Verify each message has correct sample rate (gain = 1/fs)
|
|
210
|
+
for msg in messages:
|
|
211
|
+
assert msg.axes["time"].gain == 1.0 / fs
|
|
212
|
+
|
|
213
|
+
# Verify data continuity
|
|
214
|
+
messages = messages[:target_messages] # Trim to target
|
|
215
|
+
agg = AxisArray.concatenate(*messages, dim="time")
|
|
216
|
+
|
|
217
|
+
# Expected samples per tick
|
|
218
|
+
if n_time is not None:
|
|
219
|
+
expected_samples_per_tick = n_time
|
|
220
|
+
else:
|
|
221
|
+
expected_samples_per_tick = int(fs / clock_rate)
|
|
222
|
+
|
|
223
|
+
expected_total = expected_samples_per_tick * target_messages
|
|
224
|
+
# Allow for fractional sample accumulation variance
|
|
225
|
+
assert abs(len(agg.data) - expected_total) <= target_messages
|
|
226
|
+
|
|
227
|
+
# Counter values should be sequential (0, 1, 2, ...)
|
|
228
|
+
expected_data = np.arange(len(agg.data))
|
|
229
|
+
assert np.array_equal(agg.data, expected_data)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/adaptive.rst
RENAMED
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/checkpoint.rst
RENAMED
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/composite.rst
RENAMED
|
File without changes
|
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/processor.rst
RENAMED
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/standalone.rst
RENAMED
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/stateful.rst
RENAMED
|
File without changes
|
{ezmsg_baseproc-1.2.0 → ezmsg_baseproc-1.3.0}/docs/source/guides/how-tos/processors/unit.rst
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|