bithuman 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bithuman might be problematic. Click here for more details.
- bithuman/__init__.py +13 -0
- bithuman/_version.py +1 -0
- bithuman/api.py +164 -0
- bithuman/audio/__init__.py +19 -0
- bithuman/audio/audio.py +396 -0
- bithuman/audio/hparams.py +108 -0
- bithuman/audio/utils.py +255 -0
- bithuman/config.py +88 -0
- bithuman/engine/__init__.py +15 -0
- bithuman/engine/auth.py +335 -0
- bithuman/engine/compression.py +257 -0
- bithuman/engine/enums.py +16 -0
- bithuman/engine/image_ops.py +192 -0
- bithuman/engine/inference.py +108 -0
- bithuman/engine/knn.py +58 -0
- bithuman/engine/video_data.py +391 -0
- bithuman/engine/video_reader.py +168 -0
- bithuman/lib/__init__.py +1 -0
- bithuman/lib/audio_encoder.onnx +45631 -28
- bithuman/lib/generator.py +763 -0
- bithuman/lib/pth2h5.py +106 -0
- bithuman/plugins/__init__.py +0 -0
- bithuman/plugins/stt.py +185 -0
- bithuman/runtime.py +1004 -0
- bithuman/runtime_async.py +469 -0
- bithuman/service/__init__.py +9 -0
- bithuman/service/client.py +788 -0
- bithuman/service/messages.py +210 -0
- bithuman/service/server.py +759 -0
- bithuman/utils/__init__.py +43 -0
- bithuman/utils/agent.py +359 -0
- bithuman/utils/fps_controller.py +90 -0
- bithuman/utils/image.py +41 -0
- bithuman/utils/unzip.py +38 -0
- bithuman/video_graph/__init__.py +16 -0
- bithuman/video_graph/action_trigger.py +83 -0
- bithuman/video_graph/driver_video.py +482 -0
- bithuman/video_graph/navigator.py +736 -0
- bithuman/video_graph/trigger.py +90 -0
- bithuman/video_graph/video_script.py +344 -0
- bithuman-1.0.2.dist-info/METADATA +37 -0
- bithuman-1.0.2.dist-info/RECORD +44 -0
- bithuman-1.0.2.dist-info/WHEEL +5 -0
- bithuman-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from tempfile import TemporaryDirectory
|
|
8
|
+
from threading import Lock
|
|
9
|
+
from typing import Dict, List, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
import networkx as nx
|
|
12
|
+
import numpy as np
|
|
13
|
+
from loguru import logger
|
|
14
|
+
|
|
15
|
+
from bithuman.config import Settings
|
|
16
|
+
from bithuman.utils.unzip import unzip_tarfile
|
|
17
|
+
from bithuman.video_graph.driver_video import (
|
|
18
|
+
DriverVideo,
|
|
19
|
+
Frame,
|
|
20
|
+
LoopingVideo,
|
|
21
|
+
NodeID,
|
|
22
|
+
SingleActionVideo,
|
|
23
|
+
)
|
|
24
|
+
from bithuman.video_graph.video_script import VideoConfigs, VideoScript
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VideoGraphNavigator:
|
|
28
|
+
CROSS_VIDEO_PENALTY = 30
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
avatar_model_path: str | Tuple[str, TemporaryDirectory],
|
|
33
|
+
video_configs: VideoConfigs = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.videos: Dict[str, DriverVideo] = {}
|
|
36
|
+
self.filler_videos: Dict[str, DriverVideo] = {}
|
|
37
|
+
|
|
38
|
+
self._video_configs = video_configs or VideoConfigs()
|
|
39
|
+
|
|
40
|
+
# Workspace directory
|
|
41
|
+
if isinstance(avatar_model_path, tuple):
|
|
42
|
+
self.avatar_model_path, self.temp_dir = avatar_model_path
|
|
43
|
+
else:
|
|
44
|
+
self.avatar_model_path = avatar_model_path
|
|
45
|
+
self.temp_dir = None
|
|
46
|
+
self.filler_frames_dir = Path(self.avatar_model_path) / "filler_videos"
|
|
47
|
+
self.similarity_cache_dir = Path(self.avatar_model_path) / "similarities"
|
|
48
|
+
|
|
49
|
+
# Similarity matrices between two videos
|
|
50
|
+
self.similarity_matrices: Dict[Tuple[str, str], np.ndarray] = {}
|
|
51
|
+
|
|
52
|
+
# Graph and frame buffer for output
|
|
53
|
+
self.graph = nx.DiGraph()
|
|
54
|
+
self.curr_node = None
|
|
55
|
+
self.frame_buffer: List[Frame] = []
|
|
56
|
+
|
|
57
|
+
# Cache all the paths from one node to another
|
|
58
|
+
self.path_cache: Dict[NodeID, Dict[NodeID, Tuple[float, List[NodeID]]]] = (
|
|
59
|
+
defaultdict(dict)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def cleanup(self):
|
|
63
|
+
"""Clean up the temporary directory if it exists."""
|
|
64
|
+
if self.temp_dir and Path(self.temp_dir.name).exists():
|
|
65
|
+
self.temp_dir.cleanup()
|
|
66
|
+
self.temp_dir = None
|
|
67
|
+
|
|
68
|
+
def __del__(self):
|
|
69
|
+
"""Clean up the temporary directory if it exists."""
|
|
70
|
+
self.cleanup()
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def videos_script(self) -> VideoScript:
|
|
74
|
+
"""Get the video script."""
|
|
75
|
+
return self._video_configs.videos_script
|
|
76
|
+
|
|
77
|
+
def update_runtime_configs(self, settings: Settings):
|
|
78
|
+
"""Update the runtime configs from the settings."""
|
|
79
|
+
self._video_configs.update_runtime_configs(settings)
|
|
80
|
+
|
|
81
|
+
def video_exists(self, name: str, is_action: bool = None) -> bool:
|
|
82
|
+
"""Check if the video exists in the navigator.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
name: The video name
|
|
86
|
+
is_action: If True, only check the action videos.
|
|
87
|
+
If False, only check the looping videos.
|
|
88
|
+
None, check both.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
True if the video exists.
|
|
92
|
+
"""
|
|
93
|
+
if name not in self.videos:
|
|
94
|
+
return False
|
|
95
|
+
if is_action is None:
|
|
96
|
+
return True
|
|
97
|
+
return isinstance(self.videos[name], SingleActionVideo) == is_action
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def action_videos(self) -> List[SingleActionVideo]:
|
|
101
|
+
"""Get all the action videos in the navigator."""
|
|
102
|
+
return [
|
|
103
|
+
video
|
|
104
|
+
for video in self.videos.values()
|
|
105
|
+
if isinstance(video, SingleActionVideo)
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def action_video_names(self) -> List[str]:
|
|
110
|
+
"""Get all the action video names in the navigator."""
|
|
111
|
+
return [video.video_name for video in self.action_videos]
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_single_video(
|
|
115
|
+
cls, video_file: str, inference_data_file: str = None
|
|
116
|
+
) -> "VideoGraphNavigator":
|
|
117
|
+
"""Create a navigator with a single video."""
|
|
118
|
+
video_configs = VideoConfigs.from_videofile(video_file, inference_data_file)
|
|
119
|
+
return cls(
|
|
120
|
+
avatar_model_path=Path(video_file).parent,
|
|
121
|
+
video_configs=video_configs,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def from_workspace(
|
|
126
|
+
cls,
|
|
127
|
+
avatar_model_path: str,
|
|
128
|
+
video_config_file: str = None,
|
|
129
|
+
extract_to_local: bool = False,
|
|
130
|
+
) -> "VideoGraphNavigator":
|
|
131
|
+
"""Create a navigator from the model."""
|
|
132
|
+
logger.info(f"Loading model from {avatar_model_path}")
|
|
133
|
+
avatar_model_path, temp_dir = unzip_tarfile(
|
|
134
|
+
avatar_model_path, extract_to_local=extract_to_local
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Load video configs
|
|
138
|
+
video_config_file = video_config_file or Path(avatar_model_path) / "videos.yaml"
|
|
139
|
+
if Path(video_config_file).exists():
|
|
140
|
+
video_configs = VideoConfigs.from_yaml(video_config_file)
|
|
141
|
+
elif (Path(avatar_model_path) / "videos").exists():
|
|
142
|
+
video_configs = VideoConfigs.from_videofolder(
|
|
143
|
+
Path(avatar_model_path) / "videos"
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
files = list(Path(avatar_model_path).glob("*"))
|
|
147
|
+
raise FileNotFoundError(
|
|
148
|
+
f"model not found in {avatar_model_path}, files: {files}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Update video files to absolute path
|
|
152
|
+
for video in video_configs.videos:
|
|
153
|
+
video.video_file = str(
|
|
154
|
+
Path(avatar_model_path).absolute() / video.video_file
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return cls(
|
|
158
|
+
avatar_model_path=(avatar_model_path, temp_dir),
|
|
159
|
+
video_configs=video_configs,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def load_workspace(
|
|
163
|
+
self, prepare_filler_frames: bool = True
|
|
164
|
+
) -> "VideoGraphNavigator":
|
|
165
|
+
"""Load the videos from workspace.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
prepare_filler_frames: If True, prepare filler frames for all the edges
|
|
169
|
+
"""
|
|
170
|
+
# Load videos
|
|
171
|
+
videos = self._video_configs.load_videos(video_root=self.avatar_model_path)
|
|
172
|
+
|
|
173
|
+
# Init the navigator
|
|
174
|
+
for video, config in zip(videos, self._video_configs.videos):
|
|
175
|
+
self.add_video(video, **config.adding_kwargs)
|
|
176
|
+
if prepare_filler_frames:
|
|
177
|
+
self.load_filler_frames_for_allnodes()
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
def update_path_cache(self):
|
|
181
|
+
"""Update the path cache for all the nodes in the graph."""
|
|
182
|
+
self.path_cache.clear()
|
|
183
|
+
for source in self.graph.nodes:
|
|
184
|
+
distance, path = nx.single_source_dijkstra(
|
|
185
|
+
self.graph, source, weight="distance"
|
|
186
|
+
)
|
|
187
|
+
self.path_cache[source] = {
|
|
188
|
+
target: (distance[target], path[target]) for target in path
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
def single_source_multi_target_dijkstra(
|
|
192
|
+
self, source: NodeID, targets: List[NodeID]
|
|
193
|
+
) -> Tuple[float, List[NodeID]]:
|
|
194
|
+
"""Find the shortest path from the source node to any of target nodes.
|
|
195
|
+
|
|
196
|
+
The method uses the path cache to speed up the computation.
|
|
197
|
+
Make sure to update the path cache after the graph is updated.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
source: Node to start the search from
|
|
201
|
+
targets: List of target nodes
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
nx.NetworkXNoPath: If no path is found from the source
|
|
205
|
+
to any of the target nodes
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Tuple[float, List[NodeID]]: The distance and the shortest path
|
|
209
|
+
from the source to one of the target nodes
|
|
210
|
+
"""
|
|
211
|
+
paths = [
|
|
212
|
+
self.path_cache[source][target]
|
|
213
|
+
for target in targets
|
|
214
|
+
if target in self.path_cache[source]
|
|
215
|
+
]
|
|
216
|
+
if not paths:
|
|
217
|
+
raise nx.NetworkXNoPath(f"No path found from {source} to {targets}")
|
|
218
|
+
shortest_path = min(paths, key=lambda x: x[0])
|
|
219
|
+
return shortest_path
|
|
220
|
+
|
|
221
|
+
def reset_buffer(self) -> None:
|
|
222
|
+
"""Reset the frame buffer."""
|
|
223
|
+
self.curr_node = None
|
|
224
|
+
self.frame_buffer = []
|
|
225
|
+
|
|
226
|
+
def add_edge(
|
|
227
|
+
self,
|
|
228
|
+
source_node: NodeID,
|
|
229
|
+
target_node: NodeID,
|
|
230
|
+
distance: float,
|
|
231
|
+
num_filler_frames: int = 0,
|
|
232
|
+
single_direction: bool = False,
|
|
233
|
+
cross_video: bool = False,
|
|
234
|
+
) -> None:
|
|
235
|
+
"""Add an edge to the graph.
|
|
236
|
+
|
|
237
|
+
Set two metadata for the edge: distance and num_filler_frames.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
source_node: The source node of the edge
|
|
241
|
+
target_node: The target node of the edge
|
|
242
|
+
distance: The distance between the two nodes, added a penalty for
|
|
243
|
+
cross-video edges with filler frames
|
|
244
|
+
num_filler_frames: The number of filler frames between the two nodes
|
|
245
|
+
single_direction: If the edge is only in one direction
|
|
246
|
+
"""
|
|
247
|
+
if cross_video:
|
|
248
|
+
distance += self.CROSS_VIDEO_PENALTY
|
|
249
|
+
metadata = {
|
|
250
|
+
"distance": distance,
|
|
251
|
+
"num_filler_frames": num_filler_frames,
|
|
252
|
+
"cross_video": cross_video,
|
|
253
|
+
}
|
|
254
|
+
self.graph.add_edge(source_node, target_node, **metadata)
|
|
255
|
+
if not single_direction:
|
|
256
|
+
self.graph.add_edge(target_node, source_node, **metadata)
|
|
257
|
+
|
|
258
|
+
def get_first_frame(self, output_size: Optional[int] = None) -> np.ndarray:
|
|
259
|
+
if not self.videos:
|
|
260
|
+
raise ValueError("No videos is added.")
|
|
261
|
+
video = next(iter(self.videos.values()))
|
|
262
|
+
return video.get_first_frame(output_size)
|
|
263
|
+
|
|
264
|
+
def get_frame_wh(self, output_size: Optional[int] = None) -> Tuple[int, int]:
|
|
265
|
+
if not self.videos:
|
|
266
|
+
raise ValueError("No videos is added.")
|
|
267
|
+
video = next(iter(self.videos.values()))
|
|
268
|
+
return video.get_frame_wh(output_size)
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def num_frames(self) -> int:
|
|
272
|
+
return sum(len(video.frames) for video in self.videos.values())
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def num_nodes(self) -> int:
|
|
276
|
+
return self.graph.number_of_nodes()
|
|
277
|
+
|
|
278
|
+
@property
|
|
279
|
+
def num_edges(self) -> int:
|
|
280
|
+
return self.graph.number_of_edges()
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def edges_with_filler_frames(self) -> List[Tuple[NodeID, NodeID]]:
|
|
284
|
+
edges = []
|
|
285
|
+
seen = set()
|
|
286
|
+
for source, target in self.graph.edges:
|
|
287
|
+
if not self.graph[source][target].get("num_filler_frames"):
|
|
288
|
+
continue
|
|
289
|
+
key = tuple(sorted([source, target]))
|
|
290
|
+
if key in seen:
|
|
291
|
+
continue
|
|
292
|
+
seen.add(key)
|
|
293
|
+
edges.append((source, target))
|
|
294
|
+
return edges
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def num_filler_frames(self) -> int:
|
|
298
|
+
return sum(
|
|
299
|
+
self.graph[source][target]["num_filler_frames"]
|
|
300
|
+
for source, target in self.edges_with_filler_frames
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def load_similarity_matrix(
|
|
304
|
+
self, video1: DriverVideo, video2: DriverVideo
|
|
305
|
+
) -> np.ndarray:
|
|
306
|
+
sorted_videos = sorted([video1, video2], key=lambda x: x.video_id)
|
|
307
|
+
transition_frames = [
|
|
308
|
+
[node.frame_index for node in video.transition_nodes]
|
|
309
|
+
for video in sorted_videos
|
|
310
|
+
]
|
|
311
|
+
nodes_hash = hashlib.md5(str(transition_frames).encode()).hexdigest()
|
|
312
|
+
sim_cache_file = (
|
|
313
|
+
Path(self.similarity_cache_dir)
|
|
314
|
+
/ f"{sorted_videos[0].video_id}-{sorted_videos[1].video_id}"
|
|
315
|
+
f"-{nodes_hash[:8]}.npy"
|
|
316
|
+
)
|
|
317
|
+
if not Path(sim_cache_file).exists():
|
|
318
|
+
raise FileNotFoundError(
|
|
319
|
+
f"Similarity matrix not found between {video1.video_name} and {video2.video_name}."
|
|
320
|
+
)
|
|
321
|
+
sim_matrix = np.load(sim_cache_file)
|
|
322
|
+
if sorted_videos[0].video_id == video1.video_id:
|
|
323
|
+
return sim_matrix
|
|
324
|
+
else:
|
|
325
|
+
return sim_matrix.T
|
|
326
|
+
|
|
327
|
+
def load_filler_frames_for_allnodes(self) -> "VideoGraphNavigator":
|
|
328
|
+
"""Generate filler frames for all the edges in the graph."""
|
|
329
|
+
logger.trace(
|
|
330
|
+
f"Load filler frames for {len(self.edges_with_filler_frames)} edges, "
|
|
331
|
+
f"total {self.num_filler_frames} frames.",
|
|
332
|
+
)
|
|
333
|
+
# Load videos sequentially without ThreadPool
|
|
334
|
+
for source, target in self.edges_with_filler_frames:
|
|
335
|
+
self.load_filler_frames(source, target)
|
|
336
|
+
|
|
337
|
+
self.update_path_cache()
|
|
338
|
+
return self
|
|
339
|
+
|
|
340
|
+
def load_filler_frames(
|
|
341
|
+
self, source_node: NodeID, target_node: NodeID, lock: Lock = None
|
|
342
|
+
) -> int:
|
|
343
|
+
"""Generate filler frames between two nodes.
|
|
344
|
+
|
|
345
|
+
Returns the number of filler frames loaded.
|
|
346
|
+
"""
|
|
347
|
+
edge = self.graph.get_edge_data(source_node, target_node)
|
|
348
|
+
if not edge or not edge.get("num_filler_frames"):
|
|
349
|
+
return 0
|
|
350
|
+
|
|
351
|
+
lock = lock or nullcontext()
|
|
352
|
+
num_filler_frames = edge["num_filler_frames"]
|
|
353
|
+
|
|
354
|
+
filler_name = self._get_filler_video_name(source_node, target_node)
|
|
355
|
+
filler_video_path = self._get_filler_video_path(source_node, target_node)
|
|
356
|
+
if not Path(filler_video_path).exists():
|
|
357
|
+
logger.warning(
|
|
358
|
+
f"Filler video {filler_video_path} not found, "
|
|
359
|
+
f"skip the {num_filler_frames} filler frames between "
|
|
360
|
+
f"{source_node} and {target_node}."
|
|
361
|
+
)
|
|
362
|
+
return 0
|
|
363
|
+
|
|
364
|
+
# Load the filler video
|
|
365
|
+
filler_video = DriverVideo(name=filler_name, video_path=filler_video_path)
|
|
366
|
+
self.filler_videos[filler_name] = filler_video
|
|
367
|
+
|
|
368
|
+
# Update edge distance
|
|
369
|
+
with lock:
|
|
370
|
+
self.graph[source_node][target_node]["distance"] = (
|
|
371
|
+
1 + filler_video.num_frames + self.CROSS_VIDEO_PENALTY
|
|
372
|
+
)
|
|
373
|
+
return filler_video.num_frames
|
|
374
|
+
|
|
375
|
+
def _get_filler_video_name(self, source: NodeID, target: NodeID) -> str:
|
|
376
|
+
"""Get the filler video name."""
|
|
377
|
+
key = tuple(sorted([source, target]))
|
|
378
|
+
n1, n2 = key
|
|
379
|
+
return f"Filler_{n1}-{n2}"
|
|
380
|
+
|
|
381
|
+
def _get_filler_video_path(self, source: NodeID, target: NodeID) -> str:
|
|
382
|
+
"""Get the cache path for the filler frames."""
|
|
383
|
+
key = tuple(sorted([source, target]))
|
|
384
|
+
n1, n2 = key
|
|
385
|
+
video_file = str(self.filler_frames_dir / f"{n1}-{n2}.mp4")
|
|
386
|
+
return video_file
|
|
387
|
+
|
|
388
|
+
def get_filler_frames(
|
|
389
|
+
self, source_node: NodeID, target_node: NodeID
|
|
390
|
+
) -> List[Frame]:
|
|
391
|
+
"""Get the filler frames between two nodes."""
|
|
392
|
+
edge = self.graph.get_edge_data(source_node, target_node)
|
|
393
|
+
if edge is None or not edge.get("num_filler_frames"):
|
|
394
|
+
# No filler frames between the two nodes
|
|
395
|
+
return []
|
|
396
|
+
|
|
397
|
+
filler_name = self._get_filler_video_name(source_node, target_node)
|
|
398
|
+
if filler_name not in self.filler_videos:
|
|
399
|
+
logger.warning(
|
|
400
|
+
f"Filler video not found between {source_node} and {target_node}, "
|
|
401
|
+
f"expected {edge['num_filler_frames']} frames. "
|
|
402
|
+
"Please generate the filler frames first."
|
|
403
|
+
)
|
|
404
|
+
return []
|
|
405
|
+
|
|
406
|
+
frames = self.filler_videos[filler_name].frames.copy()
|
|
407
|
+
if source_node > target_node:
|
|
408
|
+
# reverse the frames if the source node is after the target node
|
|
409
|
+
frames = frames[::-1]
|
|
410
|
+
return frames
|
|
411
|
+
|
|
412
|
+
def add_video(
|
|
413
|
+
self,
|
|
414
|
+
video: DriverVideo,
|
|
415
|
+
edge_threshold: float = 0.7,
|
|
416
|
+
connects_to: List[str] = None,
|
|
417
|
+
num_filler_frames: int = None,
|
|
418
|
+
) -> None:
|
|
419
|
+
"""Add a video to the navigator.
|
|
420
|
+
|
|
421
|
+
Connect the video with all the existing videos in the navigator.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
video: The video to add
|
|
425
|
+
edge_threshold: The threshold for the edge weight
|
|
426
|
+
connects_to: The video names to connect to the new video,
|
|
427
|
+
if None, connect to all the existing videos
|
|
428
|
+
num_filler_frames: The number of filler frames between the two videos
|
|
429
|
+
"""
|
|
430
|
+
if video.video_name in self.videos:
|
|
431
|
+
raise ValueError(f"Video {video.video_name} is already added.")
|
|
432
|
+
|
|
433
|
+
self.videos[video.video_name] = video
|
|
434
|
+
self.graph.update(video.as_graph())
|
|
435
|
+
|
|
436
|
+
for target_video in self.videos.values():
|
|
437
|
+
if target_video.video_hash == video.video_hash:
|
|
438
|
+
continue
|
|
439
|
+
if connects_to and target_video.video_name not in connects_to:
|
|
440
|
+
continue
|
|
441
|
+
self.connect_two_videos(
|
|
442
|
+
target_video,
|
|
443
|
+
video,
|
|
444
|
+
edge_threshold=edge_threshold,
|
|
445
|
+
num_filler_frames=num_filler_frames,
|
|
446
|
+
)
|
|
447
|
+
self.update_path_cache()
|
|
448
|
+
|
|
449
|
+
def connect_two_videos(
|
|
450
|
+
self,
|
|
451
|
+
video1: DriverVideo,
|
|
452
|
+
video2: DriverVideo,
|
|
453
|
+
edge_threshold: float = 0.7,
|
|
454
|
+
num_filler_frames: int = None,
|
|
455
|
+
) -> int:
|
|
456
|
+
if video1.video_hash == video2.video_hash:
|
|
457
|
+
# TODO: support the same video with different frames
|
|
458
|
+
return 0
|
|
459
|
+
|
|
460
|
+
key = (video1.video_id, video2.video_id)
|
|
461
|
+
if key not in self.similarity_matrices:
|
|
462
|
+
sim_matrix = self.load_similarity_matrix(video1, video2)
|
|
463
|
+
self.similarity_matrices[key] = sim_matrix
|
|
464
|
+
self.similarity_matrices[key[::-1]] = sim_matrix.T
|
|
465
|
+
|
|
466
|
+
sim_matrix = self.similarity_matrices[key]
|
|
467
|
+
|
|
468
|
+
# Add nodes and edges to the graph
|
|
469
|
+
def get_num_fillers(similarity: float):
|
|
470
|
+
if similarity > 0.98:
|
|
471
|
+
return 0
|
|
472
|
+
if similarity > 0.80:
|
|
473
|
+
return 3
|
|
474
|
+
if similarity > 0.70:
|
|
475
|
+
return 7
|
|
476
|
+
return 7
|
|
477
|
+
|
|
478
|
+
new_edges = set()
|
|
479
|
+
|
|
480
|
+
def connect(video1: DriverVideo, video2: DriverVideo, sim_matrix):
|
|
481
|
+
sim_matrix = sim_matrix.copy()
|
|
482
|
+
if isinstance(video1, SingleActionVideo) and video1.single_direction:
|
|
483
|
+
# Only nodes after the action node are out-nodes
|
|
484
|
+
invalid_indices = [
|
|
485
|
+
i
|
|
486
|
+
for i, n in enumerate(video1.transition_nodes)
|
|
487
|
+
if n < video1.action_node
|
|
488
|
+
]
|
|
489
|
+
sim_matrix[invalid_indices] = -1
|
|
490
|
+
if isinstance(video2, SingleActionVideo):
|
|
491
|
+
# Only nodes before the action node are in-nodes
|
|
492
|
+
invalid_indices = [
|
|
493
|
+
i
|
|
494
|
+
for i, n in enumerate(video2.transition_nodes)
|
|
495
|
+
if n >= video2.action_node
|
|
496
|
+
]
|
|
497
|
+
sim_matrix[:, invalid_indices] = -1
|
|
498
|
+
|
|
499
|
+
argmax = np.argmax(sim_matrix, axis=1) # video1 -> video2
|
|
500
|
+
indices = [(i, j, sim_matrix[i, j]) for i, j in enumerate(argmax)]
|
|
501
|
+
indices = list(filter(lambda x: x[2] > edge_threshold, indices))
|
|
502
|
+
|
|
503
|
+
for i, j, score in indices:
|
|
504
|
+
# if isinstance(video1, SingleActionVideo) and video1.single_direction:
|
|
505
|
+
# # Only nodes after the action node are out-nodes
|
|
506
|
+
# if video1.transition_nodes[i] < video1.action_node:
|
|
507
|
+
# continue
|
|
508
|
+
# if isinstance(video2, SingleActionVideo):
|
|
509
|
+
# # Only nodes before the action node are in-nodes
|
|
510
|
+
# print(video2.transition_nodes[j], video2.action_node)
|
|
511
|
+
# if video2.transition_nodes[j] >= video2.action_node:
|
|
512
|
+
# continue
|
|
513
|
+
|
|
514
|
+
edge_filler_frames = (
|
|
515
|
+
get_num_fillers(score)
|
|
516
|
+
if num_filler_frames is None
|
|
517
|
+
else num_filler_frames
|
|
518
|
+
)
|
|
519
|
+
self.add_edge(
|
|
520
|
+
video1.transition_nodes[i],
|
|
521
|
+
video2.transition_nodes[j],
|
|
522
|
+
distance=edge_filler_frames + 1,
|
|
523
|
+
num_filler_frames=edge_filler_frames,
|
|
524
|
+
single_direction=True,
|
|
525
|
+
cross_video=video1.video_id != video2.video_id,
|
|
526
|
+
)
|
|
527
|
+
key = tuple(
|
|
528
|
+
sorted([video1.transition_nodes[i], video2.transition_nodes[j]])
|
|
529
|
+
)
|
|
530
|
+
new_edges.add(key)
|
|
531
|
+
|
|
532
|
+
connect(video1, video2, sim_matrix)
|
|
533
|
+
connect(video2, video1, sim_matrix.T)
|
|
534
|
+
|
|
535
|
+
logger.trace(
|
|
536
|
+
f"Connect {video1.video_name} and "
|
|
537
|
+
f"{video2.video_name}, {len(new_edges)} edges."
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
def find_path(
|
|
541
|
+
self, source: NodeID, target: str | NodeID
|
|
542
|
+
) -> Tuple[float, List[NodeID]]:
|
|
543
|
+
"""Find the shortest path from the source node to the target video or node.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
source: The source node
|
|
547
|
+
target: The target video name or node
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
The distance and the shortest path from the source to the target
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
def count_cross_video_penalty(path: List[NodeID]) -> int:
|
|
554
|
+
penalty = 0
|
|
555
|
+
for i in range(1, len(path)):
|
|
556
|
+
edge = self.graph.get_edge_data(path[i - 1], path[i])
|
|
557
|
+
if edge is None or not edge.get("cross_video"):
|
|
558
|
+
continue
|
|
559
|
+
penalty += self.CROSS_VIDEO_PENALTY
|
|
560
|
+
return penalty
|
|
561
|
+
|
|
562
|
+
if isinstance(target, NodeID):
|
|
563
|
+
distance, path = self.single_source_multi_target_dijkstra(source, [target])
|
|
564
|
+
return distance - count_cross_video_penalty(path), path
|
|
565
|
+
|
|
566
|
+
target_video = self.videos[target]
|
|
567
|
+
if target_video.video_hash == source.video_hash:
|
|
568
|
+
# Already in the target video
|
|
569
|
+
return 0, [source]
|
|
570
|
+
|
|
571
|
+
# find the shortest path to a node of the target video
|
|
572
|
+
distance, path = self.single_source_multi_target_dijkstra(
|
|
573
|
+
source, target_video.transition_nodes
|
|
574
|
+
)
|
|
575
|
+
return distance - count_cross_video_penalty(path), path
|
|
576
|
+
|
|
577
|
+
def collect_path_frames(self, path: List[NodeID]) -> List[Frame]:
|
|
578
|
+
"""Collect frames from the path."""
|
|
579
|
+
frames = []
|
|
580
|
+
for i in range(len(path) - 1):
|
|
581
|
+
source, target = path[i], path[i + 1]
|
|
582
|
+
frames += self.videos[source.video_name].collect_frames(source, target)
|
|
583
|
+
# Add filler frames if they exist
|
|
584
|
+
frames += self.get_filler_frames(source, target)
|
|
585
|
+
return frames
|
|
586
|
+
|
|
587
|
+
def collect_n_frames(
|
|
588
|
+
self,
|
|
589
|
+
min_n: int,
|
|
590
|
+
target_video_name: str = None,
|
|
591
|
+
actions_name: List[str] | str = None,
|
|
592
|
+
) -> List[Frame]:
|
|
593
|
+
"""Collect at least min_n frames from the navigator.
|
|
594
|
+
|
|
595
|
+
The frames are collected from the current node to the target video.
|
|
596
|
+
"""
|
|
597
|
+
target_video_name, actions_name = self.filter_video_names(
|
|
598
|
+
target_video_name, actions_name
|
|
599
|
+
)
|
|
600
|
+
if min_n <= 0 and not actions_name and not target_video_name:
|
|
601
|
+
return []
|
|
602
|
+
|
|
603
|
+
if self.curr_node is None:
|
|
604
|
+
if target_video_name is None:
|
|
605
|
+
self.curr_node = list(self.videos.values())[0].nodes[0]
|
|
606
|
+
logger.trace(
|
|
607
|
+
f"Current node is not set, set to the first node of "
|
|
608
|
+
f"{self.curr_node.video_name}."
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
self.curr_node = self.videos[target_video_name].nodes[0]
|
|
612
|
+
|
|
613
|
+
target_video_name = target_video_name or self.curr_node.video_name
|
|
614
|
+
target_videos = [self.videos[target_video_name]]
|
|
615
|
+
|
|
616
|
+
if isinstance(actions_name, str):
|
|
617
|
+
actions_name = [actions_name]
|
|
618
|
+
if actions_name:
|
|
619
|
+
target_videos = [self.videos[name] for name in actions_name] + target_videos
|
|
620
|
+
|
|
621
|
+
assert isinstance(target_videos[-1], LoopingVideo), (
|
|
622
|
+
"The target video must be a LoopingVideo"
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# [action_1, action_2, ..., target_video]
|
|
626
|
+
total_path = [self.curr_node]
|
|
627
|
+
total_distance = 0
|
|
628
|
+
for video in target_videos:
|
|
629
|
+
target = (
|
|
630
|
+
video.action_node
|
|
631
|
+
if isinstance(video, SingleActionVideo)
|
|
632
|
+
else video.video_name
|
|
633
|
+
)
|
|
634
|
+
distance, path = self.find_path(total_path[-1], target)
|
|
635
|
+
|
|
636
|
+
total_path += path[1:]
|
|
637
|
+
total_distance += distance
|
|
638
|
+
|
|
639
|
+
frames = self.collect_path_frames(total_path)
|
|
640
|
+
if total_distance != len(frames):
|
|
641
|
+
logger.warning(
|
|
642
|
+
f"Distance mismatch: {total_distance} != {len(frames)}, "
|
|
643
|
+
f"Path: {total_path}"
|
|
644
|
+
)
|
|
645
|
+
# Collect frames from the target video
|
|
646
|
+
n_left = min_n - len(frames)
|
|
647
|
+
last_node = total_path[-2] if len(total_path) > 1 else None
|
|
648
|
+
target_frames, last_node = target_videos[-1].get_n_frames(
|
|
649
|
+
n_left, start=total_path[-1], last_position=last_node
|
|
650
|
+
)
|
|
651
|
+
frames += target_frames
|
|
652
|
+
self.curr_node = last_node or total_path[-1]
|
|
653
|
+
|
|
654
|
+
logger.trace(
|
|
655
|
+
f"Path: {total_path} -> {self.curr_node}, path len: {total_distance}, "
|
|
656
|
+
f"rest: {len(target_frames)}, total: {len(frames)}/{min_n}"
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
return frames
|
|
660
|
+
|
|
661
|
+
def next_n_frames(
|
|
662
|
+
self,
|
|
663
|
+
num_frames: int,
|
|
664
|
+
target_video_name: str = None,
|
|
665
|
+
actions_name: List[str] | str = None,
|
|
666
|
+
on_user_speech: bool = False,
|
|
667
|
+
on_agent_speech: bool = False,
|
|
668
|
+
stop_on_user_speech_override: Optional[bool] = None,
|
|
669
|
+
stop_on_agent_speech_override: Optional[bool] = None,
|
|
670
|
+
) -> List[Frame]:
|
|
671
|
+
"""Get the next n frames from the navigator.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
num_frames: The number of frames to get
|
|
675
|
+
target_video_name: The target video name. Keep the current video if None.
|
|
676
|
+
actions_name: The actions before the target video.
|
|
677
|
+
on_user_speech: Whether user is currently speaking
|
|
678
|
+
on_agent_speech: Whether agent is currently speaking
|
|
679
|
+
stop_on_user_speech_override: Override stop_on_user_speech from video config if provided
|
|
680
|
+
stop_on_agent_speech_override: Override stop_on_agent_speech from video config if provided
|
|
681
|
+
"""
|
|
682
|
+
if self.frame_buffer:
|
|
683
|
+
video = self.videos.get(self.frame_buffer[0].video_name)
|
|
684
|
+
if video:
|
|
685
|
+
# Use override values if provided, otherwise use video's default values
|
|
686
|
+
stop_on_user = (
|
|
687
|
+
stop_on_user_speech_override
|
|
688
|
+
if stop_on_user_speech_override is not None
|
|
689
|
+
else video.stop_on_user_speech
|
|
690
|
+
)
|
|
691
|
+
stop_on_agent = (
|
|
692
|
+
stop_on_agent_speech_override
|
|
693
|
+
if stop_on_agent_speech_override is not None
|
|
694
|
+
else video.stop_on_agent_speech
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
if (on_user_speech and stop_on_user) or (
|
|
698
|
+
on_agent_speech and stop_on_agent
|
|
699
|
+
):
|
|
700
|
+
self.reset_buffer()
|
|
701
|
+
logger.trace(
|
|
702
|
+
f"Stop on {video.video_name} because of {on_user_speech=} or "
|
|
703
|
+
f"{on_agent_speech=} (stop_on_user={stop_on_user}, stop_on_agent={stop_on_agent})"
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
if num_frames <= 0:
|
|
707
|
+
return []
|
|
708
|
+
|
|
709
|
+
min_n = num_frames - len(self.frame_buffer)
|
|
710
|
+
self.frame_buffer += self.collect_n_frames(
|
|
711
|
+
min_n, target_video_name=target_video_name, actions_name=actions_name
|
|
712
|
+
)
|
|
713
|
+
frames = self.frame_buffer[:num_frames]
|
|
714
|
+
self.frame_buffer = self.frame_buffer[num_frames:]
|
|
715
|
+
return frames
|
|
716
|
+
|
|
717
|
+
def filter_video_names(
|
|
718
|
+
self, target_video: Optional[str], actions: Optional[List[str] | str]
|
|
719
|
+
) -> Tuple[str, List[str]]:
|
|
720
|
+
"""Filter the target video and actions."""
|
|
721
|
+
if target_video:
|
|
722
|
+
if not self.video_exists(target_video, is_action=False):
|
|
723
|
+
logger.warning(f"Invalid video name: {target_video}")
|
|
724
|
+
target_video = None
|
|
725
|
+
|
|
726
|
+
if actions:
|
|
727
|
+
if isinstance(actions, str):
|
|
728
|
+
actions = [actions]
|
|
729
|
+
valid_actions = []
|
|
730
|
+
for action in actions:
|
|
731
|
+
if self.video_exists(action, is_action=True):
|
|
732
|
+
valid_actions.append(action)
|
|
733
|
+
else:
|
|
734
|
+
logger.warning(f"Invalid action name: {action}")
|
|
735
|
+
actions = valid_actions
|
|
736
|
+
return target_video, actions
|