ezmsg-simbiophys 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,125 @@
1
+ """ezmsg-simbiophys: Signal simulation and synthesis for ezmsg."""
2
+
3
+ from .__version__ import __version__ as __version__
4
+
5
+ # Clock
6
+ from .clock import (
7
+ Clock,
8
+ ClockProducer,
9
+ ClockSettings,
10
+ ClockState,
11
+ aclock,
12
+ clock,
13
+ )
14
+
15
+ # Cosine Tuning
16
+ from .cosine_tuning import (
17
+ CosineTuningParams,
18
+ CosineTuningSettings,
19
+ CosineTuningState,
20
+ CosineTuningTransformer,
21
+ CosineTuningUnit,
22
+ )
23
+
24
+ # Counter
25
+ from .counter import (
26
+ Counter,
27
+ CounterProducer,
28
+ CounterSettings,
29
+ CounterState,
30
+ acounter,
31
+ )
32
+
33
+ # Dynamic Colored Noise
34
+ from .dynamic_colored_noise import (
35
+ ColoredNoiseFilterState,
36
+ DynamicColoredNoiseSettings,
37
+ DynamicColoredNoiseState,
38
+ DynamicColoredNoiseTransformer,
39
+ DynamicColoredNoiseUnit,
40
+ compute_kasdin_coefficients,
41
+ )
42
+
43
+ # EEG
44
+ from .eeg import (
45
+ EEGSynth,
46
+ EEGSynthSettings,
47
+ )
48
+
49
+ # Noise
50
+ from .noise import (
51
+ NoiseSettings,
52
+ PinkNoise,
53
+ PinkNoiseProducer,
54
+ PinkNoiseSettings,
55
+ RandomGenerator,
56
+ RandomGeneratorSettings,
57
+ RandomTransformer,
58
+ WhiteNoise,
59
+ WhiteNoiseProducer,
60
+ WhiteNoiseSettings,
61
+ )
62
+
63
+ # Oscillator
64
+ from .oscillator import (
65
+ Oscillator,
66
+ OscillatorProducer,
67
+ OscillatorSettings,
68
+ SinGenerator,
69
+ SinGeneratorSettings,
70
+ SinTransformer,
71
+ sin,
72
+ )
73
+
74
+ __all__ = [
75
+ # Version
76
+ "__version__",
77
+ # Clock
78
+ "Clock",
79
+ "ClockProducer",
80
+ "ClockSettings",
81
+ "ClockState",
82
+ "aclock",
83
+ "clock",
84
+ # Counter
85
+ "Counter",
86
+ "CounterProducer",
87
+ "CounterSettings",
88
+ "CounterState",
89
+ "acounter",
90
+ # Oscillator
91
+ "Oscillator",
92
+ "OscillatorProducer",
93
+ "OscillatorSettings",
94
+ "SinGenerator",
95
+ "SinGeneratorSettings",
96
+ "SinTransformer",
97
+ "sin",
98
+ # Noise
99
+ "NoiseSettings",
100
+ "PinkNoise",
101
+ "PinkNoiseProducer",
102
+ "PinkNoiseSettings",
103
+ "RandomGenerator",
104
+ "RandomGeneratorSettings",
105
+ "RandomTransformer",
106
+ "WhiteNoise",
107
+ "WhiteNoiseProducer",
108
+ "WhiteNoiseSettings",
109
+ # EEG
110
+ "EEGSynth",
111
+ "EEGSynthSettings",
112
+ # Cosine Tuning
113
+ "CosineTuningParams",
114
+ "CosineTuningSettings",
115
+ "CosineTuningState",
116
+ "CosineTuningTransformer",
117
+ "CosineTuningUnit",
118
+ # Dynamic Colored Noise
119
+ "ColoredNoiseFilterState",
120
+ "DynamicColoredNoiseSettings",
121
+ "DynamicColoredNoiseState",
122
+ "DynamicColoredNoiseTransformer",
123
+ "DynamicColoredNoiseUnit",
124
+ "compute_kasdin_coefficients",
125
+ ]
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '1.0.0'
32
+ __version_tuple__ = version_tuple = (1, 0, 0)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,60 @@
1
+ """Base classes for counter-based producers."""
2
+
3
+ import traceback
4
+ import typing
5
+
6
+ import ezmsg.core as ez
7
+ from ezmsg.baseproc import BaseProducerUnit, MessageInType, MessageOutType, ProducerType, SettingsType
8
+ from ezmsg.baseproc.util.profile import profile_subpub
9
+
10
+ from .counter import CounterProducer
11
+
12
+
13
+ class BaseCounterFirstProducerUnit(
14
+ BaseProducerUnit[SettingsType, MessageOutType, ProducerType],
15
+ typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType],
16
+ ):
17
+ """
18
+ Base class for units whose primary processor is a composite producer with a CounterProducer as the first
19
+ processor (producer) in the chain.
20
+ """
21
+
22
+ INPUT_SIGNAL = ez.InputStream(MessageInType)
23
+
24
+ def create_producer(self):
25
+ super().create_producer()
26
+
27
+ def recurse_get_counter(proc) -> CounterProducer:
28
+ if hasattr(proc, "_procs"):
29
+ return recurse_get_counter(list(proc._procs.values())[0])
30
+ return proc
31
+
32
+ self._counter = recurse_get_counter(self.producer)
33
+
34
+ @ez.subscriber(INPUT_SIGNAL, zero_copy=True)
35
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
36
+ @profile_subpub(trace_oldest=False)
37
+ async def on_signal(self, _: ez.Flag):
38
+ if self.producer.settings.dispatch_rate == "ext_clock":
39
+ out = await self.producer.__acall__()
40
+ yield self.OUTPUT_SIGNAL, out
41
+
42
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
43
+ async def produce(self) -> typing.AsyncGenerator:
44
+ try:
45
+ counter_state = self._counter.state
46
+ while True:
47
+ # Once-only, enter the generator loop
48
+ await counter_state.new_generator.wait()
49
+ counter_state.new_generator.clear()
50
+
51
+ if self.producer.settings.dispatch_rate == "ext_clock":
52
+ # We shouldn't even be here. Cycle around and wait on the event again.
53
+ continue
54
+
55
+ # We are not using an external clock. Run the generator.
56
+ while not counter_state.new_generator.is_set():
57
+ out = await self.producer.__acall__()
58
+ yield self.OUTPUT_SIGNAL, out
59
+ except Exception:
60
+ ez.logger.info(traceback.format_exc())
@@ -0,0 +1,99 @@
1
+ """Clock generator for timing control."""
2
+
3
+ import asyncio
4
+ import time
5
+ import typing
6
+ from dataclasses import field
7
+
8
+ import ezmsg.core as ez
9
+ from ezmsg.baseproc import BaseProducerUnit, BaseStatefulProducer, processor_state
10
+
11
+
12
+ class ClockSettings(ez.Settings):
13
+ """Settings for clock generator."""
14
+
15
+ dispatch_rate: float | str | None = None
16
+ """Dispatch rate in Hz, 'realtime', or None for external clock"""
17
+
18
+
19
+ @processor_state
20
+ class ClockState:
21
+ """State for clock generator."""
22
+
23
+ t_0: float = field(default_factory=time.time) # Start time
24
+ n_dispatch: int = 0 # Number of dispatches
25
+
26
+
27
+ class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]):
28
+ """
29
+ Produces clock ticks at specified rate.
30
+ Can be used to drive periodic operations.
31
+ """
32
+
33
+ def _reset_state(self) -> None:
34
+ """Reset internal state."""
35
+ self._state.t_0 = time.time()
36
+ self._state.n_dispatch = 0
37
+
38
+ def __call__(self) -> ez.Flag:
39
+ """Synchronous clock production. We override __call__ (which uses run_coroutine_sync)
40
+ to avoid async overhead."""
41
+ if self._hash == -1:
42
+ self._reset_state()
43
+ self._hash = 0
44
+
45
+ if isinstance(self.settings.dispatch_rate, (int, float)):
46
+ # Manual dispatch_rate. (else it is 'as fast as possible')
47
+ target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
48
+ now = time.time()
49
+ if target_time > now:
50
+ time.sleep(target_time - now)
51
+
52
+ self.state.n_dispatch += 1
53
+ return ez.Flag()
54
+
55
+ async def _produce(self) -> ez.Flag:
56
+ """Generate next clock tick."""
57
+ if isinstance(self.settings.dispatch_rate, (int, float)):
58
+ # Manual dispatch_rate. (else it is 'as fast as possible')
59
+ target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
60
+ now = time.time()
61
+ if target_time > now:
62
+ await asyncio.sleep(target_time - now)
63
+
64
+ self.state.n_dispatch += 1
65
+ return ez.Flag()
66
+
67
+
68
+ def aclock(dispatch_rate: float | None) -> ClockProducer:
69
+ """
70
+ Construct an async generator that yields events at a specified rate.
71
+
72
+ Returns:
73
+ A :obj:`ClockProducer` object.
74
+ """
75
+ return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate))
76
+
77
+
78
+ clock = aclock
79
+ """
80
+ Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async.
81
+ """
82
+
83
+
84
+ class Clock(
85
+ BaseProducerUnit[
86
+ ClockSettings, # SettingsType
87
+ ez.Flag, # MessageType
88
+ ClockProducer, # ProducerType
89
+ ]
90
+ ):
91
+ SETTINGS = ClockSettings
92
+
93
+ @ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
94
+ async def produce(self) -> typing.AsyncGenerator:
95
+ # Override so we can not to yield if out is False-like
96
+ while True:
97
+ out = await self.producer.__acall__()
98
+ if out:
99
+ yield self.OUTPUT_SIGNAL, out
@@ -0,0 +1,249 @@
1
+ """Cosine tuning model for neural encoding of velocity/movement.
2
+
3
+ Implements the offset model from "Decoding arm speed during reaching"
4
+ (https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/):
5
+
6
+ firing_rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
7
+
8
+ Where:
9
+ - b0: baseline firing rate
10
+ - m: directional modulation depth
11
+ - θ: velocity direction (angle)
12
+ - θ_pd: preferred direction
13
+ - bs: speed modulation (non-directional)
14
+ - |v|: velocity magnitude (speed)
15
+
16
+ For spike generation from firing rates, use EventsFromRatesTransformer
17
+ from ezmsg-event.
18
+ """
19
+
20
+ from dataclasses import dataclass
21
+ from pathlib import Path
22
+
23
+ import ezmsg.core as ez
24
+ import numpy as np
25
+ import numpy.typing as npt
26
+ from ezmsg.baseproc import (
27
+ BaseStatefulTransformer,
28
+ BaseTransformerUnit,
29
+ processor_state,
30
+ )
31
+ from ezmsg.util.messages.axisarray import AxisArray
32
+ from ezmsg.util.messages.util import replace
33
+
34
+
35
+ @dataclass
36
+ class CosineTuningParams:
37
+ """Parameters for cosine tuning model.
38
+
39
+ All arrays must have the same shape (n_units,).
40
+
41
+ Attributes:
42
+ b0: Baseline firing rate (Hz) for each unit.
43
+ m: Directional modulation depth for each unit.
44
+ pd: Preferred direction (radians) for each unit.
45
+ bs: Speed modulation (non-directional) for each unit.
46
+ """
47
+
48
+ b0: npt.NDArray[np.floating]
49
+ m: npt.NDArray[np.floating]
50
+ pd: npt.NDArray[np.floating]
51
+ bs: npt.NDArray[np.floating]
52
+
53
+ def __post_init__(self):
54
+ """Validate that all parameters have consistent shapes."""
55
+ if not (self.b0.shape == self.m.shape == self.pd.shape == self.bs.shape):
56
+ raise ValueError("All parameters must have the same shape")
57
+ if self.b0.ndim != 1:
58
+ raise ValueError("Parameters must be 1D arrays")
59
+ if len(self.b0) < 1:
60
+ raise ValueError("Parameters must have length >= 1")
61
+
62
+ @property
63
+ def n_units(self) -> int:
64
+ """Number of neural units."""
65
+ return len(self.b0)
66
+
67
+ @classmethod
68
+ def from_file(
69
+ cls,
70
+ filepath: str | Path,
71
+ n_units: int | None = None,
72
+ weight_gain: float = 1.0,
73
+ ) -> "CosineTuningParams":
74
+ """Load parameters from a .npz file.
75
+
76
+ Args:
77
+ filepath: Path to .npz file containing b0, m, pd, bs arrays.
78
+ n_units: Number of units to use. If None, uses all units in file.
79
+ weight_gain: Scaling factor applied to m and bs parameters.
80
+
81
+ Returns:
82
+ CosineTuningParams instance.
83
+ """
84
+ params = np.load(filepath)
85
+
86
+ b0 = np.asarray(params["b0"], dtype=np.float64)
87
+ m = np.asarray(params["m"], dtype=np.float64)
88
+ pd = np.asarray(params["pd"], dtype=np.float64)
89
+ bs = np.asarray(params["bs"], dtype=np.float64)
90
+
91
+ if n_units is not None:
92
+ b0 = b0[:n_units]
93
+ m = m[:n_units]
94
+ pd = pd[:n_units]
95
+ bs = bs[:n_units]
96
+
97
+ m = m * weight_gain
98
+ bs = bs * weight_gain
99
+
100
+ return cls(b0=b0, m=m, pd=pd, bs=bs)
101
+
102
+ @classmethod
103
+ def from_random(
104
+ cls,
105
+ n_units: int,
106
+ baseline_hz: float = 10.0,
107
+ modulation_hz: float = 20.0,
108
+ speed_modulation_hz: float = 0.0,
109
+ seed: int | None = None,
110
+ ) -> "CosineTuningParams":
111
+ """Generate random tuning parameters.
112
+
113
+ Args:
114
+ n_units: Number of neural units.
115
+ baseline_hz: Baseline firing rate (Hz) for all units.
116
+ modulation_hz: Directional modulation depth for all units.
117
+ speed_modulation_hz: Speed modulation (non-directional) for all units.
118
+ seed: Random seed for reproducibility.
119
+
120
+ Returns:
121
+ CosineTuningParams instance with random preferred directions.
122
+ """
123
+ rng = np.random.default_rng(seed)
124
+
125
+ return cls(
126
+ b0=np.full(n_units, baseline_hz, dtype=np.float64),
127
+ m=np.full(n_units, modulation_hz, dtype=np.float64),
128
+ pd=rng.uniform(0.0, 2.0 * np.pi, size=n_units).astype(np.float64),
129
+ bs=np.full(n_units, speed_modulation_hz, dtype=np.float64),
130
+ )
131
+
132
+
133
+ class CosineTuningSettings(ez.Settings):
134
+ """Settings for CosineTuningTransformer.
135
+
136
+ Either `model_file` OR the random generation parameters should be specified.
137
+ If `model_file` is provided, parameters are loaded from file.
138
+ Otherwise, parameters are randomly generated.
139
+ """
140
+
141
+ # File-based parameters
142
+ model_file: str | None = None
143
+ """Path to .npz file with tuning parameters (b0, m, pd, bs)."""
144
+
145
+ weight_gain: float = 1.0
146
+ """Scaling factor for m and bs when loading from file."""
147
+
148
+ # Random generation parameters
149
+ n_units: int = 50
150
+ """Number of neural units (used if model_file is None)."""
151
+
152
+ baseline_hz: float = 10.0
153
+ """Baseline firing rate in Hz (used if model_file is None)."""
154
+
155
+ modulation_hz: float = 20.0
156
+ """Directional modulation depth in Hz (used if model_file is None)."""
157
+
158
+ speed_modulation_hz: float = 0.0
159
+ """Speed modulation (non-directional) in Hz (used if model_file is None)."""
160
+
161
+ seed: int | None = None
162
+ """Random seed for reproducibility (used if model_file is None)."""
163
+
164
+ # Output settings
165
+ min_rate: float = 0.0
166
+ """Minimum firing rate (Hz). Rates are clipped to this value."""
167
+
168
+
169
+ @processor_state
170
+ class CosineTuningState:
171
+ """State for cosine tuning transformer."""
172
+
173
+ params: CosineTuningParams | None = None
174
+ """Tuning curve parameters."""
175
+
176
+
177
+ class CosineTuningTransformer(BaseStatefulTransformer[CosineTuningSettings, AxisArray, AxisArray, CosineTuningState]):
178
+ """Transform 2D velocity into firing rates using cosine tuning model.
179
+
180
+ Input: AxisArray with shape (n_samples, 2) containing velocity (vx, vy).
181
+ Output: AxisArray with shape (n_samples, n_units) containing firing rates (Hz).
182
+
183
+ The model implements:
184
+ rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
185
+
186
+ For spike generation, chain with EventsFromRatesTransformer from ezmsg-event.
187
+ """
188
+
189
+ def _reset_state(self, message: AxisArray) -> None:
190
+ """Initialize tuning parameters."""
191
+ if self.settings.model_file is not None:
192
+ self.state.params = CosineTuningParams.from_file(
193
+ self.settings.model_file,
194
+ n_units=None, # Use all units from file
195
+ weight_gain=self.settings.weight_gain,
196
+ )
197
+ else:
198
+ self.state.params = CosineTuningParams.from_random(
199
+ n_units=self.settings.n_units,
200
+ baseline_hz=self.settings.baseline_hz,
201
+ modulation_hz=self.settings.modulation_hz,
202
+ speed_modulation_hz=self.settings.speed_modulation_hz,
203
+ seed=self.settings.seed,
204
+ )
205
+
206
+ def _process(self, message: AxisArray) -> AxisArray:
207
+ """Transform velocity to firing rates."""
208
+ v = np.asarray(message.data, dtype=np.float64)
209
+
210
+ if v.ndim != 2 or v.shape[1] != 2:
211
+ raise ValueError(f"Expected velocity with shape (n_samples, 2), got {v.shape}")
212
+
213
+ # Extract velocity components
214
+ vx = v[:, 0]
215
+ vy = v[:, 1]
216
+
217
+ # Calculate speed (magnitude) and direction (angle)
218
+ speed = np.hypot(vx, vy)[:, np.newaxis] # (n_samples, 1)
219
+ theta = np.arctan2(vy, vx)[:, np.newaxis] # (n_samples, 1)
220
+
221
+ # Get parameters as row vectors for broadcasting
222
+ params = self.state.params
223
+ b0 = params.b0[np.newaxis, :] # (1, n_units)
224
+ m = params.m[np.newaxis, :] # (1, n_units)
225
+ pd = params.pd[np.newaxis, :] # (1, n_units)
226
+ bs = params.bs[np.newaxis, :] # (1, n_units)
227
+
228
+ # Compute firing rates: b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
229
+ rates = b0 + m * speed * np.cos(theta - pd) + bs * speed
230
+
231
+ # Clip to minimum rate
232
+ rates = np.maximum(rates, self.settings.min_rate)
233
+
234
+ # Create channel axis
235
+ ch_labels = np.array([f"unit{i}" for i in range(params.n_units)])
236
+ ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"])
237
+
238
+ return replace(
239
+ message,
240
+ data=rates,
241
+ dims=["time", "ch"],
242
+ axes={**message.axes, "ch": ch_axis},
243
+ )
244
+
245
+
246
+ class CosineTuningUnit(BaseTransformerUnit[CosineTuningSettings, AxisArray, AxisArray, CosineTuningTransformer]):
247
+ """Unit wrapper for CosineTuningTransformer."""
248
+
249
+ SETTINGS = CosineTuningSettings