ethograph 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ethograph might be problematic. Click here for more details.
- ethograph/__init__.py +115 -0
- ethograph/__main__.py +3 -0
- ethograph/assets/icon.ico +0 -0
- ethograph/assets/icon.png +0 -0
- ethograph/assets/menu.json +27 -0
- ethograph/cli.py +71 -0
- ethograph/crowlab/io_matlab.py +264 -0
- ethograph/crowlab/legacy.py +389 -0
- ethograph/datasets.py +314 -0
- ethograph/features/__init__.py +0 -0
- ethograph/features/audio_changepoints.py +176 -0
- ethograph/features/changepoints.py +501 -0
- ethograph/features/energy.py +389 -0
- ethograph/features/movement.py +509 -0
- ethograph/features/neural.py +219 -0
- ethograph/features/oscillatory.py +55 -0
- ethograph/features/preprocessing.py +184 -0
- ethograph/gui/__init__.py +35 -0
- ethograph/gui/app_constants.py +157 -0
- ethograph/gui/app_state.py +1173 -0
- ethograph/gui/audio_player.py +153 -0
- ethograph/gui/dialog_busy_progress.py +119 -0
- ethograph/gui/dialog_function_params.py +1065 -0
- ethograph/gui/dialog_pose_video_matcher.py +289 -0
- ethograph/gui/dialog_screen_recorder.py +419 -0
- ethograph/gui/dialog_select_template.py +269 -0
- ethograph/gui/dialog_video_downsample.py +292 -0
- ethograph/gui/label_drawing_mixin.py +407 -0
- ethograph/gui/make_pretty.py +290 -0
- ethograph/gui/napari.yaml +12 -0
- ethograph/gui/notify.py +56 -0
- ethograph/gui/plots_audiotrace.py +208 -0
- ethograph/gui/plots_base.py +416 -0
- ethograph/gui/plots_container.py +1064 -0
- ethograph/gui/plots_ephystrace.py +1687 -0
- ethograph/gui/plots_heatmap.py +543 -0
- ethograph/gui/plots_lineplot.py +412 -0
- ethograph/gui/plots_overlay.py +316 -0
- ethograph/gui/plots_psth.py +409 -0
- ethograph/gui/plots_raster.py +283 -0
- ethograph/gui/plots_space.py +942 -0
- ethograph/gui/plots_spectrogram.py +323 -0
- ethograph/gui/pose_render.py +664 -0
- ethograph/gui/shortcuts.py +302 -0
- ethograph/gui/templates/wizard_nwb_codegen.j2 +109 -0
- ethograph/gui/video_manager.py +602 -0
- ethograph/gui/video_sync.py +261 -0
- ethograph/gui/widget_trials.py +444 -0
- ethograph/gui/widgets_changepoints.py +1426 -0
- ethograph/gui/widgets_data.py +2602 -0
- ethograph/gui/widgets_ephys.py +2603 -0
- ethograph/gui/widgets_help.py +257 -0
- ethograph/gui/widgets_io.py +1765 -0
- ethograph/gui/widgets_labels.py +1366 -0
- ethograph/gui/widgets_meta.py +515 -0
- ethograph/gui/widgets_navigation.py +947 -0
- ethograph/gui/widgets_plot_settings.py +833 -0
- ethograph/gui/widgets_psth.py +757 -0
- ethograph/gui/widgets_transform.py +53 -0
- ethograph/gui/wizard_boris.py +317 -0
- ethograph/gui/wizard_media_files.py +1124 -0
- ethograph/gui/wizard_multi_builder.py +186 -0
- ethograph/gui/wizard_multi_codegen.py +231 -0
- ethograph/gui/wizard_multi_tabs.py +1121 -0
- ethograph/gui/wizard_multi_timeline.py +846 -0
- ethograph/gui/wizard_multi_trials.py +564 -0
- ethograph/gui/wizard_nwb.py +508 -0
- ethograph/gui/wizard_overview.py +565 -0
- ethograph/gui/wizard_single.py +1149 -0
- ethograph/io/catalog.py +953 -0
- ethograph/io/data_loader.py +670 -0
- ethograph/io/dataset.py +282 -0
- ethograph/io/metadata_table.py +279 -0
- ethograph/io/nwb_alignment.py +1367 -0
- ethograph/io/nwb_import.py +170 -0
- ethograph/io/plot_sources.py +511 -0
- ethograph/io/pynapple.py +226 -0
- ethograph/io/time_model.py +512 -0
- ethograph/io/time_sources.py +131 -0
- ethograph/io/trialtree.py +572 -0
- ethograph/io/validation.py +289 -0
- ethograph/labels/__init__.py +11 -0
- ethograph/labels/boris.py +323 -0
- ethograph/labels/converters.py +425 -0
- ethograph/labels/crowsetta_format.py +120 -0
- ethograph/labels/export.py +145 -0
- ethograph/labels/intervals.py +708 -0
- ethograph/labels/ml.py +505 -0
- ethograph/labels/plots.py +336 -0
- ethograph/labels/predictions.py +257 -0
- ethograph/labels/tsv_store.py +270 -0
- ethograph/model/batch_gen.py +135 -0
- ethograph/model/cetnet_encoder.py +589 -0
- ethograph/model/dataset.py +302 -0
- ethograph/model/eval_metrics.py +244 -0
- ethograph/model/eval_plotting.py +465 -0
- ethograph/shortcuts.py +126 -0
- ethograph/utils/__init__.py +0 -0
- ethograph/utils/arraytools.py +361 -0
- ethograph/utils/audio.py +65 -0
- ethograph/utils/download.py +625 -0
- ethograph/utils/nwb.py +156 -0
- ethograph/utils/paths.py +277 -0
- ethograph/utils/qt.py +198 -0
- ethograph/utils/sequences.py +129 -0
- ethograph/utils/stream_durations.py +168 -0
- ethograph/utils/xr_utils.py +165 -0
- ethograph/video_features/base_extractor.py +122 -0
- ethograph/video_features/checkpoint/S3D_kinetics400_torchified.pt +0 -0
- ethograph/video_features/extract_s3d.py +117 -0
- ethograph/video_features/s3d.py +357 -0
- ethograph/video_features/s3d.yml +16 -0
- ethograph/video_features/transforms.py +309 -0
- ethograph/video_features/utils.py +168 -0
- ethograph-0.1.3.dist-info/METADATA +92 -0
- ethograph-0.1.3.dist-info/RECORD +120 -0
- ethograph-0.1.3.dist-info/WHEEL +5 -0
- ethograph-0.1.3.dist-info/entry_points.txt +8 -0
- ethograph-0.1.3.dist-info/licenses/LICENSE +28 -0
- ethograph-0.1.3.dist-info/top_level.txt +1 -0
ethograph/__init__.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
__version__ = version("ethograph")
|
|
5
|
+
except PackageNotFoundError:
|
|
6
|
+
# package is not installed
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from ethograph.io.trialtree import TrialTree
|
|
11
|
+
from ethograph.io.dataset import (
|
|
12
|
+
add_angle_rgb_to_ds,
|
|
13
|
+
add_changepoints_to_ds,
|
|
14
|
+
downsample_trialtree,
|
|
15
|
+
)
|
|
16
|
+
from ethograph.utils.xr_utils import get_ds_duration, get_time_coord, sel_valid
|
|
17
|
+
from ethograph.utils.paths import get_project_root
|
|
18
|
+
from ethograph.io.catalog import (
|
|
19
|
+
DataCatalog,
|
|
20
|
+
DataLoader,
|
|
21
|
+
PlotData,
|
|
22
|
+
PynappleLoader,
|
|
23
|
+
XarrayLoader,
|
|
24
|
+
catalog_from_pynapple,
|
|
25
|
+
catalog_from_xarray,
|
|
26
|
+
)
|
|
27
|
+
from ethograph.io.pynapple import load_nap_data
|
|
28
|
+
from ethograph.io.nwb_alignment import (
|
|
29
|
+
NWBAlignment,
|
|
30
|
+
align_media_per_trial,
|
|
31
|
+
align_media_from_streams,
|
|
32
|
+
)
|
|
33
|
+
from ethograph.datasets import sample_data
|
|
34
|
+
from ethograph.io.time_model import (
|
|
35
|
+
SourceCollection,
|
|
36
|
+
TimeRange,
|
|
37
|
+
TimeSource,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def open(path: str) -> TrialTree:
|
|
42
|
+
"""Load a TrialTree from a saved NetCDF file.
|
|
43
|
+
|
|
44
|
+
Shorthand for :meth:`TrialTree.open <ethograph.io.trialtree.TrialTree.open>`.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
path : str or Path
|
|
49
|
+
Path to a ``.nc`` file previously saved with ``dt.save()``.
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
TrialTree
|
|
54
|
+
|
|
55
|
+
Examples
|
|
56
|
+
--------
|
|
57
|
+
>>> import ethograph as eto
|
|
58
|
+
>>> dt = eto.open("experiment.nc")
|
|
59
|
+
>>> dt.trials
|
|
60
|
+
[1, 2, 3]
|
|
61
|
+
>>> ds = dt.itrial(0)
|
|
62
|
+
"""
|
|
63
|
+
return TrialTree.open(path)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def from_datasets(datasets: list) -> TrialTree:
|
|
67
|
+
"""Build a TrialTree from a list of per-trial xarray Datasets.
|
|
68
|
+
|
|
69
|
+
Shorthand for :meth:`TrialTree.from_datasets <ethograph.io.trialtree.TrialTree.from_datasets>`.
|
|
70
|
+
|
|
71
|
+
Each dataset must have ``attrs["trial"]`` set to a unique trial
|
|
72
|
+
identifier.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
datasets : list[xarray.Dataset]
|
|
77
|
+
One Dataset per trial.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
TrialTree
|
|
82
|
+
|
|
83
|
+
Examples
|
|
84
|
+
--------
|
|
85
|
+
>>> import xarray as xr, numpy as np, ethograph as eto
|
|
86
|
+
>>> trials = []
|
|
87
|
+
>>> for i in range(1, 4):
|
|
88
|
+
... ds = xr.Dataset({"speed": xr.DataArray(np.random.rand(300), dims=["time"])})
|
|
89
|
+
... ds.attrs["trial"] = i
|
|
90
|
+
... trials.append(ds)
|
|
91
|
+
>>> dt = eto.from_datasets(trials)
|
|
92
|
+
>>> dt.trials
|
|
93
|
+
[1, 2, 3]
|
|
94
|
+
"""
|
|
95
|
+
return TrialTree.from_datasets(datasets)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def from_continuous(ds, epochs) -> TrialTree:
|
|
99
|
+
"""Build a TrialTree from a single continuous recording + trial epochs.
|
|
100
|
+
|
|
101
|
+
Shorthand for :meth:`TrialTree.from_continuous`.
|
|
102
|
+
|
|
103
|
+
Parameters
|
|
104
|
+
----------
|
|
105
|
+
ds : xarray.Dataset
|
|
106
|
+
Full recording dataset.
|
|
107
|
+
epochs : pandas.DataFrame or pynapple.IntervalSet
|
|
108
|
+
Trial boundaries. DataFrame must have columns ``trial``,
|
|
109
|
+
``start_time``, ``stop_time``.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
TrialTree
|
|
114
|
+
"""
|
|
115
|
+
return TrialTree.from_continuous(ds, epochs)
|
ethograph/__main__.py
ADDED
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
{
|
|
2
|
+
"$schema": "https://json-schema.org/draft-07/schema",
|
|
3
|
+
"$id": "https://schemas.conda.io/menuinst-1.schema.json",
|
|
4
|
+
"menu_name": "ethograph",
|
|
5
|
+
"menu_items": [
|
|
6
|
+
{
|
|
7
|
+
"name": "ethograph",
|
|
8
|
+
"command": ["{{ PYTHON }}", "-m", "ethograph", "launch"],
|
|
9
|
+
"activate": true,
|
|
10
|
+
"terminal": true,
|
|
11
|
+
"platforms": {
|
|
12
|
+
"win": {
|
|
13
|
+
"desktop": true,
|
|
14
|
+
"icon": "{{ MENU_DIR }}/icon.ico",
|
|
15
|
+
"command": ["{{ PYTHON }}", "-m", "ethograph", "launch", "--keep-open"]
|
|
16
|
+
},
|
|
17
|
+
"linux": {
|
|
18
|
+
"Categories": ["Science", "Education"],
|
|
19
|
+
"icon": "{{ MENU_DIR }}/icon.png"
|
|
20
|
+
},
|
|
21
|
+
"osx": {
|
|
22
|
+
"icon": "{{ MENU_DIR }}/icon.png"
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
]
|
|
27
|
+
}
|
ethograph/cli.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""Command-line interface for ethograph."""
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
# Suppress noisy dependency warnings before any imports trigger them
|
|
10
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"logging")
|
|
11
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"vispy\.")
|
|
12
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"numpy\.")
|
|
13
|
+
|
|
14
|
+
# PyOpenGL info message goes through logging, not warnings
|
|
15
|
+
logging.getLogger("OpenGL.acceleratesupport").setLevel(logging.WARNING)
|
|
16
|
+
|
|
17
|
+
def _ensure_qt_plugins():
|
|
18
|
+
"""Set QT_PLUGIN_PATH for conda-forge Qt installs (needed by menuinst shortcuts)."""
|
|
19
|
+
if os.environ.get("QT_PLUGIN_PATH"):
|
|
20
|
+
return
|
|
21
|
+
candidates = [
|
|
22
|
+
os.path.join(sys.prefix, "Library", "plugins"), # Windows conda-forge
|
|
23
|
+
os.path.join(sys.prefix, "lib", "qt5", "plugins"), # Linux conda-forge
|
|
24
|
+
os.path.join(sys.prefix, "lib", "qt", "plugins"), # macOS conda-forge
|
|
25
|
+
]
|
|
26
|
+
for path in candidates:
|
|
27
|
+
if os.path.isdir(os.path.join(path, "platforms")):
|
|
28
|
+
os.environ["QT_PLUGIN_PATH"] = path
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def launch():
|
|
33
|
+
"""Launch the ethograph GUI."""
|
|
34
|
+
logging.basicConfig(
|
|
35
|
+
level=logging.INFO,
|
|
36
|
+
format="%(name)s - %(levelname)s - %(message)s",
|
|
37
|
+
)
|
|
38
|
+
logging.getLogger("napari").setLevel(logging.WARNING)
|
|
39
|
+
_ensure_qt_plugins()
|
|
40
|
+
import napari
|
|
41
|
+
from ethograph.gui.widgets_meta import MetaWidget
|
|
42
|
+
|
|
43
|
+
viewer = napari.Viewer()
|
|
44
|
+
viewer.window.add_dock_widget(MetaWidget(viewer), name="ethograph GUI")
|
|
45
|
+
napari.run()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
"""Main CLI entry point."""
|
|
50
|
+
if len(sys.argv) < 2:
|
|
51
|
+
print("Usage: ethograph <command>")
|
|
52
|
+
print("Commands:")
|
|
53
|
+
print(" launch Launch the ethograph GUI")
|
|
54
|
+
print(" shortcut Install desktop/Start Menu shortcut")
|
|
55
|
+
sys.exit(1)
|
|
56
|
+
|
|
57
|
+
command = sys.argv[1]
|
|
58
|
+
|
|
59
|
+
if command == "launch":
|
|
60
|
+
launch()
|
|
61
|
+
elif command == "shortcut":
|
|
62
|
+
from ethograph.shortcuts import install_shortcut
|
|
63
|
+
sys.exit(install_shortcut())
|
|
64
|
+
else:
|
|
65
|
+
print(f"Unknown command: {command}")
|
|
66
|
+
print("Available commands: launch, shortcut")
|
|
67
|
+
sys.exit(1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
if __name__ == "__main__":
|
|
71
|
+
main()
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Custom code for importing data from lab-internal matlab codes to python"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import xarray as xr
|
|
9
|
+
from scipy.io import loadmat
|
|
10
|
+
|
|
11
|
+
import ethograph as eto
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_all_trials_path_info(all_trials_path):
|
|
15
|
+
"""
|
|
16
|
+
Extract subject_id, session_date, session_number, and dataset_name from a path string.
|
|
17
|
+
Args:
|
|
18
|
+
all_trials_path (str): Path string to parse.
|
|
19
|
+
Returns:
|
|
20
|
+
subject_id (str), session_date (str), session_number (str), dataset_name (str)
|
|
21
|
+
"""
|
|
22
|
+
path = all_trials_path.replace('\\', '/')
|
|
23
|
+
|
|
24
|
+
id_match = re.search(r'id-([^/\\]+)', path)
|
|
25
|
+
subject_id = id_match.group(1) if id_match else ''
|
|
26
|
+
|
|
27
|
+
date_sess_match = re.search(r'date-(\d{8})_(\d{2})', path)
|
|
28
|
+
if date_sess_match:
|
|
29
|
+
session_date = date_sess_match.group(1)
|
|
30
|
+
session_number = date_sess_match.group(2)
|
|
31
|
+
else:
|
|
32
|
+
session_date = ''
|
|
33
|
+
session_number = ''
|
|
34
|
+
|
|
35
|
+
dataset_name = f'{session_date}-{session_number}_{subject_id}'
|
|
36
|
+
|
|
37
|
+
return subject_id, session_date, session_number, dataset_name
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def update_dt_with_matlab_pulse_onsets(
|
|
42
|
+
dt: xr.DataTree,
|
|
43
|
+
matlab_path: str | Path,
|
|
44
|
+
) -> xr.DataTree:
|
|
45
|
+
"""
|
|
46
|
+
Update NetCDF trial data with information from MATLAB AllTrials structure.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
nc_path : str | Path
|
|
51
|
+
Path to the NetCDF file containing trial data
|
|
52
|
+
matlab_path : str | Path
|
|
53
|
+
Path to the MATLAB file containing AllTrials structure
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
xr.DataTree
|
|
58
|
+
Updated DataTree with MATLAB trial information
|
|
59
|
+
|
|
60
|
+
Raises
|
|
61
|
+
------
|
|
62
|
+
FileNotFoundError
|
|
63
|
+
If input files don't exist
|
|
64
|
+
ValueError
|
|
65
|
+
If required data structures are missing
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
matlab_path = Path(matlab_path)
|
|
69
|
+
|
|
70
|
+
if not matlab_path.exists():
|
|
71
|
+
raise FileNotFoundError(f"MATLAB file not found: {matlab_path}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Load MATLAB data
|
|
77
|
+
mat_data = loadmat(
|
|
78
|
+
matlab_path,
|
|
79
|
+
squeeze_me=True,
|
|
80
|
+
struct_as_record=False
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if 'AllTrials' not in mat_data:
|
|
84
|
+
raise ValueError(f"'AllTrials' structure not found in {matlab_path}")
|
|
85
|
+
|
|
86
|
+
all_trials = mat_data['AllTrials']
|
|
87
|
+
|
|
88
|
+
# Create lookup for efficient trial matching
|
|
89
|
+
matlab_trials_dict = {
|
|
90
|
+
trial.trial_num: trial
|
|
91
|
+
for trial in all_trials
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# Update each dataset in the tree
|
|
95
|
+
for node_name, node in dt.children.items():
|
|
96
|
+
if node.ds is None or 'trial' not in node.ds.attrs:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
trial_num = node.ds.attrs['trial']
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if trial_num not in matlab_trials_dict:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
trial = matlab_trials_dict[trial_num]
|
|
106
|
+
|
|
107
|
+
# Create mutable copy of the dataset
|
|
108
|
+
ds = node.to_dataset().copy()
|
|
109
|
+
|
|
110
|
+
pulse_onsets = trial.pulse_info.pulse_onsets
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
assert len(pulse_onsets) == len(ds.time)
|
|
114
|
+
|
|
115
|
+
ds['pulse_onsets'] = ('time', pulse_onsets)
|
|
116
|
+
|
|
117
|
+
# Assign back to the DataTree node
|
|
118
|
+
dt[node_name] = ds
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
return dt
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def update_nc_with_matlab_trials(
|
|
125
|
+
dt: xr.DataTree,
|
|
126
|
+
matlab_path: str | Path,
|
|
127
|
+
) -> xr.DataTree:
|
|
128
|
+
"""
|
|
129
|
+
Update NetCDF trial data with information from MATLAB AllTrials structure.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
dt : xr.DataTree
|
|
134
|
+
The DataTree containing trial data
|
|
135
|
+
matlab_path : str | Path
|
|
136
|
+
Path to the MATLAB file containing AllTrials structure
|
|
137
|
+
output_path : Optional[str | Path]
|
|
138
|
+
Path to save updated NetCDF. If None, overwrites input file
|
|
139
|
+
|
|
140
|
+
matlab_path : str | Path
|
|
141
|
+
Path to the MATLAB file containing AllTrials structure
|
|
142
|
+
|
|
143
|
+
Returns
|
|
144
|
+
-------
|
|
145
|
+
xr.DataTree
|
|
146
|
+
Updated DataTree with MATLAB trial information
|
|
147
|
+
|
|
148
|
+
Raises
|
|
149
|
+
------
|
|
150
|
+
FileNotFoundError
|
|
151
|
+
If input files don't exist
|
|
152
|
+
ValueError
|
|
153
|
+
If required data structures are missing
|
|
154
|
+
"""
|
|
155
|
+
matlab_path = Path(matlab_path)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
if not matlab_path.exists():
|
|
159
|
+
raise FileNotFoundError(f"MATLAB file not found: {matlab_path}")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# Load MATLAB data
|
|
163
|
+
mat_data = loadmat(
|
|
164
|
+
matlab_path,
|
|
165
|
+
squeeze_me=True,
|
|
166
|
+
struct_as_record=False
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if 'AllTrials' not in mat_data:
|
|
170
|
+
raise ValueError(f"'AllTrials' structure not found in {matlab_path}")
|
|
171
|
+
|
|
172
|
+
all_trials = mat_data['AllTrials']
|
|
173
|
+
|
|
174
|
+
# Create lookup for efficient trial matching
|
|
175
|
+
matlab_trials_dict = {
|
|
176
|
+
trial.trial_num: trial
|
|
177
|
+
for trial in all_trials
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
# Update each dataset in the tree
|
|
181
|
+
for node_name, node in dt.children.items():
|
|
182
|
+
if node.ds is None or 'trial' not in node.ds.attrs:
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
trial_num = node.ds.attrs['trial']
|
|
186
|
+
|
|
187
|
+
if trial_num not in matlab_trials_dict:
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
trial = matlab_trials_dict[trial_num]
|
|
191
|
+
ds = node.ds
|
|
192
|
+
|
|
193
|
+
# Update attributes
|
|
194
|
+
if hasattr(trial.info, 'poscat'):
|
|
195
|
+
ds.attrs['poscat'] = trial.info.poscat
|
|
196
|
+
if hasattr(trial.info, 'num_pellets'):
|
|
197
|
+
ds.attrs['num_pellets'] = trial.info.num_pellets
|
|
198
|
+
|
|
199
|
+
# Process boundary events
|
|
200
|
+
event_data = _extract_boundary_events(trial.info)
|
|
201
|
+
if event_data is not None:
|
|
202
|
+
ds['boundary_events'] = ('events', event_data)
|
|
203
|
+
ds = ds.assign_coords(
|
|
204
|
+
events=('events', ['disp_out', 'disp_in', 'box_in', 'box_out'])
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Initialize labels
|
|
208
|
+
n_time = len(ds.coords['time'])
|
|
209
|
+
n_individuals = len(ds.coords['individuals'])
|
|
210
|
+
ds['labels'] = (
|
|
211
|
+
('time', 'individuals'),
|
|
212
|
+
np.zeros((n_time, n_individuals))
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Update labels if available
|
|
216
|
+
if hasattr(trial, 'label_infos') and hasattr(trial.label_infos, 'beakTip'):
|
|
217
|
+
if 'bird' in ds.attrs:
|
|
218
|
+
bird = ds.attrs['bird']
|
|
219
|
+
labels = trial.label_infos.beakTip.labels
|
|
220
|
+
ds['labels'].loc[dict(individuals=bird)] = labels
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
return dt
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _extract_boundary_events(trial_info) -> Optional[np.ndarray]:
|
|
227
|
+
"""
|
|
228
|
+
Extract boundary events from trial info structure.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
trial_info : object
|
|
233
|
+
MATLAB trial info structure
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
Optional[np.ndarray]
|
|
238
|
+
Array of boundary events or None if extraction fails
|
|
239
|
+
"""
|
|
240
|
+
event_data = np.zeros(4)
|
|
241
|
+
|
|
242
|
+
if not hasattr(trial_info, 'stick_in_out_disp'):
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
disp_out_in = trial_info.stick_in_out_disp
|
|
246
|
+
|
|
247
|
+
# Extract disp_out and disp_in
|
|
248
|
+
for j, val in enumerate([0, 1]):
|
|
249
|
+
try:
|
|
250
|
+
event_data[j] = int(disp_out_in[val]) - 1
|
|
251
|
+
except (IndexError, TypeError, ValueError):
|
|
252
|
+
event_data[j] = np.nan
|
|
253
|
+
|
|
254
|
+
# Extract box_in and box_out
|
|
255
|
+
if hasattr(trial_info, 'first_in_last_out'):
|
|
256
|
+
for j, idx in enumerate([0, 1], start=2):
|
|
257
|
+
try:
|
|
258
|
+
event_data[j] = int(trial_info.first_in_last_out[idx-2]) - 1
|
|
259
|
+
except (IndexError, TypeError, ValueError):
|
|
260
|
+
event_data[j] = np.nan
|
|
261
|
+
else:
|
|
262
|
+
event_data[2:] = np.nan
|
|
263
|
+
|
|
264
|
+
return event_data
|