aind-behavior-utils 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ """Init package"""
2
+
3
+ __version__ = "0.3.1"
@@ -0,0 +1 @@
1
+ """Plotting utilities for behavior data visualization."""
@@ -0,0 +1,60 @@
1
+ """Plotting utilities for behavior data."""
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.figure import Figure
8
+
9
+
10
+ def plot_array(
11
+ data_array: np.ndarray,
12
+ ylim: Optional[Tuple[float, float]] = None,
13
+ xlim: Optional[Tuple[float, float]] = None,
14
+ ylabel: str = "",
15
+ xlabel: str = "",
16
+ title: str = "",
17
+ aspect: Optional[Union[str, float]] = None,
18
+ ) -> Figure:
19
+ """Plot data, return the figure.
20
+
21
+ Parameters
22
+ ----------
23
+ data_array : numpy.ndarray
24
+ An array of data to plot.
25
+ ylim : Optional[Tuple[float, float]]
26
+ The y-axis limits as a tuple (lower, upper).
27
+ xlim : Optional[Tuple[float, float]]
28
+ The x-axis limits as a tuple (lower, upper).
29
+ ylabel : str
30
+ The y-axis label.
31
+ xlabel : str
32
+ The x-axis label.
33
+ title : str
34
+ The plot title.
35
+ aspect : Optional[Union[str, float]]
36
+ The aspect ratio of the plot.
37
+
38
+ Returns
39
+ -------
40
+ matplotlib.figure.Figure
41
+ A matplotlib figure.
42
+ """
43
+ fig, ax = plt.subplots()
44
+
45
+ if ylim:
46
+ ax.set_ylim(*ylim)
47
+ if xlim:
48
+ ax.set_xlim(*xlim)
49
+
50
+ ax.set_ylabel(ylabel)
51
+ ax.set_xlabel(xlabel)
52
+ ax.set_title(title)
53
+
54
+ ax.plot(data_array)
55
+
56
+ if aspect:
57
+ ax.set_aspect(aspect)
58
+
59
+ fig.tight_layout()
60
+ return fig
@@ -0,0 +1,5 @@
1
+ """Stimulus pickle file parsing and analysis utilities."""
2
+
3
+ from aind_behavior_utils.stimulus.camstim_dataset import CamstimDataset
4
+
5
+ __all__ = ["CamstimDataset"]
@@ -0,0 +1,267 @@
1
+ """Stimulus pickle file parsing utilities.
2
+
3
+ Provides utilities for loading and parsing Camstim stimulus pickle files,
4
+ extracting frame timing, wheel encoder data, and quality control metrics.
5
+
6
+ The primary interface is the :class:`CamstimDataset` class, which wraps
7
+ a loaded stimulus pickle dictionary and resolves its internal structure
8
+ (foraging vs behavior item groups) once at construction.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any, Dict, Optional
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ # Wheel encoder calibration constant (radius in cm)
19
+ WHEEL_RADIUS = 5.5036
20
+
21
+ # Default fallback frame rate (Hz) when not explicitly specified or computed
22
+ DEFAULT_FPS = 60.0
23
+
24
+ # Conversion factor from milliseconds to seconds
25
+ MS_TO_S = 0.001
26
+
27
+
28
+ class CamstimDataset:
29
+ """Wrapper around a camstim stimulus pickle file dictionary.
30
+
31
+ Resolves the internal data layout (``foraging`` vs ``behavior``
32
+ item groups) once at construction so that downstream accessors
33
+ never need to repeat the lookup.
34
+
35
+ Parameters
36
+ ----------
37
+ data : Dict[str, Any]
38
+ The loaded stimulus pickle file dictionary.
39
+
40
+ Examples
41
+ --------
42
+ >>> dset = CamstimDataset(pkl_data)
43
+ >>> print(dset.fps)
44
+ 60.0
45
+ >>> print(dset.stim_frame_count)
46
+ 120
47
+ >>> speed = dset.running_speed_array
48
+ >>> artifacts = dset.get_nb_wheel_artifacts()
49
+ """
50
+
51
+ def __init__(self, data: Dict[str, Any]) -> None:
52
+ """Initialize CamstimDataset."""
53
+ self.data = data
54
+ self._items: Optional[Dict[str, Any]] = self._resolve_items()
55
+
56
+ @classmethod
57
+ def from_file(cls, path: str) -> CamstimDataset:
58
+ """Load a pickle file and return a CamstimDataset instance.
59
+
60
+ Parameters
61
+ ----------
62
+ path : str
63
+ The path to the pickle file.
64
+
65
+ Returns
66
+ -------
67
+ CamstimDataset
68
+ A new instance wrapping the loaded data.
69
+ """
70
+ with open(path, "rb") as f:
71
+ data = pd.read_pickle(f)
72
+ return cls(data)
73
+
74
+ def _resolve_items(self) -> Optional[Dict[str, Any]]:
75
+ """Return the inner item-group dict (foraging or behavior).
76
+
77
+ Returns
78
+ -------
79
+ Optional[Dict[str, Any]]
80
+ The resolved item group dictionary, or None if not found.
81
+ """
82
+ items = self.data.get("items", {})
83
+ for group in ("foraging", "behavior"):
84
+ if group in items:
85
+ return items[group]
86
+ return None
87
+
88
+ @property
89
+ def fps(self) -> float:
90
+ """Frames per second.
91
+
92
+ Reads from the top-level ``"fps"`` key when available,
93
+ otherwise computes from ``intervalsms``. Falls back to
94
+ :data:`DEFAULT_FPS`.
95
+
96
+ Returns
97
+ -------
98
+ float
99
+ The frames per second.
100
+ """
101
+ try:
102
+ return float(self.data["fps"])
103
+ except KeyError:
104
+ if self._items is not None:
105
+ try:
106
+ mean_interval_ms = np.mean(self._items["intervalsms"])
107
+ return round(1 / (mean_interval_ms * MS_TO_S), 1)
108
+ except KeyError:
109
+ return DEFAULT_FPS
110
+ return DEFAULT_FPS
111
+
112
+ @property
113
+ def stage(self) -> Optional[str]:
114
+ """Stimulus stage name.
115
+
116
+ Returns
117
+ -------
118
+ Optional[str]
119
+ The stage name from ``data["params"]["stage"]``, or
120
+ ``None`` if absent.
121
+ """
122
+ try:
123
+ return self.data["params"]["stage"]
124
+ except (KeyError, TypeError):
125
+ return None
126
+
127
+ @property
128
+ def intervals_ms(self) -> list:
129
+ """Inter-frame intervals in milliseconds.
130
+
131
+ Returns
132
+ -------
133
+ list
134
+ The inter-frame intervals in milliseconds.
135
+
136
+ Raises
137
+ ------
138
+ KeyError
139
+ If no ``intervalsms`` data is found.
140
+ """
141
+ if self._items is not None and "intervalsms" in self._items:
142
+ return self._items["intervalsms"]
143
+ raise KeyError("Could not find intervalsms in pickle file.")
144
+
145
+ @property
146
+ def stim_frame_count(self) -> int:
147
+ """Number of stimulus frames.
148
+
149
+ The frame count is the length of the ``intervalsms`` array plus one,
150
+ since each interval sits between two frames.
151
+
152
+ Returns
153
+ -------
154
+ int
155
+ The number of stimulus frames.
156
+
157
+ Raises
158
+ ------
159
+ KeyError
160
+ If no ``intervalsms`` data is found.
161
+ """
162
+ return len(self.intervals_ms) + 1
163
+
164
+ @property
165
+ def running_speed_array(self) -> np.ndarray:
166
+ """Running speed in cm/s derived from wheel encoder data.
167
+
168
+ Locates the wheel encoder ``dx`` array, scales it by FPS and
169
+ wheel radius to compute instantaneous speed.
170
+
171
+ Returns
172
+ -------
173
+ numpy.ndarray
174
+ Array of running speed values in cm/s.
175
+
176
+ Raises
177
+ ------
178
+ KeyError
179
+ If no encoder ``dx`` data is found in the pickle file.
180
+ NotImplementedError
181
+ If the pickle file format is unrecognised.
182
+ """
183
+ speed_dtheta = self._resolve_encoder_dx()
184
+ return speed_dtheta * self.fps * (2 * np.pi * WHEEL_RADIUS / 360)
185
+
186
+ def _resolve_encoder_dx(self) -> np.ndarray:
187
+ """Locate the encoder ``dx`` array.
188
+
189
+ Searches under the resolved item group (foraging or behavior)
190
+ encoders, or at the top level.
191
+
192
+ Returns
193
+ -------
194
+ numpy.ndarray
195
+ The encoder ``dx`` array.
196
+
197
+ Raises
198
+ ------
199
+ KeyError
200
+ If no encoder ``dx`` data is found.
201
+ NotImplementedError
202
+ If the pickle file format is unrecognised.
203
+ """
204
+ if self._items is not None:
205
+ try:
206
+ return np.array(self._items["encoders"][0]["dx"])
207
+ except (KeyError, IndexError, TypeError):
208
+ raise KeyError(
209
+ "Could not find running speed data in pickle file."
210
+ )
211
+ if "dx" in self.data:
212
+ return np.array(self.data["dx"])
213
+ raise NotImplementedError(
214
+ "Encountered unknown format for stimulus pickle file."
215
+ )
216
+
217
+ def get_nb_wheel_artifacts(self, threshold: float = 100) -> int:
218
+ """Count speed values exceeding threshold.
219
+
220
+ Artifacts are defined as absolute speed values that exceed
221
+ the given threshold, typically indicating encoder glitches or
222
+ physical wheel slips.
223
+
224
+ Parameters
225
+ ----------
226
+ threshold : float, optional
227
+ Speed threshold in cm/s. Default is 100.
228
+
229
+ Returns
230
+ -------
231
+ int
232
+ The number of points exceeding the threshold.
233
+ """
234
+ return int(np.sum(np.abs(self.running_speed_array) > threshold))
235
+
236
+
237
+ def load_pkl_file(path: str) -> Dict[str, Any]:
238
+ """Load a stimulus pickle file and return the raw data dictionary.
239
+
240
+ Parameters
241
+ ----------
242
+ path : str
243
+ The path to the pickle file.
244
+
245
+ Returns
246
+ -------
247
+ Dict[str, Any]
248
+ The raw data dictionary from the pickle file.
249
+ """
250
+ with open(path, "rb") as f:
251
+ return pd.read_pickle(f)
252
+
253
+
254
+ def get_stim_frame_count(pkl_data: Dict[str, Any]) -> int:
255
+ """Get stimulus frame count from a raw pickle data dictionary.
256
+
257
+ Parameters
258
+ ----------
259
+ pkl_data : Dict[str, Any]
260
+ The raw stimulus pickle data dictionary.
261
+
262
+ Returns
263
+ -------
264
+ int
265
+ The number of stimulus frames.
266
+ """
267
+ return CamstimDataset(pkl_data).stim_frame_count
@@ -0,0 +1,131 @@
1
+ """Wheel QC image and metric calculation utilities.
2
+
3
+ This module provides high-level functions for quality control analysis of
4
+ wheel rotation data from stimulus pickle files. These are convenience wrappers
5
+ around the CamstimDataset class (data parsing and metrics) and the plotting
6
+ module (visualization).
7
+
8
+ The module focuses on wheel encoder data, which is typically stored in the
9
+ stimulus pickle file under items.foraging or items.behavior. Functions here
10
+ combine multiple core utilities to produce either visual plots or numerical
11
+ metrics suitable for QC assessment.
12
+ """
13
+
14
+ from typing import Any, Dict, Union
15
+
16
+
17
+ import aind_behavior_utils.plotting.plots as plots
18
+ from aind_behavior_utils.stimulus.camstim_dataset import CamstimDataset
19
+
20
+
21
+ def _resolve_pkl(
22
+ pkl_input: Union[str, Dict[str, Any], CamstimDataset],
23
+ ) -> CamstimDataset:
24
+ """Resolve pickle input to a CamstimDataset instance.
25
+
26
+ Parameters
27
+ ----------
28
+ pkl_input : Union[str, Dict[str, Any], CamstimDataset]
29
+ A file path, an already-loaded dictionary, or an existing
30
+ CamstimDataset instance.
31
+
32
+ Returns
33
+ -------
34
+ CamstimDataset
35
+ A CamstimDataset instance.
36
+ """
37
+ if isinstance(pkl_input, CamstimDataset):
38
+ return pkl_input
39
+ if isinstance(pkl_input, str):
40
+ return CamstimDataset.from_file(pkl_input)
41
+ return CamstimDataset(pkl_input)
42
+
43
+
44
+ def calculate_qc_images(
45
+ pkl_input: Union[str, Dict[str, Any], CamstimDataset],
46
+ ) -> dict:
47
+ """Calculate quality control images from stimulus pickle data.
48
+
49
+ This is a high-level wrapper that combines core utilities from
50
+ CamstimDataset and the plotting module to generate visual QC plots
51
+ for wheel rotation data.
52
+
53
+ Parameters
54
+ ----------
55
+ pkl_input : Union[str, Dict[str, Any], CamstimDataset]
56
+ Path to a stimulus pickle file, an already-loaded pickle data
57
+ dictionary, or an existing CamstimDataset instance.
58
+
59
+ Returns
60
+ -------
61
+ dict
62
+ Dictionary containing two matplotlib Figure objects:
63
+ - 'wheel_speed_plot': Line plot of instantaneous wheel
64
+ speed in cm/s across all stimulus frames.
65
+ - 'wheel_traveled_distance_plot': Line plot of cumulative
66
+ wheel traveled distance in meters.
67
+
68
+ Notes
69
+ -----
70
+ Uses the following utilities:
71
+ - CamstimDataset.running_speed_array: Converts encoder data to speed (cm/s)
72
+ - CamstimDataset.fps: Retrieves frame rate from stimulus metadata
73
+ - plotting.plot_array: Creates matplotlib figures
74
+ """
75
+ dset = _resolve_pkl(pkl_input)
76
+ running_speed_array = dset.running_speed_array
77
+ wheel_speed_plot = plots.plot_array(
78
+ running_speed_array,
79
+ xlabel="Frame #",
80
+ ylabel="Wheel Speed (cm/s)",
81
+ title="wheel_speed_plot",
82
+ )
83
+
84
+ wheel_travel_plot = plots.plot_array(
85
+ running_speed_array.cumsum() / (100 * dset.fps),
86
+ xlabel="Frame #",
87
+ ylabel="Traveled Distance (m)",
88
+ title="wheel_traveled_distance_plot",
89
+ )
90
+ return {
91
+ "wheel_speed_plot": wheel_speed_plot,
92
+ "wheel_traveled_distance_plot": wheel_travel_plot,
93
+ }
94
+
95
+
96
+ def calculate_qc_metrics(
97
+ pkl_input: Union[str, Dict[str, Any], CamstimDataset],
98
+ ) -> dict:
99
+ """Calculate quality control metrics from stimulus pickle data.
100
+
101
+ This is a high-level wrapper that combines core utilities from
102
+ CamstimDataset to compute numerical QC metrics for wheel rotation
103
+ data.
104
+
105
+ Parameters
106
+ ----------
107
+ pkl_input : Union[str, Dict[str, Any], CamstimDataset]
108
+ Path to a stimulus pickle file, an already-loaded pickle data
109
+ dictionary, or an existing CamstimDataset instance.
110
+
111
+ Returns
112
+ -------
113
+ dict
114
+ Dictionary containing QC metrics:
115
+ - 'wheel_artifacts': Count of abnormal speed values
116
+ (absolute value exceeding 100 cm/s), which may
117
+ indicate encoder glitches or physical wheel slips.
118
+
119
+ Notes
120
+ -----
121
+ Uses the following utilities:
122
+ - CamstimDataset.running_speed_array: Converts encoder data to speed (cm/s)
123
+ - CamstimDataset.get_nb_wheel_artifacts: Counts speed outliers
124
+
125
+ See Also
126
+ --------
127
+ calculate_qc_images : Generate visual QC plots
128
+ """
129
+ dset = _resolve_pkl(pkl_input)
130
+ metrics = {"wheel_artifacts": dset.get_nb_wheel_artifacts()}
131
+ return metrics
@@ -0,0 +1 @@
1
+ """Sync HDF5 dataset parsing and analysis utilities."""
@@ -0,0 +1,97 @@
1
+ """Utilities for resolving legacy sync line label variants.
2
+
3
+ Different experimental setups use different names for the same signal
4
+ (e.g. ``"stim_vsync"`` vs ``"vsync_stim"``). This module maps
5
+ canonical signal names to their known variants so that downstream code
6
+ can refer to signals by a single, stable name.
7
+ """
8
+
9
+ from typing import Dict, List, Optional
10
+
11
+ LINE_LABEL_VARIANTS: Dict[str, List[str]] = {
12
+ "behavior_monitoring": [
13
+ "behavior_monitoring",
14
+ "cam1_exposure",
15
+ "cam1",
16
+ "beh_cam_frame_readout",
17
+ ],
18
+ "eye_tracking": [
19
+ "eye_tracking",
20
+ "cam2_exposure",
21
+ "cam2",
22
+ "eye_cam_frame_readout",
23
+ ],
24
+ "face_tracking": ["face_tracking", "face_cam_frame_readout"],
25
+ "photodiode": ["photodiode", "stim_photodiode"],
26
+ "physio": ["2p_vsync", "vsync_2p"],
27
+ "visual_stim": ["stim_vsync", "vsync_stim"],
28
+ }
29
+
30
+
31
+ def build_line_label_map(
32
+ line_labels: List[str],
33
+ variants: Optional[Dict[str, List[str]]] = None,
34
+ ) -> Dict[str, str]:
35
+ """Build a mapping from canonical names to actual file labels.
36
+
37
+ For each canonical name, the first matching variant found in
38
+ ``line_labels`` is used.
39
+
40
+ Parameters
41
+ ----------
42
+ line_labels : List[str]
43
+ Line labels present in a sync file.
44
+ variants : Optional[Dict[str, List[str]]]
45
+ Canonical-name-to-variant-list mapping. Defaults to
46
+ ``LINE_LABEL_VARIANTS``.
47
+
48
+ Returns
49
+ -------
50
+ Dict[str, str]
51
+ Canonical name to the actual label found in ``line_labels``.
52
+ """
53
+ if variants is None:
54
+ variants = LINE_LABEL_VARIANTS
55
+ result: Dict[str, str] = {}
56
+ for canonical, variant_list in variants.items():
57
+ for variant in variant_list:
58
+ if variant in line_labels:
59
+ result[canonical] = variant
60
+ break
61
+ return result
62
+
63
+
64
+ def resolve_line_label(
65
+ line: str,
66
+ label_map: Dict[str, str],
67
+ line_labels: List[str],
68
+ ) -> str:
69
+ """Resolve a line name to the actual label in a sync file.
70
+
71
+ Accepts either a canonical name (looked up via ``label_map``) or a
72
+ direct label that exists in ``line_labels``.
73
+
74
+ Parameters
75
+ ----------
76
+ line : str
77
+ Canonical name or direct line label.
78
+ label_map : Dict[str, str]
79
+ Mapping from canonical names to actual labels, as returned by
80
+ :func:`build_line_label_map`.
81
+ line_labels : List[str]
82
+ Line labels present in the sync file.
83
+
84
+ Returns
85
+ -------
86
+ str
87
+ The actual line label found in ``line_labels``.
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If ``line`` cannot be resolved to a label in ``line_labels``.
93
+ """
94
+ resolved = label_map.get(line, line)
95
+ if resolved in line_labels:
96
+ return resolved
97
+ raise ValueError(f"'{line}' not found in line labels or label map.")