ezmsg-simbiophys 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ezmsg/simbiophys/__init__.py +125 -0
- ezmsg/simbiophys/__version__.py +34 -0
- ezmsg/simbiophys/_base.py +60 -0
- ezmsg/simbiophys/clock.py +99 -0
- ezmsg/simbiophys/cosine_tuning.py +249 -0
- ezmsg/simbiophys/counter.py +224 -0
- ezmsg/simbiophys/dynamic_colored_noise.py +352 -0
- ezmsg/simbiophys/eeg.py +73 -0
- ezmsg/simbiophys/noise.py +122 -0
- ezmsg/simbiophys/oscillator.py +133 -0
- ezmsg_simbiophys-1.0.0.dist-info/METADATA +46 -0
- ezmsg_simbiophys-1.0.0.dist-info/RECORD +14 -0
- ezmsg_simbiophys-1.0.0.dist-info/WHEEL +4 -0
- ezmsg_simbiophys-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""ezmsg-simbiophys: Signal simulation and synthesis for ezmsg."""
|
|
2
|
+
|
|
3
|
+
from .__version__ import __version__ as __version__
|
|
4
|
+
|
|
5
|
+
# Clock
|
|
6
|
+
from .clock import (
|
|
7
|
+
Clock,
|
|
8
|
+
ClockProducer,
|
|
9
|
+
ClockSettings,
|
|
10
|
+
ClockState,
|
|
11
|
+
aclock,
|
|
12
|
+
clock,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# Cosine Tuning
|
|
16
|
+
from .cosine_tuning import (
|
|
17
|
+
CosineTuningParams,
|
|
18
|
+
CosineTuningSettings,
|
|
19
|
+
CosineTuningState,
|
|
20
|
+
CosineTuningTransformer,
|
|
21
|
+
CosineTuningUnit,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Counter
|
|
25
|
+
from .counter import (
|
|
26
|
+
Counter,
|
|
27
|
+
CounterProducer,
|
|
28
|
+
CounterSettings,
|
|
29
|
+
CounterState,
|
|
30
|
+
acounter,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Dynamic Colored Noise
|
|
34
|
+
from .dynamic_colored_noise import (
|
|
35
|
+
ColoredNoiseFilterState,
|
|
36
|
+
DynamicColoredNoiseSettings,
|
|
37
|
+
DynamicColoredNoiseState,
|
|
38
|
+
DynamicColoredNoiseTransformer,
|
|
39
|
+
DynamicColoredNoiseUnit,
|
|
40
|
+
compute_kasdin_coefficients,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# EEG
|
|
44
|
+
from .eeg import (
|
|
45
|
+
EEGSynth,
|
|
46
|
+
EEGSynthSettings,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Noise
|
|
50
|
+
from .noise import (
|
|
51
|
+
NoiseSettings,
|
|
52
|
+
PinkNoise,
|
|
53
|
+
PinkNoiseProducer,
|
|
54
|
+
PinkNoiseSettings,
|
|
55
|
+
RandomGenerator,
|
|
56
|
+
RandomGeneratorSettings,
|
|
57
|
+
RandomTransformer,
|
|
58
|
+
WhiteNoise,
|
|
59
|
+
WhiteNoiseProducer,
|
|
60
|
+
WhiteNoiseSettings,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Oscillator
|
|
64
|
+
from .oscillator import (
|
|
65
|
+
Oscillator,
|
|
66
|
+
OscillatorProducer,
|
|
67
|
+
OscillatorSettings,
|
|
68
|
+
SinGenerator,
|
|
69
|
+
SinGeneratorSettings,
|
|
70
|
+
SinTransformer,
|
|
71
|
+
sin,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
__all__ = [
|
|
75
|
+
# Version
|
|
76
|
+
"__version__",
|
|
77
|
+
# Clock
|
|
78
|
+
"Clock",
|
|
79
|
+
"ClockProducer",
|
|
80
|
+
"ClockSettings",
|
|
81
|
+
"ClockState",
|
|
82
|
+
"aclock",
|
|
83
|
+
"clock",
|
|
84
|
+
# Counter
|
|
85
|
+
"Counter",
|
|
86
|
+
"CounterProducer",
|
|
87
|
+
"CounterSettings",
|
|
88
|
+
"CounterState",
|
|
89
|
+
"acounter",
|
|
90
|
+
# Oscillator
|
|
91
|
+
"Oscillator",
|
|
92
|
+
"OscillatorProducer",
|
|
93
|
+
"OscillatorSettings",
|
|
94
|
+
"SinGenerator",
|
|
95
|
+
"SinGeneratorSettings",
|
|
96
|
+
"SinTransformer",
|
|
97
|
+
"sin",
|
|
98
|
+
# Noise
|
|
99
|
+
"NoiseSettings",
|
|
100
|
+
"PinkNoise",
|
|
101
|
+
"PinkNoiseProducer",
|
|
102
|
+
"PinkNoiseSettings",
|
|
103
|
+
"RandomGenerator",
|
|
104
|
+
"RandomGeneratorSettings",
|
|
105
|
+
"RandomTransformer",
|
|
106
|
+
"WhiteNoise",
|
|
107
|
+
"WhiteNoiseProducer",
|
|
108
|
+
"WhiteNoiseSettings",
|
|
109
|
+
# EEG
|
|
110
|
+
"EEGSynth",
|
|
111
|
+
"EEGSynthSettings",
|
|
112
|
+
# Cosine Tuning
|
|
113
|
+
"CosineTuningParams",
|
|
114
|
+
"CosineTuningSettings",
|
|
115
|
+
"CosineTuningState",
|
|
116
|
+
"CosineTuningTransformer",
|
|
117
|
+
"CosineTuningUnit",
|
|
118
|
+
# Dynamic Colored Noise
|
|
119
|
+
"ColoredNoiseFilterState",
|
|
120
|
+
"DynamicColoredNoiseSettings",
|
|
121
|
+
"DynamicColoredNoiseState",
|
|
122
|
+
"DynamicColoredNoiseTransformer",
|
|
123
|
+
"DynamicColoredNoiseUnit",
|
|
124
|
+
"compute_kasdin_coefficients",
|
|
125
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '1.0.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 0)
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Base classes for counter-based producers."""
|
|
2
|
+
|
|
3
|
+
import traceback
|
|
4
|
+
import typing
|
|
5
|
+
|
|
6
|
+
import ezmsg.core as ez
|
|
7
|
+
from ezmsg.baseproc import BaseProducerUnit, MessageInType, MessageOutType, ProducerType, SettingsType
|
|
8
|
+
from ezmsg.baseproc.util.profile import profile_subpub
|
|
9
|
+
|
|
10
|
+
from .counter import CounterProducer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseCounterFirstProducerUnit(
|
|
14
|
+
BaseProducerUnit[SettingsType, MessageOutType, ProducerType],
|
|
15
|
+
typing.Generic[SettingsType, MessageInType, MessageOutType, ProducerType],
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Base class for units whose primary processor is a composite producer with a CounterProducer as the first
|
|
19
|
+
processor (producer) in the chain.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
INPUT_SIGNAL = ez.InputStream(MessageInType)
|
|
23
|
+
|
|
24
|
+
def create_producer(self):
|
|
25
|
+
super().create_producer()
|
|
26
|
+
|
|
27
|
+
def recurse_get_counter(proc) -> CounterProducer:
|
|
28
|
+
if hasattr(proc, "_procs"):
|
|
29
|
+
return recurse_get_counter(list(proc._procs.values())[0])
|
|
30
|
+
return proc
|
|
31
|
+
|
|
32
|
+
self._counter = recurse_get_counter(self.producer)
|
|
33
|
+
|
|
34
|
+
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
|
|
35
|
+
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
|
|
36
|
+
@profile_subpub(trace_oldest=False)
|
|
37
|
+
async def on_signal(self, _: ez.Flag):
|
|
38
|
+
if self.producer.settings.dispatch_rate == "ext_clock":
|
|
39
|
+
out = await self.producer.__acall__()
|
|
40
|
+
yield self.OUTPUT_SIGNAL, out
|
|
41
|
+
|
|
42
|
+
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
|
|
43
|
+
async def produce(self) -> typing.AsyncGenerator:
|
|
44
|
+
try:
|
|
45
|
+
counter_state = self._counter.state
|
|
46
|
+
while True:
|
|
47
|
+
# Once-only, enter the generator loop
|
|
48
|
+
await counter_state.new_generator.wait()
|
|
49
|
+
counter_state.new_generator.clear()
|
|
50
|
+
|
|
51
|
+
if self.producer.settings.dispatch_rate == "ext_clock":
|
|
52
|
+
# We shouldn't even be here. Cycle around and wait on the event again.
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# We are not using an external clock. Run the generator.
|
|
56
|
+
while not counter_state.new_generator.is_set():
|
|
57
|
+
out = await self.producer.__acall__()
|
|
58
|
+
yield self.OUTPUT_SIGNAL, out
|
|
59
|
+
except Exception:
|
|
60
|
+
ez.logger.info(traceback.format_exc())
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Clock generator for timing control."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
import typing
|
|
6
|
+
from dataclasses import field
|
|
7
|
+
|
|
8
|
+
import ezmsg.core as ez
|
|
9
|
+
from ezmsg.baseproc import BaseProducerUnit, BaseStatefulProducer, processor_state
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ClockSettings(ez.Settings):
|
|
13
|
+
"""Settings for clock generator."""
|
|
14
|
+
|
|
15
|
+
dispatch_rate: float | str | None = None
|
|
16
|
+
"""Dispatch rate in Hz, 'realtime', or None for external clock"""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@processor_state
|
|
20
|
+
class ClockState:
|
|
21
|
+
"""State for clock generator."""
|
|
22
|
+
|
|
23
|
+
t_0: float = field(default_factory=time.time) # Start time
|
|
24
|
+
n_dispatch: int = 0 # Number of dispatches
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ClockProducer(BaseStatefulProducer[ClockSettings, ez.Flag, ClockState]):
|
|
28
|
+
"""
|
|
29
|
+
Produces clock ticks at specified rate.
|
|
30
|
+
Can be used to drive periodic operations.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def _reset_state(self) -> None:
|
|
34
|
+
"""Reset internal state."""
|
|
35
|
+
self._state.t_0 = time.time()
|
|
36
|
+
self._state.n_dispatch = 0
|
|
37
|
+
|
|
38
|
+
def __call__(self) -> ez.Flag:
|
|
39
|
+
"""Synchronous clock production. We override __call__ (which uses run_coroutine_sync)
|
|
40
|
+
to avoid async overhead."""
|
|
41
|
+
if self._hash == -1:
|
|
42
|
+
self._reset_state()
|
|
43
|
+
self._hash = 0
|
|
44
|
+
|
|
45
|
+
if isinstance(self.settings.dispatch_rate, (int, float)):
|
|
46
|
+
# Manual dispatch_rate. (else it is 'as fast as possible')
|
|
47
|
+
target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
|
|
48
|
+
now = time.time()
|
|
49
|
+
if target_time > now:
|
|
50
|
+
time.sleep(target_time - now)
|
|
51
|
+
|
|
52
|
+
self.state.n_dispatch += 1
|
|
53
|
+
return ez.Flag()
|
|
54
|
+
|
|
55
|
+
async def _produce(self) -> ez.Flag:
|
|
56
|
+
"""Generate next clock tick."""
|
|
57
|
+
if isinstance(self.settings.dispatch_rate, (int, float)):
|
|
58
|
+
# Manual dispatch_rate. (else it is 'as fast as possible')
|
|
59
|
+
target_time = self.state.t_0 + (self.state.n_dispatch + 1) / self.settings.dispatch_rate
|
|
60
|
+
now = time.time()
|
|
61
|
+
if target_time > now:
|
|
62
|
+
await asyncio.sleep(target_time - now)
|
|
63
|
+
|
|
64
|
+
self.state.n_dispatch += 1
|
|
65
|
+
return ez.Flag()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def aclock(dispatch_rate: float | None) -> ClockProducer:
|
|
69
|
+
"""
|
|
70
|
+
Construct an async generator that yields events at a specified rate.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A :obj:`ClockProducer` object.
|
|
74
|
+
"""
|
|
75
|
+
return ClockProducer(ClockSettings(dispatch_rate=dispatch_rate))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
clock = aclock
|
|
79
|
+
"""
|
|
80
|
+
Alias for :obj:`aclock` expected by synchronous methods. `ClockProducer` can be used in sync or async.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Clock(
|
|
85
|
+
BaseProducerUnit[
|
|
86
|
+
ClockSettings, # SettingsType
|
|
87
|
+
ez.Flag, # MessageType
|
|
88
|
+
ClockProducer, # ProducerType
|
|
89
|
+
]
|
|
90
|
+
):
|
|
91
|
+
SETTINGS = ClockSettings
|
|
92
|
+
|
|
93
|
+
@ez.publisher(BaseProducerUnit.OUTPUT_SIGNAL)
|
|
94
|
+
async def produce(self) -> typing.AsyncGenerator:
|
|
95
|
+
# Override so we can not to yield if out is False-like
|
|
96
|
+
while True:
|
|
97
|
+
out = await self.producer.__acall__()
|
|
98
|
+
if out:
|
|
99
|
+
yield self.OUTPUT_SIGNAL, out
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Cosine tuning model for neural encoding of velocity/movement.
|
|
2
|
+
|
|
3
|
+
Implements the offset model from "Decoding arm speed during reaching"
|
|
4
|
+
(https://ncbi.nlm.nih.gov/pmc/articles/PMC6286377/):
|
|
5
|
+
|
|
6
|
+
firing_rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
|
|
7
|
+
|
|
8
|
+
Where:
|
|
9
|
+
- b0: baseline firing rate
|
|
10
|
+
- m: directional modulation depth
|
|
11
|
+
- θ: velocity direction (angle)
|
|
12
|
+
- θ_pd: preferred direction
|
|
13
|
+
- bs: speed modulation (non-directional)
|
|
14
|
+
- |v|: velocity magnitude (speed)
|
|
15
|
+
|
|
16
|
+
For spike generation from firing rates, use EventsFromRatesTransformer
|
|
17
|
+
from ezmsg-event.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import ezmsg.core as ez
|
|
24
|
+
import numpy as np
|
|
25
|
+
import numpy.typing as npt
|
|
26
|
+
from ezmsg.baseproc import (
|
|
27
|
+
BaseStatefulTransformer,
|
|
28
|
+
BaseTransformerUnit,
|
|
29
|
+
processor_state,
|
|
30
|
+
)
|
|
31
|
+
from ezmsg.util.messages.axisarray import AxisArray
|
|
32
|
+
from ezmsg.util.messages.util import replace
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class CosineTuningParams:
|
|
37
|
+
"""Parameters for cosine tuning model.
|
|
38
|
+
|
|
39
|
+
All arrays must have the same shape (n_units,).
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
b0: Baseline firing rate (Hz) for each unit.
|
|
43
|
+
m: Directional modulation depth for each unit.
|
|
44
|
+
pd: Preferred direction (radians) for each unit.
|
|
45
|
+
bs: Speed modulation (non-directional) for each unit.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
b0: npt.NDArray[np.floating]
|
|
49
|
+
m: npt.NDArray[np.floating]
|
|
50
|
+
pd: npt.NDArray[np.floating]
|
|
51
|
+
bs: npt.NDArray[np.floating]
|
|
52
|
+
|
|
53
|
+
def __post_init__(self):
|
|
54
|
+
"""Validate that all parameters have consistent shapes."""
|
|
55
|
+
if not (self.b0.shape == self.m.shape == self.pd.shape == self.bs.shape):
|
|
56
|
+
raise ValueError("All parameters must have the same shape")
|
|
57
|
+
if self.b0.ndim != 1:
|
|
58
|
+
raise ValueError("Parameters must be 1D arrays")
|
|
59
|
+
if len(self.b0) < 1:
|
|
60
|
+
raise ValueError("Parameters must have length >= 1")
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def n_units(self) -> int:
|
|
64
|
+
"""Number of neural units."""
|
|
65
|
+
return len(self.b0)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_file(
|
|
69
|
+
cls,
|
|
70
|
+
filepath: str | Path,
|
|
71
|
+
n_units: int | None = None,
|
|
72
|
+
weight_gain: float = 1.0,
|
|
73
|
+
) -> "CosineTuningParams":
|
|
74
|
+
"""Load parameters from a .npz file.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
filepath: Path to .npz file containing b0, m, pd, bs arrays.
|
|
78
|
+
n_units: Number of units to use. If None, uses all units in file.
|
|
79
|
+
weight_gain: Scaling factor applied to m and bs parameters.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
CosineTuningParams instance.
|
|
83
|
+
"""
|
|
84
|
+
params = np.load(filepath)
|
|
85
|
+
|
|
86
|
+
b0 = np.asarray(params["b0"], dtype=np.float64)
|
|
87
|
+
m = np.asarray(params["m"], dtype=np.float64)
|
|
88
|
+
pd = np.asarray(params["pd"], dtype=np.float64)
|
|
89
|
+
bs = np.asarray(params["bs"], dtype=np.float64)
|
|
90
|
+
|
|
91
|
+
if n_units is not None:
|
|
92
|
+
b0 = b0[:n_units]
|
|
93
|
+
m = m[:n_units]
|
|
94
|
+
pd = pd[:n_units]
|
|
95
|
+
bs = bs[:n_units]
|
|
96
|
+
|
|
97
|
+
m = m * weight_gain
|
|
98
|
+
bs = bs * weight_gain
|
|
99
|
+
|
|
100
|
+
return cls(b0=b0, m=m, pd=pd, bs=bs)
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def from_random(
|
|
104
|
+
cls,
|
|
105
|
+
n_units: int,
|
|
106
|
+
baseline_hz: float = 10.0,
|
|
107
|
+
modulation_hz: float = 20.0,
|
|
108
|
+
speed_modulation_hz: float = 0.0,
|
|
109
|
+
seed: int | None = None,
|
|
110
|
+
) -> "CosineTuningParams":
|
|
111
|
+
"""Generate random tuning parameters.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
n_units: Number of neural units.
|
|
115
|
+
baseline_hz: Baseline firing rate (Hz) for all units.
|
|
116
|
+
modulation_hz: Directional modulation depth for all units.
|
|
117
|
+
speed_modulation_hz: Speed modulation (non-directional) for all units.
|
|
118
|
+
seed: Random seed for reproducibility.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
CosineTuningParams instance with random preferred directions.
|
|
122
|
+
"""
|
|
123
|
+
rng = np.random.default_rng(seed)
|
|
124
|
+
|
|
125
|
+
return cls(
|
|
126
|
+
b0=np.full(n_units, baseline_hz, dtype=np.float64),
|
|
127
|
+
m=np.full(n_units, modulation_hz, dtype=np.float64),
|
|
128
|
+
pd=rng.uniform(0.0, 2.0 * np.pi, size=n_units).astype(np.float64),
|
|
129
|
+
bs=np.full(n_units, speed_modulation_hz, dtype=np.float64),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CosineTuningSettings(ez.Settings):
|
|
134
|
+
"""Settings for CosineTuningTransformer.
|
|
135
|
+
|
|
136
|
+
Either `model_file` OR the random generation parameters should be specified.
|
|
137
|
+
If `model_file` is provided, parameters are loaded from file.
|
|
138
|
+
Otherwise, parameters are randomly generated.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
# File-based parameters
|
|
142
|
+
model_file: str | None = None
|
|
143
|
+
"""Path to .npz file with tuning parameters (b0, m, pd, bs)."""
|
|
144
|
+
|
|
145
|
+
weight_gain: float = 1.0
|
|
146
|
+
"""Scaling factor for m and bs when loading from file."""
|
|
147
|
+
|
|
148
|
+
# Random generation parameters
|
|
149
|
+
n_units: int = 50
|
|
150
|
+
"""Number of neural units (used if model_file is None)."""
|
|
151
|
+
|
|
152
|
+
baseline_hz: float = 10.0
|
|
153
|
+
"""Baseline firing rate in Hz (used if model_file is None)."""
|
|
154
|
+
|
|
155
|
+
modulation_hz: float = 20.0
|
|
156
|
+
"""Directional modulation depth in Hz (used if model_file is None)."""
|
|
157
|
+
|
|
158
|
+
speed_modulation_hz: float = 0.0
|
|
159
|
+
"""Speed modulation (non-directional) in Hz (used if model_file is None)."""
|
|
160
|
+
|
|
161
|
+
seed: int | None = None
|
|
162
|
+
"""Random seed for reproducibility (used if model_file is None)."""
|
|
163
|
+
|
|
164
|
+
# Output settings
|
|
165
|
+
min_rate: float = 0.0
|
|
166
|
+
"""Minimum firing rate (Hz). Rates are clipped to this value."""
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@processor_state
|
|
170
|
+
class CosineTuningState:
|
|
171
|
+
"""State for cosine tuning transformer."""
|
|
172
|
+
|
|
173
|
+
params: CosineTuningParams | None = None
|
|
174
|
+
"""Tuning curve parameters."""
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class CosineTuningTransformer(BaseStatefulTransformer[CosineTuningSettings, AxisArray, AxisArray, CosineTuningState]):
|
|
178
|
+
"""Transform 2D velocity into firing rates using cosine tuning model.
|
|
179
|
+
|
|
180
|
+
Input: AxisArray with shape (n_samples, 2) containing velocity (vx, vy).
|
|
181
|
+
Output: AxisArray with shape (n_samples, n_units) containing firing rates (Hz).
|
|
182
|
+
|
|
183
|
+
The model implements:
|
|
184
|
+
rate = b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
|
|
185
|
+
|
|
186
|
+
For spike generation, chain with EventsFromRatesTransformer from ezmsg-event.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def _reset_state(self, message: AxisArray) -> None:
|
|
190
|
+
"""Initialize tuning parameters."""
|
|
191
|
+
if self.settings.model_file is not None:
|
|
192
|
+
self.state.params = CosineTuningParams.from_file(
|
|
193
|
+
self.settings.model_file,
|
|
194
|
+
n_units=None, # Use all units from file
|
|
195
|
+
weight_gain=self.settings.weight_gain,
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
self.state.params = CosineTuningParams.from_random(
|
|
199
|
+
n_units=self.settings.n_units,
|
|
200
|
+
baseline_hz=self.settings.baseline_hz,
|
|
201
|
+
modulation_hz=self.settings.modulation_hz,
|
|
202
|
+
speed_modulation_hz=self.settings.speed_modulation_hz,
|
|
203
|
+
seed=self.settings.seed,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def _process(self, message: AxisArray) -> AxisArray:
|
|
207
|
+
"""Transform velocity to firing rates."""
|
|
208
|
+
v = np.asarray(message.data, dtype=np.float64)
|
|
209
|
+
|
|
210
|
+
if v.ndim != 2 or v.shape[1] != 2:
|
|
211
|
+
raise ValueError(f"Expected velocity with shape (n_samples, 2), got {v.shape}")
|
|
212
|
+
|
|
213
|
+
# Extract velocity components
|
|
214
|
+
vx = v[:, 0]
|
|
215
|
+
vy = v[:, 1]
|
|
216
|
+
|
|
217
|
+
# Calculate speed (magnitude) and direction (angle)
|
|
218
|
+
speed = np.hypot(vx, vy)[:, np.newaxis] # (n_samples, 1)
|
|
219
|
+
theta = np.arctan2(vy, vx)[:, np.newaxis] # (n_samples, 1)
|
|
220
|
+
|
|
221
|
+
# Get parameters as row vectors for broadcasting
|
|
222
|
+
params = self.state.params
|
|
223
|
+
b0 = params.b0[np.newaxis, :] # (1, n_units)
|
|
224
|
+
m = params.m[np.newaxis, :] # (1, n_units)
|
|
225
|
+
pd = params.pd[np.newaxis, :] # (1, n_units)
|
|
226
|
+
bs = params.bs[np.newaxis, :] # (1, n_units)
|
|
227
|
+
|
|
228
|
+
# Compute firing rates: b0 + m * |v| * cos(θ - θ_pd) + bs * |v|
|
|
229
|
+
rates = b0 + m * speed * np.cos(theta - pd) + bs * speed
|
|
230
|
+
|
|
231
|
+
# Clip to minimum rate
|
|
232
|
+
rates = np.maximum(rates, self.settings.min_rate)
|
|
233
|
+
|
|
234
|
+
# Create channel axis
|
|
235
|
+
ch_labels = np.array([f"unit{i}" for i in range(params.n_units)])
|
|
236
|
+
ch_axis = AxisArray.CoordinateAxis(data=ch_labels, dims=["ch"])
|
|
237
|
+
|
|
238
|
+
return replace(
|
|
239
|
+
message,
|
|
240
|
+
data=rates,
|
|
241
|
+
dims=["time", "ch"],
|
|
242
|
+
axes={**message.axes, "ch": ch_axis},
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class CosineTuningUnit(BaseTransformerUnit[CosineTuningSettings, AxisArray, AxisArray, CosineTuningTransformer]):
|
|
247
|
+
"""Unit wrapper for CosineTuningTransformer."""
|
|
248
|
+
|
|
249
|
+
SETTINGS = CosineTuningSettings
|