ethograph 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ethograph might be problematic. Click here for more details.

Files changed (120) hide show
  1. ethograph/__init__.py +115 -0
  2. ethograph/__main__.py +3 -0
  3. ethograph/assets/icon.ico +0 -0
  4. ethograph/assets/icon.png +0 -0
  5. ethograph/assets/menu.json +27 -0
  6. ethograph/cli.py +71 -0
  7. ethograph/crowlab/io_matlab.py +264 -0
  8. ethograph/crowlab/legacy.py +389 -0
  9. ethograph/datasets.py +314 -0
  10. ethograph/features/__init__.py +0 -0
  11. ethograph/features/audio_changepoints.py +176 -0
  12. ethograph/features/changepoints.py +501 -0
  13. ethograph/features/energy.py +389 -0
  14. ethograph/features/movement.py +509 -0
  15. ethograph/features/neural.py +219 -0
  16. ethograph/features/oscillatory.py +55 -0
  17. ethograph/features/preprocessing.py +184 -0
  18. ethograph/gui/__init__.py +35 -0
  19. ethograph/gui/app_constants.py +157 -0
  20. ethograph/gui/app_state.py +1173 -0
  21. ethograph/gui/audio_player.py +153 -0
  22. ethograph/gui/dialog_busy_progress.py +119 -0
  23. ethograph/gui/dialog_function_params.py +1065 -0
  24. ethograph/gui/dialog_pose_video_matcher.py +289 -0
  25. ethograph/gui/dialog_screen_recorder.py +419 -0
  26. ethograph/gui/dialog_select_template.py +269 -0
  27. ethograph/gui/dialog_video_downsample.py +292 -0
  28. ethograph/gui/label_drawing_mixin.py +407 -0
  29. ethograph/gui/make_pretty.py +290 -0
  30. ethograph/gui/napari.yaml +12 -0
  31. ethograph/gui/notify.py +56 -0
  32. ethograph/gui/plots_audiotrace.py +208 -0
  33. ethograph/gui/plots_base.py +416 -0
  34. ethograph/gui/plots_container.py +1064 -0
  35. ethograph/gui/plots_ephystrace.py +1687 -0
  36. ethograph/gui/plots_heatmap.py +543 -0
  37. ethograph/gui/plots_lineplot.py +412 -0
  38. ethograph/gui/plots_overlay.py +316 -0
  39. ethograph/gui/plots_psth.py +409 -0
  40. ethograph/gui/plots_raster.py +283 -0
  41. ethograph/gui/plots_space.py +942 -0
  42. ethograph/gui/plots_spectrogram.py +323 -0
  43. ethograph/gui/pose_render.py +664 -0
  44. ethograph/gui/shortcuts.py +302 -0
  45. ethograph/gui/templates/wizard_nwb_codegen.j2 +109 -0
  46. ethograph/gui/video_manager.py +602 -0
  47. ethograph/gui/video_sync.py +261 -0
  48. ethograph/gui/widget_trials.py +444 -0
  49. ethograph/gui/widgets_changepoints.py +1426 -0
  50. ethograph/gui/widgets_data.py +2602 -0
  51. ethograph/gui/widgets_ephys.py +2603 -0
  52. ethograph/gui/widgets_help.py +257 -0
  53. ethograph/gui/widgets_io.py +1765 -0
  54. ethograph/gui/widgets_labels.py +1366 -0
  55. ethograph/gui/widgets_meta.py +515 -0
  56. ethograph/gui/widgets_navigation.py +947 -0
  57. ethograph/gui/widgets_plot_settings.py +833 -0
  58. ethograph/gui/widgets_psth.py +757 -0
  59. ethograph/gui/widgets_transform.py +53 -0
  60. ethograph/gui/wizard_boris.py +317 -0
  61. ethograph/gui/wizard_media_files.py +1124 -0
  62. ethograph/gui/wizard_multi_builder.py +186 -0
  63. ethograph/gui/wizard_multi_codegen.py +231 -0
  64. ethograph/gui/wizard_multi_tabs.py +1121 -0
  65. ethograph/gui/wizard_multi_timeline.py +846 -0
  66. ethograph/gui/wizard_multi_trials.py +564 -0
  67. ethograph/gui/wizard_nwb.py +508 -0
  68. ethograph/gui/wizard_overview.py +565 -0
  69. ethograph/gui/wizard_single.py +1149 -0
  70. ethograph/io/catalog.py +953 -0
  71. ethograph/io/data_loader.py +670 -0
  72. ethograph/io/dataset.py +282 -0
  73. ethograph/io/metadata_table.py +279 -0
  74. ethograph/io/nwb_alignment.py +1367 -0
  75. ethograph/io/nwb_import.py +170 -0
  76. ethograph/io/plot_sources.py +511 -0
  77. ethograph/io/pynapple.py +226 -0
  78. ethograph/io/time_model.py +512 -0
  79. ethograph/io/time_sources.py +131 -0
  80. ethograph/io/trialtree.py +572 -0
  81. ethograph/io/validation.py +289 -0
  82. ethograph/labels/__init__.py +11 -0
  83. ethograph/labels/boris.py +323 -0
  84. ethograph/labels/converters.py +425 -0
  85. ethograph/labels/crowsetta_format.py +120 -0
  86. ethograph/labels/export.py +145 -0
  87. ethograph/labels/intervals.py +708 -0
  88. ethograph/labels/ml.py +505 -0
  89. ethograph/labels/plots.py +336 -0
  90. ethograph/labels/predictions.py +257 -0
  91. ethograph/labels/tsv_store.py +270 -0
  92. ethograph/model/batch_gen.py +135 -0
  93. ethograph/model/cetnet_encoder.py +589 -0
  94. ethograph/model/dataset.py +302 -0
  95. ethograph/model/eval_metrics.py +244 -0
  96. ethograph/model/eval_plotting.py +465 -0
  97. ethograph/shortcuts.py +126 -0
  98. ethograph/utils/__init__.py +0 -0
  99. ethograph/utils/arraytools.py +361 -0
  100. ethograph/utils/audio.py +65 -0
  101. ethograph/utils/download.py +625 -0
  102. ethograph/utils/nwb.py +156 -0
  103. ethograph/utils/paths.py +277 -0
  104. ethograph/utils/qt.py +198 -0
  105. ethograph/utils/sequences.py +129 -0
  106. ethograph/utils/stream_durations.py +168 -0
  107. ethograph/utils/xr_utils.py +165 -0
  108. ethograph/video_features/base_extractor.py +122 -0
  109. ethograph/video_features/checkpoint/S3D_kinetics400_torchified.pt +0 -0
  110. ethograph/video_features/extract_s3d.py +117 -0
  111. ethograph/video_features/s3d.py +357 -0
  112. ethograph/video_features/s3d.yml +16 -0
  113. ethograph/video_features/transforms.py +309 -0
  114. ethograph/video_features/utils.py +168 -0
  115. ethograph-0.1.3.dist-info/METADATA +92 -0
  116. ethograph-0.1.3.dist-info/RECORD +120 -0
  117. ethograph-0.1.3.dist-info/WHEEL +5 -0
  118. ethograph-0.1.3.dist-info/entry_points.txt +8 -0
  119. ethograph-0.1.3.dist-info/licenses/LICENSE +28 -0
  120. ethograph-0.1.3.dist-info/top_level.txt +1 -0
ethograph/__init__.py ADDED
@@ -0,0 +1,115 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ try:
4
+ __version__ = version("ethograph")
5
+ except PackageNotFoundError:
6
+ # package is not installed
7
+ pass
8
+
9
+
10
+ from ethograph.io.trialtree import TrialTree
11
+ from ethograph.io.dataset import (
12
+ add_angle_rgb_to_ds,
13
+ add_changepoints_to_ds,
14
+ downsample_trialtree,
15
+ )
16
+ from ethograph.utils.xr_utils import get_ds_duration, get_time_coord, sel_valid
17
+ from ethograph.utils.paths import get_project_root
18
+ from ethograph.io.catalog import (
19
+ DataCatalog,
20
+ DataLoader,
21
+ PlotData,
22
+ PynappleLoader,
23
+ XarrayLoader,
24
+ catalog_from_pynapple,
25
+ catalog_from_xarray,
26
+ )
27
+ from ethograph.io.pynapple import load_nap_data
28
+ from ethograph.io.nwb_alignment import (
29
+ NWBAlignment,
30
+ align_media_per_trial,
31
+ align_media_from_streams,
32
+ )
33
+ from ethograph.datasets import sample_data
34
+ from ethograph.io.time_model import (
35
+ SourceCollection,
36
+ TimeRange,
37
+ TimeSource,
38
+ )
39
+
40
+
41
+ def open(path: str) -> TrialTree:
42
+ """Load a TrialTree from a saved NetCDF file.
43
+
44
+ Shorthand for :meth:`TrialTree.open <ethograph.io.trialtree.TrialTree.open>`.
45
+
46
+ Parameters
47
+ ----------
48
+ path : str or Path
49
+ Path to a ``.nc`` file previously saved with ``dt.save()``.
50
+
51
+ Returns
52
+ -------
53
+ TrialTree
54
+
55
+ Examples
56
+ --------
57
+ >>> import ethograph as eto
58
+ >>> dt = eto.open("experiment.nc")
59
+ >>> dt.trials
60
+ [1, 2, 3]
61
+ >>> ds = dt.itrial(0)
62
+ """
63
+ return TrialTree.open(path)
64
+
65
+
66
+ def from_datasets(datasets: list) -> TrialTree:
67
+ """Build a TrialTree from a list of per-trial xarray Datasets.
68
+
69
+ Shorthand for :meth:`TrialTree.from_datasets <ethograph.io.trialtree.TrialTree.from_datasets>`.
70
+
71
+ Each dataset must have ``attrs["trial"]`` set to a unique trial
72
+ identifier.
73
+
74
+ Parameters
75
+ ----------
76
+ datasets : list[xarray.Dataset]
77
+ One Dataset per trial.
78
+
79
+ Returns
80
+ -------
81
+ TrialTree
82
+
83
+ Examples
84
+ --------
85
+ >>> import xarray as xr, numpy as np, ethograph as eto
86
+ >>> trials = []
87
+ >>> for i in range(1, 4):
88
+ ... ds = xr.Dataset({"speed": xr.DataArray(np.random.rand(300), dims=["time"])})
89
+ ... ds.attrs["trial"] = i
90
+ ... trials.append(ds)
91
+ >>> dt = eto.from_datasets(trials)
92
+ >>> dt.trials
93
+ [1, 2, 3]
94
+ """
95
+ return TrialTree.from_datasets(datasets)
96
+
97
+
98
+ def from_continuous(ds, epochs) -> TrialTree:
99
+ """Build a TrialTree from a single continuous recording + trial epochs.
100
+
101
+ Shorthand for :meth:`TrialTree.from_continuous`.
102
+
103
+ Parameters
104
+ ----------
105
+ ds : xarray.Dataset
106
+ Full recording dataset.
107
+ epochs : pandas.DataFrame or pynapple.IntervalSet
108
+ Trial boundaries. DataFrame must have columns ``trial``,
109
+ ``start_time``, ``stop_time``.
110
+
111
+ Returns
112
+ -------
113
+ TrialTree
114
+ """
115
+ return TrialTree.from_continuous(ds, epochs)
ethograph/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from ethograph.cli import main
2
+
3
+ main()
Binary file
Binary file
@@ -0,0 +1,27 @@
1
+ {
2
+ "$schema": "https://json-schema.org/draft-07/schema",
3
+ "$id": "https://schemas.conda.io/menuinst-1.schema.json",
4
+ "menu_name": "ethograph",
5
+ "menu_items": [
6
+ {
7
+ "name": "ethograph",
8
+ "command": ["{{ PYTHON }}", "-m", "ethograph", "launch"],
9
+ "activate": true,
10
+ "terminal": true,
11
+ "platforms": {
12
+ "win": {
13
+ "desktop": true,
14
+ "icon": "{{ MENU_DIR }}/icon.ico",
15
+ "command": ["{{ PYTHON }}", "-m", "ethograph", "launch", "--keep-open"]
16
+ },
17
+ "linux": {
18
+ "Categories": ["Science", "Education"],
19
+ "icon": "{{ MENU_DIR }}/icon.png"
20
+ },
21
+ "osx": {
22
+ "icon": "{{ MENU_DIR }}/icon.png"
23
+ }
24
+ }
25
+ }
26
+ ]
27
+ }
ethograph/cli.py ADDED
@@ -0,0 +1,71 @@
1
+ #!/usr/bin/env python
2
+ """Command-line interface for ethograph."""
3
+
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+
9
+ # Suppress noisy dependency warnings before any imports trigger them
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"logging")
11
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"vispy\.")
12
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"numpy\.")
13
+
14
+ # PyOpenGL info message goes through logging, not warnings
15
+ logging.getLogger("OpenGL.acceleratesupport").setLevel(logging.WARNING)
16
+
17
+ def _ensure_qt_plugins():
18
+ """Set QT_PLUGIN_PATH for conda-forge Qt installs (needed by menuinst shortcuts)."""
19
+ if os.environ.get("QT_PLUGIN_PATH"):
20
+ return
21
+ candidates = [
22
+ os.path.join(sys.prefix, "Library", "plugins"), # Windows conda-forge
23
+ os.path.join(sys.prefix, "lib", "qt5", "plugins"), # Linux conda-forge
24
+ os.path.join(sys.prefix, "lib", "qt", "plugins"), # macOS conda-forge
25
+ ]
26
+ for path in candidates:
27
+ if os.path.isdir(os.path.join(path, "platforms")):
28
+ os.environ["QT_PLUGIN_PATH"] = path
29
+ return
30
+
31
+
32
+ def launch():
33
+ """Launch the ethograph GUI."""
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(name)s - %(levelname)s - %(message)s",
37
+ )
38
+ logging.getLogger("napari").setLevel(logging.WARNING)
39
+ _ensure_qt_plugins()
40
+ import napari
41
+ from ethograph.gui.widgets_meta import MetaWidget
42
+
43
+ viewer = napari.Viewer()
44
+ viewer.window.add_dock_widget(MetaWidget(viewer), name="ethograph GUI")
45
+ napari.run()
46
+
47
+
48
+ def main():
49
+ """Main CLI entry point."""
50
+ if len(sys.argv) < 2:
51
+ print("Usage: ethograph <command>")
52
+ print("Commands:")
53
+ print(" launch Launch the ethograph GUI")
54
+ print(" shortcut Install desktop/Start Menu shortcut")
55
+ sys.exit(1)
56
+
57
+ command = sys.argv[1]
58
+
59
+ if command == "launch":
60
+ launch()
61
+ elif command == "shortcut":
62
+ from ethograph.shortcuts import install_shortcut
63
+ sys.exit(install_shortcut())
64
+ else:
65
+ print(f"Unknown command: {command}")
66
+ print("Available commands: launch, shortcut")
67
+ sys.exit(1)
68
+
69
+
70
+ if __name__ == "__main__":
71
+ main()
@@ -0,0 +1,264 @@
1
+ """Custom code for importing data from lab-internal matlab codes to python"""
2
+
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import xarray as xr
9
+ from scipy.io import loadmat
10
+
11
+ import ethograph as eto
12
+
13
+
14
+ def get_all_trials_path_info(all_trials_path):
15
+ """
16
+ Extract subject_id, session_date, session_number, and dataset_name from a path string.
17
+ Args:
18
+ all_trials_path (str): Path string to parse.
19
+ Returns:
20
+ subject_id (str), session_date (str), session_number (str), dataset_name (str)
21
+ """
22
+ path = all_trials_path.replace('\\', '/')
23
+
24
+ id_match = re.search(r'id-([^/\\]+)', path)
25
+ subject_id = id_match.group(1) if id_match else ''
26
+
27
+ date_sess_match = re.search(r'date-(\d{8})_(\d{2})', path)
28
+ if date_sess_match:
29
+ session_date = date_sess_match.group(1)
30
+ session_number = date_sess_match.group(2)
31
+ else:
32
+ session_date = ''
33
+ session_number = ''
34
+
35
+ dataset_name = f'{session_date}-{session_number}_{subject_id}'
36
+
37
+ return subject_id, session_date, session_number, dataset_name
38
+
39
+
40
+
41
+ def update_dt_with_matlab_pulse_onsets(
42
+ dt: xr.DataTree,
43
+ matlab_path: str | Path,
44
+ ) -> xr.DataTree:
45
+ """
46
+ Update NetCDF trial data with information from MATLAB AllTrials structure.
47
+
48
+ Parameters
49
+ ----------
50
+ nc_path : str | Path
51
+ Path to the NetCDF file containing trial data
52
+ matlab_path : str | Path
53
+ Path to the MATLAB file containing AllTrials structure
54
+
55
+ Returns
56
+ -------
57
+ xr.DataTree
58
+ Updated DataTree with MATLAB trial information
59
+
60
+ Raises
61
+ ------
62
+ FileNotFoundError
63
+ If input files don't exist
64
+ ValueError
65
+ If required data structures are missing
66
+ """
67
+
68
+ matlab_path = Path(matlab_path)
69
+
70
+ if not matlab_path.exists():
71
+ raise FileNotFoundError(f"MATLAB file not found: {matlab_path}")
72
+
73
+
74
+
75
+
76
+ # Load MATLAB data
77
+ mat_data = loadmat(
78
+ matlab_path,
79
+ squeeze_me=True,
80
+ struct_as_record=False
81
+ )
82
+
83
+ if 'AllTrials' not in mat_data:
84
+ raise ValueError(f"'AllTrials' structure not found in {matlab_path}")
85
+
86
+ all_trials = mat_data['AllTrials']
87
+
88
+ # Create lookup for efficient trial matching
89
+ matlab_trials_dict = {
90
+ trial.trial_num: trial
91
+ for trial in all_trials
92
+ }
93
+
94
+ # Update each dataset in the tree
95
+ for node_name, node in dt.children.items():
96
+ if node.ds is None or 'trial' not in node.ds.attrs:
97
+ continue
98
+
99
+ trial_num = node.ds.attrs['trial']
100
+
101
+
102
+ if trial_num not in matlab_trials_dict:
103
+ continue
104
+
105
+ trial = matlab_trials_dict[trial_num]
106
+
107
+ # Create mutable copy of the dataset
108
+ ds = node.to_dataset().copy()
109
+
110
+ pulse_onsets = trial.pulse_info.pulse_onsets
111
+
112
+
113
+ assert len(pulse_onsets) == len(ds.time)
114
+
115
+ ds['pulse_onsets'] = ('time', pulse_onsets)
116
+
117
+ # Assign back to the DataTree node
118
+ dt[node_name] = ds
119
+
120
+
121
+ return dt
122
+
123
+
124
+ def update_nc_with_matlab_trials(
125
+ dt: xr.DataTree,
126
+ matlab_path: str | Path,
127
+ ) -> xr.DataTree:
128
+ """
129
+ Update NetCDF trial data with information from MATLAB AllTrials structure.
130
+
131
+ Parameters
132
+ ----------
133
+ dt : xr.DataTree
134
+ The DataTree containing trial data
135
+ matlab_path : str | Path
136
+ Path to the MATLAB file containing AllTrials structure
137
+ output_path : Optional[str | Path]
138
+ Path to save updated NetCDF. If None, overwrites input file
139
+
140
+ matlab_path : str | Path
141
+ Path to the MATLAB file containing AllTrials structure
142
+
143
+ Returns
144
+ -------
145
+ xr.DataTree
146
+ Updated DataTree with MATLAB trial information
147
+
148
+ Raises
149
+ ------
150
+ FileNotFoundError
151
+ If input files don't exist
152
+ ValueError
153
+ If required data structures are missing
154
+ """
155
+ matlab_path = Path(matlab_path)
156
+
157
+
158
+ if not matlab_path.exists():
159
+ raise FileNotFoundError(f"MATLAB file not found: {matlab_path}")
160
+
161
+
162
+ # Load MATLAB data
163
+ mat_data = loadmat(
164
+ matlab_path,
165
+ squeeze_me=True,
166
+ struct_as_record=False
167
+ )
168
+
169
+ if 'AllTrials' not in mat_data:
170
+ raise ValueError(f"'AllTrials' structure not found in {matlab_path}")
171
+
172
+ all_trials = mat_data['AllTrials']
173
+
174
+ # Create lookup for efficient trial matching
175
+ matlab_trials_dict = {
176
+ trial.trial_num: trial
177
+ for trial in all_trials
178
+ }
179
+
180
+ # Update each dataset in the tree
181
+ for node_name, node in dt.children.items():
182
+ if node.ds is None or 'trial' not in node.ds.attrs:
183
+ continue
184
+
185
+ trial_num = node.ds.attrs['trial']
186
+
187
+ if trial_num not in matlab_trials_dict:
188
+ continue
189
+
190
+ trial = matlab_trials_dict[trial_num]
191
+ ds = node.ds
192
+
193
+ # Update attributes
194
+ if hasattr(trial.info, 'poscat'):
195
+ ds.attrs['poscat'] = trial.info.poscat
196
+ if hasattr(trial.info, 'num_pellets'):
197
+ ds.attrs['num_pellets'] = trial.info.num_pellets
198
+
199
+ # Process boundary events
200
+ event_data = _extract_boundary_events(trial.info)
201
+ if event_data is not None:
202
+ ds['boundary_events'] = ('events', event_data)
203
+ ds = ds.assign_coords(
204
+ events=('events', ['disp_out', 'disp_in', 'box_in', 'box_out'])
205
+ )
206
+
207
+ # Initialize labels
208
+ n_time = len(ds.coords['time'])
209
+ n_individuals = len(ds.coords['individuals'])
210
+ ds['labels'] = (
211
+ ('time', 'individuals'),
212
+ np.zeros((n_time, n_individuals))
213
+ )
214
+
215
+ # Update labels if available
216
+ if hasattr(trial, 'label_infos') and hasattr(trial.label_infos, 'beakTip'):
217
+ if 'bird' in ds.attrs:
218
+ bird = ds.attrs['bird']
219
+ labels = trial.label_infos.beakTip.labels
220
+ ds['labels'].loc[dict(individuals=bird)] = labels
221
+
222
+
223
+ return dt
224
+
225
+
226
+ def _extract_boundary_events(trial_info) -> Optional[np.ndarray]:
227
+ """
228
+ Extract boundary events from trial info structure.
229
+
230
+ Parameters
231
+ ----------
232
+ trial_info : object
233
+ MATLAB trial info structure
234
+
235
+ Returns
236
+ -------
237
+ Optional[np.ndarray]
238
+ Array of boundary events or None if extraction fails
239
+ """
240
+ event_data = np.zeros(4)
241
+
242
+ if not hasattr(trial_info, 'stick_in_out_disp'):
243
+ return None
244
+
245
+ disp_out_in = trial_info.stick_in_out_disp
246
+
247
+ # Extract disp_out and disp_in
248
+ for j, val in enumerate([0, 1]):
249
+ try:
250
+ event_data[j] = int(disp_out_in[val]) - 1
251
+ except (IndexError, TypeError, ValueError):
252
+ event_data[j] = np.nan
253
+
254
+ # Extract box_in and box_out
255
+ if hasattr(trial_info, 'first_in_last_out'):
256
+ for j, idx in enumerate([0, 1], start=2):
257
+ try:
258
+ event_data[j] = int(trial_info.first_in_last_out[idx-2]) - 1
259
+ except (IndexError, TypeError, ValueError):
260
+ event_data[j] = np.nan
261
+ else:
262
+ event_data[2:] = np.nan
263
+
264
+ return event_data