nwb-video-widgets 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nwb_video_widgets/__init__.py +16 -0
- nwb_video_widgets/_utils.py +328 -0
- nwb_video_widgets/dandi_pose_widget.py +356 -0
- nwb_video_widgets/dandi_video_widget.py +155 -0
- nwb_video_widgets/local_pose_widget.py +334 -0
- nwb_video_widgets/local_video_widget.py +130 -0
- nwb_video_widgets/pose_widget.css +624 -0
- nwb_video_widgets/pose_widget.js +798 -0
- nwb_video_widgets/video_widget.css +484 -0
- nwb_video_widgets/video_widget.js +566 -0
- nwb_video_widgets/video_widget.py +170 -0
- nwb_video_widgets-0.1.0.dist-info/METADATA +174 -0
- nwb_video_widgets-0.1.0.dist-info/RECORD +15 -0
- nwb_video_widgets-0.1.0.dist-info/WHEEL +4 -0
- nwb_video_widgets-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Interactive Jupyter widgets for NWB video and pose visualization."""
|
|
2
|
+
|
|
3
|
+
from nwb_video_widgets.dandi_pose_widget import NWBDANDIPoseEstimationWidget
|
|
4
|
+
from nwb_video_widgets.dandi_video_widget import NWBDANDIVideoPlayer
|
|
5
|
+
from nwb_video_widgets.local_pose_widget import NWBLocalPoseEstimationWidget
|
|
6
|
+
from nwb_video_widgets.local_video_widget import NWBLocalVideoPlayer
|
|
7
|
+
from nwb_video_widgets.video_widget import NWBFileVideoPlayer
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"NWBLocalVideoPlayer",
|
|
11
|
+
"NWBDANDIVideoPlayer",
|
|
12
|
+
"NWBFileVideoPlayer",
|
|
13
|
+
"NWBLocalPoseEstimationWidget",
|
|
14
|
+
"NWBDANDIPoseEstimationWidget",
|
|
15
|
+
]
|
|
16
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""Shared utilities for NWB video widgets."""
|
|
2
|
+
|
|
3
|
+
import socket
|
|
4
|
+
import threading
|
|
5
|
+
from functools import partial
|
|
6
|
+
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from pynwb import NWBFile
|
|
10
|
+
from pynwb.image import ImageSeries
|
|
11
|
+
|
|
12
|
+
# Global registry for video file servers
|
|
13
|
+
_video_servers: dict[str, tuple[HTTPServer, int]] = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def discover_video_series(nwbfile: NWBFile) -> dict[str, ImageSeries]:
|
|
17
|
+
"""Discover all ImageSeries with external video files in an NWB file.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
nwbfile : NWBFile
|
|
22
|
+
NWB file to search for video series
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
dict[str, ImageSeries]
|
|
27
|
+
Mapping of series names to ImageSeries objects that have external_file
|
|
28
|
+
"""
|
|
29
|
+
video_series = {}
|
|
30
|
+
for name, obj in nwbfile.acquisition.items():
|
|
31
|
+
if isinstance(obj, ImageSeries) and obj.external_file is not None:
|
|
32
|
+
video_series[name] = obj
|
|
33
|
+
return video_series
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_video_timestamps(nwbfile: NWBFile) -> dict[str, list[float]]:
|
|
37
|
+
"""Extract video timestamps from all ImageSeries in an NWB file.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
nwbfile : NWBFile
|
|
42
|
+
NWB file containing video ImageSeries in acquisition
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
dict[str, list[float]]
|
|
47
|
+
Mapping of video names to timestamp arrays
|
|
48
|
+
"""
|
|
49
|
+
video_series = discover_video_series(nwbfile)
|
|
50
|
+
timestamps = {}
|
|
51
|
+
|
|
52
|
+
for name, series in video_series.items():
|
|
53
|
+
if series.timestamps is not None:
|
|
54
|
+
timestamps[name] = [float(t) for t in series.timestamps[:]]
|
|
55
|
+
elif series.starting_time is not None:
|
|
56
|
+
timestamps[name] = [float(series.starting_time)]
|
|
57
|
+
else:
|
|
58
|
+
timestamps[name] = [0.0]
|
|
59
|
+
|
|
60
|
+
return timestamps
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_video_info(nwbfile: NWBFile) -> dict[str, dict]:
|
|
64
|
+
"""Extract video time range information from all ImageSeries in an NWB file.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
nwbfile : NWBFile
|
|
69
|
+
NWB file containing video ImageSeries in acquisition
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
dict[str, dict]
|
|
74
|
+
Mapping of video names to info dictionaries with keys:
|
|
75
|
+
- start: float, start time in seconds
|
|
76
|
+
- end: float, end time in seconds
|
|
77
|
+
- frames: int, number of frames
|
|
78
|
+
"""
|
|
79
|
+
video_series = discover_video_series(nwbfile)
|
|
80
|
+
info = {}
|
|
81
|
+
|
|
82
|
+
for name, series in video_series.items():
|
|
83
|
+
if series.timestamps is not None:
|
|
84
|
+
timestamps = series.timestamps[:]
|
|
85
|
+
start = float(timestamps[0])
|
|
86
|
+
end = float(timestamps[-1])
|
|
87
|
+
frames = len(timestamps)
|
|
88
|
+
elif series.starting_time is not None:
|
|
89
|
+
start = float(series.starting_time)
|
|
90
|
+
# Without timestamps, we can't determine end time accurately
|
|
91
|
+
# Use starting_time as both start and end
|
|
92
|
+
end = start
|
|
93
|
+
frames = 1
|
|
94
|
+
else:
|
|
95
|
+
start = 0.0
|
|
96
|
+
end = 0.0
|
|
97
|
+
frames = 1
|
|
98
|
+
|
|
99
|
+
info[name] = {
|
|
100
|
+
"start": start,
|
|
101
|
+
"end": end,
|
|
102
|
+
"frames": frames,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
return info
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class _RangeRequestHandler(SimpleHTTPRequestHandler):
|
|
109
|
+
"""HTTP request handler with CORS headers and Range request support for video streaming."""
|
|
110
|
+
|
|
111
|
+
def send_head(self):
|
|
112
|
+
"""Handle HEAD requests and Range requests for partial content."""
|
|
113
|
+
path = self.translate_path(self.path)
|
|
114
|
+
|
|
115
|
+
if not Path(path).is_file():
|
|
116
|
+
return super().send_head()
|
|
117
|
+
|
|
118
|
+
file_size = Path(path).stat().st_size
|
|
119
|
+
range_header = self.headers.get("Range")
|
|
120
|
+
|
|
121
|
+
if range_header:
|
|
122
|
+
# Parse Range header (e.g., "bytes=0-1023")
|
|
123
|
+
try:
|
|
124
|
+
range_spec = range_header.replace("bytes=", "")
|
|
125
|
+
start_str, end_str = range_spec.split("-")
|
|
126
|
+
start = int(start_str) if start_str else 0
|
|
127
|
+
end = int(end_str) if end_str else file_size - 1
|
|
128
|
+
end = min(end, file_size - 1)
|
|
129
|
+
content_length = end - start + 1
|
|
130
|
+
|
|
131
|
+
f = open(path, "rb")
|
|
132
|
+
f.seek(start)
|
|
133
|
+
|
|
134
|
+
self.send_response(206) # Partial Content
|
|
135
|
+
self.send_header("Content-Type", self.guess_type(path))
|
|
136
|
+
self.send_header("Content-Length", str(content_length))
|
|
137
|
+
self.send_header("Content-Range", f"bytes {start}-{end}/{file_size}")
|
|
138
|
+
self.send_header("Accept-Ranges", "bytes")
|
|
139
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
140
|
+
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS, HEAD")
|
|
141
|
+
self.send_header("Access-Control-Allow-Headers", "Range")
|
|
142
|
+
self.send_header("Access-Control-Expose-Headers", "Content-Range, Content-Length")
|
|
143
|
+
self.end_headers()
|
|
144
|
+
return f
|
|
145
|
+
except (ValueError, IOError):
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
# No Range header or invalid range - serve full file
|
|
149
|
+
return super().send_head()
|
|
150
|
+
|
|
151
|
+
def end_headers(self):
|
|
152
|
+
"""Add CORS headers to all responses."""
|
|
153
|
+
# Only add if not already added (for non-range requests)
|
|
154
|
+
if not self._headers_buffer or b"Access-Control-Allow-Origin" not in b"".join(self._headers_buffer):
|
|
155
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
156
|
+
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS, HEAD")
|
|
157
|
+
self.send_header("Access-Control-Allow-Headers", "Range")
|
|
158
|
+
self.send_header("Accept-Ranges", "bytes")
|
|
159
|
+
super().end_headers()
|
|
160
|
+
|
|
161
|
+
def do_OPTIONS(self):
|
|
162
|
+
"""Handle CORS preflight requests."""
|
|
163
|
+
self.send_response(200)
|
|
164
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
165
|
+
self.send_header("Access-Control-Allow-Methods", "GET, OPTIONS, HEAD")
|
|
166
|
+
self.send_header("Access-Control-Allow-Headers", "Range")
|
|
167
|
+
self.send_header("Access-Control-Max-Age", "86400")
|
|
168
|
+
self.end_headers()
|
|
169
|
+
|
|
170
|
+
def log_message(self, format, *args):
|
|
171
|
+
"""Suppress logging to avoid cluttering notebook output."""
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
def handle(self):
|
|
175
|
+
"""Handle requests, suppressing connection reset errors."""
|
|
176
|
+
try:
|
|
177
|
+
super().handle()
|
|
178
|
+
except (ConnectionResetError, BrokenPipeError):
|
|
179
|
+
# Browser closed connection early - this is normal during video seeking
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _find_free_port() -> int:
|
|
184
|
+
"""Find a free port on localhost."""
|
|
185
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
186
|
+
s.bind(("", 0))
|
|
187
|
+
return s.getsockname()[1]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def start_video_server(directory: Path) -> int:
|
|
191
|
+
"""Start an HTTP server to serve video files from a directory.
|
|
192
|
+
|
|
193
|
+
If a server is already running for this directory, returns its port.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
directory : Path
|
|
198
|
+
Directory containing video files to serve
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
int
|
|
203
|
+
Port number the server is listening on
|
|
204
|
+
"""
|
|
205
|
+
dir_key = str(directory.resolve())
|
|
206
|
+
|
|
207
|
+
# Return existing server port if already running
|
|
208
|
+
if dir_key in _video_servers:
|
|
209
|
+
_, port = _video_servers[dir_key]
|
|
210
|
+
return port
|
|
211
|
+
|
|
212
|
+
port = _find_free_port()
|
|
213
|
+
handler = partial(_RangeRequestHandler, directory=str(directory))
|
|
214
|
+
server = HTTPServer(("127.0.0.1", port), handler)
|
|
215
|
+
|
|
216
|
+
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
217
|
+
thread.start()
|
|
218
|
+
|
|
219
|
+
_video_servers[dir_key] = (server, port)
|
|
220
|
+
return port
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def discover_pose_estimation_cameras(nwbfile: NWBFile) -> dict:
|
|
224
|
+
"""Discover all PoseEstimation containers in an NWB file.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
nwbfile : NWBFile
|
|
229
|
+
NWB file to search for pose estimation data
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
dict
|
|
234
|
+
Mapping of camera names to PoseEstimation objects from
|
|
235
|
+
processing['pose_estimation'].
|
|
236
|
+
"""
|
|
237
|
+
if "pose_estimation" not in nwbfile.processing:
|
|
238
|
+
return {}
|
|
239
|
+
|
|
240
|
+
pose_module = nwbfile.processing["pose_estimation"]
|
|
241
|
+
|
|
242
|
+
# Get only PoseEstimation objects (not Skeletons or other types)
|
|
243
|
+
cameras = {}
|
|
244
|
+
for name, obj in pose_module.data_interfaces.items():
|
|
245
|
+
if type(obj).__name__ == "PoseEstimation":
|
|
246
|
+
cameras[name] = obj
|
|
247
|
+
|
|
248
|
+
return cameras
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def get_camera_to_video_mapping(nwbfile: NWBFile) -> dict[str, str]:
|
|
252
|
+
"""Auto-map pose estimation camera names to video series names.
|
|
253
|
+
|
|
254
|
+
Uses the naming convention: camera name prefixed with "Video"
|
|
255
|
+
- 'LeftCamera' -> 'VideoLeftCamera'
|
|
256
|
+
- 'BodyCamera' -> 'VideoBodyCamera'
|
|
257
|
+
|
|
258
|
+
Only returns mappings where both the camera and corresponding video exist.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
nwbfile : NWBFile
|
|
263
|
+
NWB file containing pose estimation and video data
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
dict[str, str]
|
|
268
|
+
Mapping from camera names to video series names
|
|
269
|
+
"""
|
|
270
|
+
cameras = discover_pose_estimation_cameras(nwbfile)
|
|
271
|
+
video_series = discover_video_series(nwbfile)
|
|
272
|
+
|
|
273
|
+
mapping = {}
|
|
274
|
+
for camera_name in cameras:
|
|
275
|
+
video_name = f"Video{camera_name}"
|
|
276
|
+
if video_name in video_series:
|
|
277
|
+
mapping[camera_name] = video_name
|
|
278
|
+
|
|
279
|
+
return mapping
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def get_pose_estimation_info(nwbfile: NWBFile) -> dict[str, dict]:
|
|
283
|
+
"""Extract pose estimation info for all cameras in an NWB file.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
nwbfile : NWBFile
|
|
288
|
+
NWB file containing pose estimation in processing['pose_estimation']
|
|
289
|
+
|
|
290
|
+
Returns
|
|
291
|
+
-------
|
|
292
|
+
dict[str, dict]
|
|
293
|
+
Mapping of camera names to info dictionaries with keys:
|
|
294
|
+
- start: float, start time in seconds
|
|
295
|
+
- end: float, end time in seconds
|
|
296
|
+
- frames: int, number of frames
|
|
297
|
+
- keypoints: list[str], names of keypoints
|
|
298
|
+
"""
|
|
299
|
+
cameras = discover_pose_estimation_cameras(nwbfile)
|
|
300
|
+
info = {}
|
|
301
|
+
|
|
302
|
+
for camera_name, pose_estimation in cameras.items():
|
|
303
|
+
# Get keypoint names (remove PoseEstimationSeries suffix)
|
|
304
|
+
keypoint_names = [
|
|
305
|
+
name.replace("PoseEstimationSeries", "")
|
|
306
|
+
for name in pose_estimation.pose_estimation_series.keys()
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
# Get timestamps from the first pose estimation series
|
|
310
|
+
first_series = next(iter(pose_estimation.pose_estimation_series.values()), None)
|
|
311
|
+
if first_series is not None and first_series.timestamps is not None:
|
|
312
|
+
timestamps = first_series.timestamps[:]
|
|
313
|
+
start = float(timestamps[0])
|
|
314
|
+
end = float(timestamps[-1])
|
|
315
|
+
frames = len(timestamps)
|
|
316
|
+
else:
|
|
317
|
+
start = 0.0
|
|
318
|
+
end = 0.0
|
|
319
|
+
frames = 0
|
|
320
|
+
|
|
321
|
+
info[camera_name] = {
|
|
322
|
+
"start": start,
|
|
323
|
+
"end": end,
|
|
324
|
+
"frames": frames,
|
|
325
|
+
"keypoints": keypoint_names,
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
return info
|
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""DANDI NWB pose estimation video overlay widget."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import pathlib
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Optional
|
|
8
|
+
|
|
9
|
+
import anywidget
|
|
10
|
+
import matplotlib.colors as mcolors
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
import traitlets
|
|
14
|
+
from pynwb import NWBFile
|
|
15
|
+
|
|
16
|
+
from nwb_video_widgets._utils import (
|
|
17
|
+
discover_pose_estimation_cameras,
|
|
18
|
+
discover_video_series,
|
|
19
|
+
get_pose_estimation_info,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from dandi.dandiapi import RemoteAsset
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NWBDANDIPoseEstimationWidget(anywidget.AnyWidget):
|
|
27
|
+
"""Video player with pose estimation overlay for DANDI-hosted NWB files.
|
|
28
|
+
|
|
29
|
+
Overlays DeepLabCut keypoints on streaming video with support for
|
|
30
|
+
camera selection via a settings panel.
|
|
31
|
+
|
|
32
|
+
This widget discovers PoseEstimation containers in processing['pose_estimation']
|
|
33
|
+
and resolves video paths to S3 URLs via the DANDI API. An interactive
|
|
34
|
+
settings panel allows users to select which camera to display.
|
|
35
|
+
|
|
36
|
+
Supports two common NWB patterns:
|
|
37
|
+
1. Single file: both videos and pose estimation in same NWB file
|
|
38
|
+
2. Split files: videos in raw NWB file, pose estimation in processed file
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
asset : RemoteAsset
|
|
43
|
+
DANDI asset object for the processed NWB file containing pose estimation.
|
|
44
|
+
The dandiset_id and asset path are extracted from this object.
|
|
45
|
+
nwbfile : pynwb.NWBFile, optional
|
|
46
|
+
Pre-loaded NWB file containing pose estimation. If not provided, the widget
|
|
47
|
+
will load the NWB file via streaming from `asset`.
|
|
48
|
+
video_asset : RemoteAsset, optional
|
|
49
|
+
DANDI asset object for the raw NWB file containing videos. If not provided,
|
|
50
|
+
videos are assumed to be accessible relative to `asset`.
|
|
51
|
+
video_nwbfile : pynwb.NWBFile, optional
|
|
52
|
+
Pre-loaded NWB file containing video ImageSeries. If not provided but
|
|
53
|
+
`video_asset` is provided, the widget will extract video URLs from `video_asset`.
|
|
54
|
+
If neither is provided, videos are assumed to be in `nwbfile`.
|
|
55
|
+
keypoint_colors : str or dict, default 'tab10'
|
|
56
|
+
Either a matplotlib colormap name (e.g., 'tab10', 'Set1', 'Paired') for
|
|
57
|
+
automatic color assignment, or a dict mapping keypoint names to hex colors
|
|
58
|
+
(e.g., {'LeftPaw': '#FF0000', 'RightPaw': '#00FF00'}).
|
|
59
|
+
default_camera : str, optional
|
|
60
|
+
Camera to display initially. Falls back to first available if not found.
|
|
61
|
+
|
|
62
|
+
Example
|
|
63
|
+
-------
|
|
64
|
+
Single file (videos + pose in same file):
|
|
65
|
+
|
|
66
|
+
>>> from dandi.dandiapi import DandiAPIClient
|
|
67
|
+
>>> client = DandiAPIClient()
|
|
68
|
+
>>> dandiset = client.get_dandiset("000409", "draft")
|
|
69
|
+
>>> asset = dandiset.get_asset_by_path("sub-.../sub-..._combined.nwb")
|
|
70
|
+
>>> widget = NWBDANDIPoseEstimationWidget(asset=asset)
|
|
71
|
+
>>> display(widget)
|
|
72
|
+
|
|
73
|
+
Split files (videos in raw, pose in processed):
|
|
74
|
+
|
|
75
|
+
>>> raw_asset = dandiset.get_asset_by_path("sub-.../sub-..._desc-raw.nwb")
|
|
76
|
+
>>> processed_asset = dandiset.get_asset_by_path("sub-.../sub-..._desc-processed.nwb")
|
|
77
|
+
>>> widget = NWBDANDIPoseEstimationWidget(
|
|
78
|
+
... asset=processed_asset,
|
|
79
|
+
... video_asset=raw_asset,
|
|
80
|
+
... )
|
|
81
|
+
>>> display(widget)
|
|
82
|
+
|
|
83
|
+
With pre-loaded NWB files (avoids re-loading):
|
|
84
|
+
|
|
85
|
+
>>> widget = NWBDANDIPoseEstimationWidget(
|
|
86
|
+
... asset=processed_asset,
|
|
87
|
+
... nwbfile=nwbfile_processed,
|
|
88
|
+
... video_asset=raw_asset,
|
|
89
|
+
... video_nwbfile=nwbfile_raw,
|
|
90
|
+
... )
|
|
91
|
+
|
|
92
|
+
Raises
|
|
93
|
+
------
|
|
94
|
+
ValueError
|
|
95
|
+
If no cameras have both pose data and video.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
selected_camera = traitlets.Unicode("").tag(sync=True)
|
|
99
|
+
available_cameras = traitlets.List([]).tag(sync=True)
|
|
100
|
+
available_cameras_info = traitlets.Dict({}).tag(sync=True)
|
|
101
|
+
|
|
102
|
+
# Video selection - users explicitly match cameras to videos
|
|
103
|
+
available_videos = traitlets.List([]).tag(sync=True)
|
|
104
|
+
available_videos_info = traitlets.Dict({}).tag(sync=True)
|
|
105
|
+
video_name_to_url = traitlets.Dict({}).tag(sync=True) # Video name -> URL mapping
|
|
106
|
+
camera_to_video = traitlets.Dict({}).tag(sync=True) # Camera -> video name mapping
|
|
107
|
+
|
|
108
|
+
settings_open = traitlets.Bool(True).tag(sync=True)
|
|
109
|
+
|
|
110
|
+
# Pose data for cameras - loaded lazily when selected
|
|
111
|
+
all_camera_data = traitlets.Dict({}).tag(sync=True)
|
|
112
|
+
|
|
113
|
+
# Loading state for progress indicator
|
|
114
|
+
loading = traitlets.Bool(False).tag(sync=True)
|
|
115
|
+
|
|
116
|
+
show_labels = traitlets.Bool(True).tag(sync=True)
|
|
117
|
+
visible_keypoints = traitlets.Dict({}).tag(sync=True)
|
|
118
|
+
|
|
119
|
+
_esm = pathlib.Path(__file__).parent / "pose_widget.js"
|
|
120
|
+
_css = pathlib.Path(__file__).parent / "pose_widget.css"
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
asset: RemoteAsset,
|
|
125
|
+
nwbfile: Optional[NWBFile] = None,
|
|
126
|
+
video_asset: Optional[RemoteAsset] = None,
|
|
127
|
+
video_nwbfile: Optional[NWBFile] = None,
|
|
128
|
+
keypoint_colors: str | dict[str, str] = "tab10",
|
|
129
|
+
default_camera: Optional[str] = None,
|
|
130
|
+
**kwargs,
|
|
131
|
+
):
|
|
132
|
+
# Load NWB file if not provided (for pose estimation)
|
|
133
|
+
if nwbfile is None:
|
|
134
|
+
nwbfile = self._load_nwbfile_from_dandi(asset)
|
|
135
|
+
|
|
136
|
+
# Determine video source
|
|
137
|
+
# Priority: video_nwbfile > video_asset > nwbfile
|
|
138
|
+
if video_nwbfile is not None:
|
|
139
|
+
video_source_nwbfile = video_nwbfile
|
|
140
|
+
elif video_asset is not None:
|
|
141
|
+
video_source_nwbfile = self._load_nwbfile_from_dandi(video_asset)
|
|
142
|
+
else:
|
|
143
|
+
video_source_nwbfile = nwbfile
|
|
144
|
+
|
|
145
|
+
# Determine which asset to use for video URLs
|
|
146
|
+
video_source_asset = video_asset if video_asset is not None else asset
|
|
147
|
+
|
|
148
|
+
# Compute video URLs from DANDI
|
|
149
|
+
video_urls = self._get_video_urls_from_dandi(video_source_nwbfile, video_source_asset)
|
|
150
|
+
|
|
151
|
+
# Parse keypoint_colors
|
|
152
|
+
if isinstance(keypoint_colors, str):
|
|
153
|
+
colormap_name = keypoint_colors
|
|
154
|
+
custom_colors = {}
|
|
155
|
+
else:
|
|
156
|
+
colormap_name = "tab10"
|
|
157
|
+
custom_colors = keypoint_colors
|
|
158
|
+
|
|
159
|
+
# Get pose estimation container
|
|
160
|
+
if "pose_estimation" not in nwbfile.processing:
|
|
161
|
+
raise ValueError("NWB file does not contain pose_estimation processing module")
|
|
162
|
+
pose_estimation = nwbfile.processing["pose_estimation"]
|
|
163
|
+
|
|
164
|
+
# Get all PoseEstimation containers (excludes Skeletons and other metadata)
|
|
165
|
+
pose_containers = discover_pose_estimation_cameras(nwbfile)
|
|
166
|
+
available_cameras = list(pose_containers.keys())
|
|
167
|
+
|
|
168
|
+
# Get camera info for settings panel display
|
|
169
|
+
available_cameras_info = get_pose_estimation_info(nwbfile)
|
|
170
|
+
|
|
171
|
+
# Get ALL available videos (sorted alphabetically)
|
|
172
|
+
available_videos = sorted(video_urls.keys())
|
|
173
|
+
available_videos_info = self._get_video_info(video_source_nwbfile)
|
|
174
|
+
|
|
175
|
+
# Video name to URL mapping (sent to JS for URL resolution)
|
|
176
|
+
video_name_to_url = video_urls
|
|
177
|
+
|
|
178
|
+
# Start with empty mapping - users explicitly select videos
|
|
179
|
+
camera_to_video = {}
|
|
180
|
+
|
|
181
|
+
# Select default camera - start with empty to show settings
|
|
182
|
+
if default_camera and default_camera in available_cameras:
|
|
183
|
+
selected_camera = default_camera
|
|
184
|
+
else:
|
|
185
|
+
selected_camera = ""
|
|
186
|
+
|
|
187
|
+
# Store references for lazy loading (not synced to JS)
|
|
188
|
+
self._pose_estimation = pose_estimation
|
|
189
|
+
self._cmap = plt.get_cmap(colormap_name)
|
|
190
|
+
self._custom_colors = custom_colors
|
|
191
|
+
|
|
192
|
+
super().__init__(
|
|
193
|
+
selected_camera=selected_camera,
|
|
194
|
+
available_cameras=available_cameras,
|
|
195
|
+
available_cameras_info=available_cameras_info,
|
|
196
|
+
available_videos=available_videos,
|
|
197
|
+
available_videos_info=available_videos_info,
|
|
198
|
+
video_name_to_url=video_name_to_url,
|
|
199
|
+
camera_to_video=camera_to_video,
|
|
200
|
+
all_camera_data={}, # Start empty, load lazily
|
|
201
|
+
visible_keypoints={}, # Populated as cameras are loaded
|
|
202
|
+
settings_open=True,
|
|
203
|
+
**kwargs,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
@traitlets.observe("selected_camera")
|
|
207
|
+
def _on_camera_selected(self, change):
|
|
208
|
+
"""Load pose data lazily when a camera is selected."""
|
|
209
|
+
camera_name = change["new"]
|
|
210
|
+
if not camera_name or camera_name in self.all_camera_data:
|
|
211
|
+
return # Already loaded or no camera selected
|
|
212
|
+
|
|
213
|
+
# Signal loading start
|
|
214
|
+
self.loading = True
|
|
215
|
+
|
|
216
|
+
try:
|
|
217
|
+
# Load pose data for this camera
|
|
218
|
+
camera_data = self._load_camera_pose_data(
|
|
219
|
+
self._pose_estimation, camera_name, self._cmap, self._custom_colors
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Update all_camera_data (must create new dict for traitlets to detect change)
|
|
223
|
+
self.all_camera_data = {**self.all_camera_data, camera_name: camera_data}
|
|
224
|
+
|
|
225
|
+
# Add any new keypoints to visible_keypoints
|
|
226
|
+
new_keypoints = {**self.visible_keypoints}
|
|
227
|
+
for name in camera_data["keypoint_metadata"].keys():
|
|
228
|
+
if name not in new_keypoints:
|
|
229
|
+
new_keypoints[name] = True
|
|
230
|
+
if new_keypoints != self.visible_keypoints:
|
|
231
|
+
self.visible_keypoints = new_keypoints
|
|
232
|
+
finally:
|
|
233
|
+
# Signal loading complete
|
|
234
|
+
self.loading = False
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def _load_nwbfile_from_dandi(asset: RemoteAsset) -> NWBFile:
|
|
238
|
+
"""Load an NWB file from DANDI via streaming."""
|
|
239
|
+
import h5py
|
|
240
|
+
import remfile
|
|
241
|
+
from pynwb import NWBHDF5IO
|
|
242
|
+
|
|
243
|
+
s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)
|
|
244
|
+
|
|
245
|
+
remote_file = remfile.File(s3_url)
|
|
246
|
+
h5_file = h5py.File(remote_file, "r")
|
|
247
|
+
io = NWBHDF5IO(file=h5_file, load_namespaces=True)
|
|
248
|
+
return io.read()
|
|
249
|
+
|
|
250
|
+
@staticmethod
|
|
251
|
+
def _get_video_info(nwbfile: NWBFile) -> dict[str, dict]:
|
|
252
|
+
"""Get metadata for all video series."""
|
|
253
|
+
video_series = discover_video_series(nwbfile)
|
|
254
|
+
info = {}
|
|
255
|
+
|
|
256
|
+
for name, series in video_series.items():
|
|
257
|
+
timestamps = None
|
|
258
|
+
if series.timestamps is not None:
|
|
259
|
+
timestamps = series.timestamps[:]
|
|
260
|
+
elif series.starting_time is not None and series.rate is not None:
|
|
261
|
+
n_frames = series.data.shape[0] if hasattr(series.data, "shape") else 0
|
|
262
|
+
timestamps = np.arange(n_frames) / series.rate + series.starting_time
|
|
263
|
+
|
|
264
|
+
if timestamps is not None and len(timestamps) > 0:
|
|
265
|
+
info[name] = {
|
|
266
|
+
"start": float(timestamps[0]),
|
|
267
|
+
"end": float(timestamps[-1]),
|
|
268
|
+
"frames": len(timestamps),
|
|
269
|
+
}
|
|
270
|
+
else:
|
|
271
|
+
info[name] = {"start": 0, "end": 0, "frames": 0}
|
|
272
|
+
|
|
273
|
+
return info
|
|
274
|
+
|
|
275
|
+
@staticmethod
|
|
276
|
+
def _get_video_urls_from_dandi(
|
|
277
|
+
nwbfile: NWBFile,
|
|
278
|
+
asset: RemoteAsset,
|
|
279
|
+
) -> dict[str, str]:
|
|
280
|
+
"""Extract video S3 URLs from NWB file using DANDI API."""
|
|
281
|
+
from dandi.dandiapi import DandiAPIClient
|
|
282
|
+
|
|
283
|
+
client = DandiAPIClient()
|
|
284
|
+
dandiset = client.get_dandiset(asset.dandiset_id, asset.version_id)
|
|
285
|
+
|
|
286
|
+
nwb_parent = Path(asset.path).parent
|
|
287
|
+
video_series = discover_video_series(nwbfile)
|
|
288
|
+
video_urls = {}
|
|
289
|
+
|
|
290
|
+
for name, series in video_series.items():
|
|
291
|
+
relative_path = series.external_file[0].lstrip("./")
|
|
292
|
+
full_path = str(nwb_parent / relative_path)
|
|
293
|
+
|
|
294
|
+
video_asset = dandiset.get_asset_by_path(full_path)
|
|
295
|
+
if video_asset is not None:
|
|
296
|
+
video_urls[name] = video_asset.get_content_url(
|
|
297
|
+
follow_redirects=1, strip_query=True
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return video_urls
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _load_camera_pose_data(
|
|
304
|
+
pose_estimation, camera_name: str, cmap, custom_colors: dict
|
|
305
|
+
) -> dict:
|
|
306
|
+
"""Load pose data for a single camera.
|
|
307
|
+
|
|
308
|
+
Returns a dict with:
|
|
309
|
+
- keypoint_metadata: {name: {color, label}}
|
|
310
|
+
- pose_coordinates: {name: [[x, y], ...]} as JSON-serializable lists
|
|
311
|
+
- timestamps: [t0, t1, ...] as JSON-serializable list
|
|
312
|
+
"""
|
|
313
|
+
camera_pose = pose_estimation[camera_name]
|
|
314
|
+
|
|
315
|
+
keypoint_names = list(camera_pose.pose_estimation_series.keys())
|
|
316
|
+
n_kp = len(keypoint_names)
|
|
317
|
+
|
|
318
|
+
metadata = {}
|
|
319
|
+
coordinates = {}
|
|
320
|
+
timestamps = None
|
|
321
|
+
|
|
322
|
+
for index, (series_name, series) in enumerate(
|
|
323
|
+
camera_pose.pose_estimation_series.items()
|
|
324
|
+
):
|
|
325
|
+
short_name = series_name.replace("PoseEstimationSeries", "")
|
|
326
|
+
|
|
327
|
+
# Get coordinates - iterate to build list without memory duplication
|
|
328
|
+
data = series.data[:]
|
|
329
|
+
coords_list = []
|
|
330
|
+
for x, y in data:
|
|
331
|
+
if np.isnan(x) or np.isnan(y):
|
|
332
|
+
coords_list.append(None)
|
|
333
|
+
else:
|
|
334
|
+
coords_list.append([float(x), float(y)])
|
|
335
|
+
coordinates[short_name] = coords_list
|
|
336
|
+
|
|
337
|
+
if timestamps is None:
|
|
338
|
+
timestamps = series.timestamps[:].tolist()
|
|
339
|
+
|
|
340
|
+
# Assign color from custom dict or colormap
|
|
341
|
+
if short_name in custom_colors:
|
|
342
|
+
color = custom_colors[short_name]
|
|
343
|
+
else:
|
|
344
|
+
if hasattr(cmap, "N") and cmap.N < 256:
|
|
345
|
+
rgba = cmap(index % cmap.N)
|
|
346
|
+
else:
|
|
347
|
+
rgba = cmap(index / max(n_kp - 1, 1))
|
|
348
|
+
color = mcolors.to_hex(rgba)
|
|
349
|
+
|
|
350
|
+
metadata[short_name] = {"color": color, "label": short_name}
|
|
351
|
+
|
|
352
|
+
return {
|
|
353
|
+
"keypoint_metadata": metadata,
|
|
354
|
+
"pose_coordinates": coordinates,
|
|
355
|
+
"timestamps": timestamps,
|
|
356
|
+
}
|