bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import networkx as nx
|
|
10
|
+
import numpy as np
|
|
11
|
+
from loguru import logger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True, order=True)
|
|
15
|
+
class Frame:
|
|
16
|
+
video_name: str
|
|
17
|
+
frame_index: int
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, order=True)
|
|
21
|
+
class NodeID:
|
|
22
|
+
video_name: str
|
|
23
|
+
video_hash: str
|
|
24
|
+
frame_index: int
|
|
25
|
+
|
|
26
|
+
def __repr__(self) -> str:
|
|
27
|
+
return f"{self.video_name}_{self.video_hash[:8]}_{self.frame_index}"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DriverVideo:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
name: str,
|
|
34
|
+
video_path: str,
|
|
35
|
+
video_data_path: Optional[str] = None,
|
|
36
|
+
num_frames: Optional[int] = None,
|
|
37
|
+
*,
|
|
38
|
+
stride: int = 10,
|
|
39
|
+
single_direction: bool = False,
|
|
40
|
+
stop_on_user_speech: bool = False,
|
|
41
|
+
stop_on_agent_speech: bool = False,
|
|
42
|
+
lip_sync_required: bool = True,
|
|
43
|
+
) -> None:
|
|
44
|
+
self.video_name = name
|
|
45
|
+
|
|
46
|
+
self.video_path: str = video_path
|
|
47
|
+
|
|
48
|
+
# read the video hash
|
|
49
|
+
with open(video_path, "rb") as f:
|
|
50
|
+
self.video_hash = hashlib.md5(f.read()).hexdigest()
|
|
51
|
+
|
|
52
|
+
# update the video_data_path
|
|
53
|
+
self.video_data_path: Optional[str] = video_data_path or _find_video_data_path(
|
|
54
|
+
video_path, self.video_hash
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# read the number of frames
|
|
58
|
+
cap = cv2.VideoCapture(video_path)
|
|
59
|
+
try:
|
|
60
|
+
total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
61
|
+
self.resolution = (
|
|
62
|
+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
|
63
|
+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
|
64
|
+
)
|
|
65
|
+
finally:
|
|
66
|
+
cap.release()
|
|
67
|
+
if num_frames is not None:
|
|
68
|
+
assert num_frames <= total_num_frames, (
|
|
69
|
+
f"num_frames {num_frames} > total_num_frames {total_num_frames}"
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
num_frames = total_num_frames
|
|
73
|
+
self.num_frames = num_frames
|
|
74
|
+
self.frames = [Frame(self.video_name, i) for i in range(num_frames)]
|
|
75
|
+
|
|
76
|
+
# The transition points, default to all nodes
|
|
77
|
+
self.single_direction = single_direction
|
|
78
|
+
self.init_nodes(stride)
|
|
79
|
+
self.transition_nodes = self.nodes
|
|
80
|
+
|
|
81
|
+
# Stop on user speech or agent speech
|
|
82
|
+
self.stop_on_user_speech = stop_on_user_speech
|
|
83
|
+
self.stop_on_agent_speech = stop_on_agent_speech
|
|
84
|
+
self.lip_sync_required = lip_sync_required
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def video_id(self) -> str:
|
|
88
|
+
return f"{self.video_name}_{self.video_hash[:8]}"
|
|
89
|
+
|
|
90
|
+
def get_frame_wh(self, scale_size: Optional[int] = None) -> Tuple[int, int]:
|
|
91
|
+
if scale_size is None:
|
|
92
|
+
return self.resolution
|
|
93
|
+
|
|
94
|
+
# scale max dimension to `scale_size`
|
|
95
|
+
scale = scale_size / max(self.resolution)
|
|
96
|
+
return (
|
|
97
|
+
int(self.resolution[0] * scale),
|
|
98
|
+
int(self.resolution[1] * scale),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def get_first_frame(self, scale_size: Optional[int] = None) -> np.ndarray:
|
|
102
|
+
cap = cv2.VideoCapture(str(self.video_path))
|
|
103
|
+
try:
|
|
104
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
105
|
+
ret, frame = cap.read()
|
|
106
|
+
if not ret:
|
|
107
|
+
raise ValueError("Failed to read the first frame")
|
|
108
|
+
|
|
109
|
+
if scale_size is not None:
|
|
110
|
+
frame = cv2.resize(frame, self.get_frame_wh(scale_size))
|
|
111
|
+
return frame
|
|
112
|
+
finally:
|
|
113
|
+
cap.release()
|
|
114
|
+
|
|
115
|
+
def init_nodes(self, stride: int) -> None:
|
|
116
|
+
if stride <= 0:
|
|
117
|
+
# one node for whole video
|
|
118
|
+
stride = self.num_frames
|
|
119
|
+
self.stride = stride
|
|
120
|
+
self.nodes = [
|
|
121
|
+
NodeID(self.video_name, self.video_hash, i)
|
|
122
|
+
for i in range(0, self.num_frames, stride)
|
|
123
|
+
]
|
|
124
|
+
# Last frame
|
|
125
|
+
if self.nodes[-1].frame_index != self.num_frames - 1:
|
|
126
|
+
self.nodes.append(
|
|
127
|
+
NodeID(self.video_name, self.video_hash, self.num_frames - 1)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def remove_nodes(
|
|
131
|
+
self, *, indices: List[int] = None, frame_indices: List[int] = None
|
|
132
|
+
) -> int:
|
|
133
|
+
new_nodes = self.nodes
|
|
134
|
+
|
|
135
|
+
if indices:
|
|
136
|
+
new_nodes = [
|
|
137
|
+
node for idx, node in enumerate(new_nodes) if idx not in set(indices)
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if frame_indices:
|
|
141
|
+
new_nodes = [
|
|
142
|
+
node for node in new_nodes if node.frame_index not in set(frame_indices)
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
removed = len(self.nodes) - len(new_nodes)
|
|
146
|
+
self.nodes = new_nodes
|
|
147
|
+
self.transition_nodes = self.nodes
|
|
148
|
+
return removed
|
|
149
|
+
|
|
150
|
+
def insert_nodes(self, frame_indices: List[int]) -> None:
|
|
151
|
+
new_nodes = []
|
|
152
|
+
for idx in frame_indices:
|
|
153
|
+
if idx < self.num_frames:
|
|
154
|
+
new_nodes.append(NodeID(self.video_name, self.video_hash, idx))
|
|
155
|
+
self.nodes = sorted(set(self.nodes + new_nodes))
|
|
156
|
+
self.transition_nodes = self.nodes
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def transition_frames(self) -> List[Frame]:
|
|
160
|
+
return [self.frames[node.frame_index] for node in self.transition_nodes]
|
|
161
|
+
|
|
162
|
+
def draw_nodes(
|
|
163
|
+
self, nodes: List[NodeID], n_cols: int = 2, image_width: int = 720
|
|
164
|
+
) -> np.ndarray:
|
|
165
|
+
"""Draw the frames of the nodes."""
|
|
166
|
+
frames = [self.frames[node.frame_index] for node in nodes]
|
|
167
|
+
node_indices = [self.nodes.index(node) for node in nodes]
|
|
168
|
+
labels = [f"{idx}: {node}" for idx, node in zip(node_indices, nodes)]
|
|
169
|
+
return self.draw_frames(frames, labels, n_cols, image_width)
|
|
170
|
+
|
|
171
|
+
def as_graph(self) -> nx.DiGraph:
|
|
172
|
+
"""Create a graph from the video nodes.
|
|
173
|
+
For single_direction mode, creates a directed cycle that only moves forward.
|
|
174
|
+
For normal mode, creates a bidirectional graph.
|
|
175
|
+
"""
|
|
176
|
+
graph = nx.DiGraph()
|
|
177
|
+
graph.add_nodes_from(self.nodes)
|
|
178
|
+
# Add edges between consecutive nodes
|
|
179
|
+
for i in range(1, len(self.nodes)):
|
|
180
|
+
first, second = self.nodes[i - 1], self.nodes[i]
|
|
181
|
+
distance = second.frame_index - first.frame_index
|
|
182
|
+
|
|
183
|
+
# Add forward edge
|
|
184
|
+
graph.add_edge(first, second, distance=distance)
|
|
185
|
+
|
|
186
|
+
# Add backward edge if not in single direction mode
|
|
187
|
+
if not self.single_direction:
|
|
188
|
+
graph.add_edge(second, first, distance=distance)
|
|
189
|
+
|
|
190
|
+
return graph
|
|
191
|
+
|
|
192
|
+
def collect_frames(
|
|
193
|
+
self,
|
|
194
|
+
source: NodeID | int,
|
|
195
|
+
target: NodeID | int,
|
|
196
|
+
allow_multi_steps: bool = True,
|
|
197
|
+
) -> List[Frame]:
|
|
198
|
+
"""Collect frames between two nodes.
|
|
199
|
+
|
|
200
|
+
Include the frame of the `source` and exclude the frame of the `target`.
|
|
201
|
+
If the distance between the two nodes is larger than one step,
|
|
202
|
+
and `allow_multi_steps` is False, return only the frame of the `prev_node`.
|
|
203
|
+
# TODO: write filler info to the edge and remove `allow_multi_steps`
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
List[np.ndarray]: The collected frames.
|
|
207
|
+
"""
|
|
208
|
+
if isinstance(source, int):
|
|
209
|
+
source = self.nodes[source]
|
|
210
|
+
if isinstance(target, int):
|
|
211
|
+
target = self.nodes[target]
|
|
212
|
+
|
|
213
|
+
if (
|
|
214
|
+
source.video_hash != self.video_hash
|
|
215
|
+
and target.video_hash != self.video_hash
|
|
216
|
+
):
|
|
217
|
+
logger.warning(
|
|
218
|
+
f"Both nodes are not from this video: "
|
|
219
|
+
f"{source} and {target} != {self.video_hash}"
|
|
220
|
+
)
|
|
221
|
+
return []
|
|
222
|
+
|
|
223
|
+
if source.video_hash != self.video_hash:
|
|
224
|
+
# Jump from another video to this video, return empty list
|
|
225
|
+
return []
|
|
226
|
+
|
|
227
|
+
if target.video_hash != self.video_hash or (
|
|
228
|
+
not allow_multi_steps
|
|
229
|
+
and abs(source.frame_index - target.frame_index) > self.stride
|
|
230
|
+
):
|
|
231
|
+
# Jump from this video to another video
|
|
232
|
+
# or the distance between the two nodes is larger than one step
|
|
233
|
+
# assume this is a "jumper connection"
|
|
234
|
+
return [self.frames[source.frame_index]]
|
|
235
|
+
|
|
236
|
+
# Both nodes are from this video
|
|
237
|
+
stride = 1 if source.frame_index < target.frame_index else -1
|
|
238
|
+
if self.single_direction and stride == -1:
|
|
239
|
+
# Single direction video can only move forward
|
|
240
|
+
return []
|
|
241
|
+
return self.frames[source.frame_index : target.frame_index : stride]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class LoopingVideo(DriverVideo):
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
name: str,
|
|
248
|
+
video_path: str,
|
|
249
|
+
video_data_path: Optional[str] = None,
|
|
250
|
+
num_frames: Optional[int] = None,
|
|
251
|
+
*,
|
|
252
|
+
stride: int = 10,
|
|
253
|
+
single_direction: bool = False,
|
|
254
|
+
stop_on_user_speech: bool = False,
|
|
255
|
+
stop_on_agent_speech: bool = False,
|
|
256
|
+
lip_sync_required: bool = True,
|
|
257
|
+
loop_between: Tuple[int, int] = (0, None),
|
|
258
|
+
) -> None:
|
|
259
|
+
"""A video that loops between two frames.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
loop_between: Tuple of (start, end) frame indices to loop between.
|
|
263
|
+
None means start or end of video.
|
|
264
|
+
single_direction: If True, video only plays forward from start to end,
|
|
265
|
+
then jumps back to start. If False, video plays back and forth.
|
|
266
|
+
"""
|
|
267
|
+
super().__init__(
|
|
268
|
+
name,
|
|
269
|
+
video_path=video_path,
|
|
270
|
+
video_data_path=video_data_path,
|
|
271
|
+
num_frames=num_frames,
|
|
272
|
+
stride=stride,
|
|
273
|
+
single_direction=single_direction,
|
|
274
|
+
stop_on_user_speech=stop_on_user_speech,
|
|
275
|
+
stop_on_agent_speech=stop_on_agent_speech,
|
|
276
|
+
lip_sync_required=lip_sync_required,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
self.loop_direction = 1 # 1 for forward, -1 for backward
|
|
280
|
+
self._loop_start_node = None
|
|
281
|
+
self._loop_end_node = None
|
|
282
|
+
self.set_loop_between(*loop_between)
|
|
283
|
+
|
|
284
|
+
def set_loop_between(self, start: Optional[int], end: Optional[int]) -> int:
|
|
285
|
+
start = start or 0
|
|
286
|
+
end = end or -1
|
|
287
|
+
if end < 0:
|
|
288
|
+
end = len(self.frames) + end
|
|
289
|
+
self.loop_between = (max(0, start), min(len(self.frames) - 1, end))
|
|
290
|
+
if self.loop_between[0] > self.loop_between[1]:
|
|
291
|
+
raise ValueError(f"Invalid loop_between {self.loop_between}")
|
|
292
|
+
|
|
293
|
+
# Create and store loop nodes
|
|
294
|
+
self._loop_start_node = NodeID(
|
|
295
|
+
self.video_name, self.video_hash, self.loop_between[0]
|
|
296
|
+
)
|
|
297
|
+
self._loop_end_node = NodeID(
|
|
298
|
+
self.video_name, self.video_hash, self.loop_between[1]
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Add loop nodes if not already in nodes list
|
|
302
|
+
if self._loop_start_node not in self.nodes:
|
|
303
|
+
self.nodes = sorted(self.nodes + [self._loop_start_node])
|
|
304
|
+
if self._loop_end_node not in self.nodes:
|
|
305
|
+
self.nodes = sorted(self.nodes + [self._loop_end_node])
|
|
306
|
+
|
|
307
|
+
return self.loop_between[1] - self.loop_between[0]
|
|
308
|
+
|
|
309
|
+
def get_n_frames(
|
|
310
|
+
self, min_n: int, start: NodeID, last_position: Optional[NodeID] = None
|
|
311
|
+
) -> Tuple[List[Frame], NodeID]:
|
|
312
|
+
"""Get at least `min_n` frames from the start node.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
min_n: The minimum number of frames to get.
|
|
316
|
+
start: The start node.
|
|
317
|
+
last_position: The last position node for determine direction.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
The collected frames and the last position node.
|
|
321
|
+
"""
|
|
322
|
+
if start.video_hash != self.video_hash:
|
|
323
|
+
logger.warning(
|
|
324
|
+
f"Enter node is not from this video: {start} != {self.video_hash}"
|
|
325
|
+
)
|
|
326
|
+
return [], None
|
|
327
|
+
|
|
328
|
+
if min_n <= 0:
|
|
329
|
+
return [], None
|
|
330
|
+
|
|
331
|
+
start_idx = self.nodes.index(start)
|
|
332
|
+
if start_idx == -1:
|
|
333
|
+
logger.warning(f"Node {start} is not in the nodes list")
|
|
334
|
+
return [], None
|
|
335
|
+
|
|
336
|
+
# NOTE: loop_end and loop_start are frame index,
|
|
337
|
+
# be careful if the new node index is valid
|
|
338
|
+
loop_start, loop_end = self.loop_between
|
|
339
|
+
|
|
340
|
+
# For single_direction, always move forward
|
|
341
|
+
if self.single_direction:
|
|
342
|
+
self.loop_direction = 1
|
|
343
|
+
elif last_position and last_position.video_hash != self.video_hash:
|
|
344
|
+
# reset the direction if the last postion is from another video
|
|
345
|
+
# Move to the direction with more frames
|
|
346
|
+
left_frames = start.frame_index - loop_start
|
|
347
|
+
right_frames = loop_end - start.frame_index
|
|
348
|
+
self.loop_direction = 1 if right_frames > left_frames else -1
|
|
349
|
+
|
|
350
|
+
frames = []
|
|
351
|
+
curr_node = start
|
|
352
|
+
curr_idx = start_idx
|
|
353
|
+
while len(frames) < min_n:
|
|
354
|
+
if self.single_direction:
|
|
355
|
+
if curr_node.frame_index >= loop_end:
|
|
356
|
+
# Jump back to loop_start when reaching loop_end
|
|
357
|
+
next_node = self._loop_start_node
|
|
358
|
+
next_idx = self.nodes.index(self._loop_start_node)
|
|
359
|
+
else:
|
|
360
|
+
next_idx = curr_idx + 1
|
|
361
|
+
next_node = self.nodes[next_idx]
|
|
362
|
+
else:
|
|
363
|
+
# bidirectional behavior
|
|
364
|
+
if curr_node.frame_index >= loop_end:
|
|
365
|
+
self.loop_direction = -1
|
|
366
|
+
elif curr_node.frame_index <= loop_start:
|
|
367
|
+
self.loop_direction = 1
|
|
368
|
+
next_idx = curr_idx + self.loop_direction
|
|
369
|
+
next_node = self.nodes[next_idx]
|
|
370
|
+
|
|
371
|
+
frames += self.collect_frames(curr_node, next_node)
|
|
372
|
+
curr_node = next_node
|
|
373
|
+
curr_idx = next_idx
|
|
374
|
+
|
|
375
|
+
return frames, curr_node
|
|
376
|
+
|
|
377
|
+
def as_graph(self) -> nx.Graph:
|
|
378
|
+
"""Create a graph from the video nodes with a loop back edge."""
|
|
379
|
+
graph = super().as_graph()
|
|
380
|
+
|
|
381
|
+
# Add loop back edge using stored nodes
|
|
382
|
+
if self.single_direction:
|
|
383
|
+
graph.add_edge(self._loop_end_node, self._loop_start_node, distance=0)
|
|
384
|
+
|
|
385
|
+
return graph
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class SingleActionVideo(DriverVideo):
|
|
389
|
+
def __init__(
|
|
390
|
+
self,
|
|
391
|
+
name: str,
|
|
392
|
+
video_path: str,
|
|
393
|
+
video_data_path: Optional[str] = None,
|
|
394
|
+
num_frames: Optional[int] = None,
|
|
395
|
+
*,
|
|
396
|
+
single_direction: bool = False,
|
|
397
|
+
stop_on_user_speech: bool = False,
|
|
398
|
+
stop_on_agent_speech: bool = False,
|
|
399
|
+
lip_sync_required: bool = True,
|
|
400
|
+
transition_frames: Optional[List[int]] = None,
|
|
401
|
+
action_frame: int = -1,
|
|
402
|
+
) -> None:
|
|
403
|
+
"""A video that plays a single action.
|
|
404
|
+
|
|
405
|
+
`transition_frame_indices` is a list of frame indices
|
|
406
|
+
that are the transition nodes.
|
|
407
|
+
"""
|
|
408
|
+
super().__init__(
|
|
409
|
+
name,
|
|
410
|
+
video_path=video_path,
|
|
411
|
+
video_data_path=video_data_path,
|
|
412
|
+
num_frames=num_frames,
|
|
413
|
+
stride=-1,
|
|
414
|
+
single_direction=single_direction,
|
|
415
|
+
stop_on_user_speech=stop_on_user_speech,
|
|
416
|
+
stop_on_agent_speech=stop_on_agent_speech,
|
|
417
|
+
lip_sync_required=lip_sync_required,
|
|
418
|
+
)
|
|
419
|
+
transition_frames = transition_frames or [0]
|
|
420
|
+
if single_direction:
|
|
421
|
+
transition_frames.append(-1)
|
|
422
|
+
transition_frames = [
|
|
423
|
+
frame if frame >= 0 else len(self.frames) + frame
|
|
424
|
+
for frame in transition_frames
|
|
425
|
+
]
|
|
426
|
+
transition_nodes = [
|
|
427
|
+
NodeID(self.video_name, self.video_hash, i) for i in set(transition_frames)
|
|
428
|
+
]
|
|
429
|
+
|
|
430
|
+
action_frame = (
|
|
431
|
+
action_frame if action_frame >= 0 else len(self.frames) + action_frame
|
|
432
|
+
)
|
|
433
|
+
self._action_node = NodeID(self.video_name, self.video_hash, action_frame)
|
|
434
|
+
|
|
435
|
+
self.nodes = sorted(set(self.nodes + transition_nodes + [self._action_node]))
|
|
436
|
+
self.transition_nodes = sorted(transition_nodes)
|
|
437
|
+
|
|
438
|
+
@property
|
|
439
|
+
def action_node(self) -> NodeID:
|
|
440
|
+
return self._action_node
|
|
441
|
+
|
|
442
|
+
def as_graph(self) -> nx.Graph:
|
|
443
|
+
if not self.single_direction:
|
|
444
|
+
return super().as_graph()
|
|
445
|
+
graph = nx.DiGraph()
|
|
446
|
+
graph.add_nodes_from(self.nodes)
|
|
447
|
+
# add edges for a single direction video
|
|
448
|
+
for i in range(1, len(self.nodes)):
|
|
449
|
+
first, second = self.nodes[i - 1], self.nodes[i]
|
|
450
|
+
graph.add_edge(
|
|
451
|
+
first, second, distance=second.frame_index - first.frame_index
|
|
452
|
+
)
|
|
453
|
+
return graph
|
|
454
|
+
|
|
455
|
+
def get_frames(self, start: NodeID) -> Tuple[List[Frame], NodeID]:
|
|
456
|
+
if start.video_hash != self.video_hash:
|
|
457
|
+
logger.warning(
|
|
458
|
+
f"Enter node is not from this video: {start} != {self.video_hash}"
|
|
459
|
+
)
|
|
460
|
+
return ([], None)
|
|
461
|
+
start_idx = self.nodes.index(start)
|
|
462
|
+
if start_idx == -1:
|
|
463
|
+
logger.warning(f"Node {start} is not in the nodes list")
|
|
464
|
+
return ([], None)
|
|
465
|
+
|
|
466
|
+
# Collect all frames of the action
|
|
467
|
+
# start -> last frame -> last transition frame
|
|
468
|
+
path = [start, self.action_node, self.transition_nodes[-1]]
|
|
469
|
+
frames = []
|
|
470
|
+
for i in range(1, len(path)):
|
|
471
|
+
frames += self.collect_frames(path[i - 1], path[i], allow_multi_steps=True)
|
|
472
|
+
return frames, path[-1]
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def _find_video_data_path(video_path: str, file_hash: str) -> Optional[str]:
|
|
476
|
+
video_path = Path(video_path)
|
|
477
|
+
name = f"{video_path.name}.*_{file_hash[:8]}"
|
|
478
|
+
for suffix in ["h5", "pth"]:
|
|
479
|
+
files = list(video_path.parent.glob(name + f".{suffix}"))
|
|
480
|
+
if files:
|
|
481
|
+
return files[0].as_posix()
|
|
482
|
+
return None
|