bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bithuman might be problematic. Click here for more details.

Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,736 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ from collections import defaultdict
5
+ from contextlib import nullcontext
6
+ from pathlib import Path
7
+ from tempfile import TemporaryDirectory
8
+ from threading import Lock
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import networkx as nx
12
+ import numpy as np
13
+ from loguru import logger
14
+
15
+ from bithuman.config import Settings
16
+ from bithuman.utils.unzip import unzip_tarfile
17
+ from bithuman.video_graph.driver_video import (
18
+ DriverVideo,
19
+ Frame,
20
+ LoopingVideo,
21
+ NodeID,
22
+ SingleActionVideo,
23
+ )
24
+ from bithuman.video_graph.video_script import VideoConfigs, VideoScript
25
+
26
+
27
+ class VideoGraphNavigator:
28
+ CROSS_VIDEO_PENALTY = 30
29
+
30
+ def __init__(
31
+ self,
32
+ avatar_model_path: str | Tuple[str, TemporaryDirectory],
33
+ video_configs: VideoConfigs = None,
34
+ ) -> None:
35
+ self.videos: Dict[str, DriverVideo] = {}
36
+ self.filler_videos: Dict[str, DriverVideo] = {}
37
+
38
+ self._video_configs = video_configs or VideoConfigs()
39
+
40
+ # Workspace directory
41
+ if isinstance(avatar_model_path, tuple):
42
+ self.avatar_model_path, self.temp_dir = avatar_model_path
43
+ else:
44
+ self.avatar_model_path = avatar_model_path
45
+ self.temp_dir = None
46
+ self.filler_frames_dir = Path(self.avatar_model_path) / "filler_videos"
47
+ self.similarity_cache_dir = Path(self.avatar_model_path) / "similarities"
48
+
49
+ # Similarity matrices between two videos
50
+ self.similarity_matrices: Dict[Tuple[str, str], np.ndarray] = {}
51
+
52
+ # Graph and frame buffer for output
53
+ self.graph = nx.DiGraph()
54
+ self.curr_node = None
55
+ self.frame_buffer: List[Frame] = []
56
+
57
+ # Cache all the paths from one node to another
58
+ self.path_cache: Dict[NodeID, Dict[NodeID, Tuple[float, List[NodeID]]]] = (
59
+ defaultdict(dict)
60
+ )
61
+
62
+ def cleanup(self):
63
+ """Clean up the temporary directory if it exists."""
64
+ if self.temp_dir and Path(self.temp_dir.name).exists():
65
+ self.temp_dir.cleanup()
66
+ self.temp_dir = None
67
+
68
+ def __del__(self):
69
+ """Clean up the temporary directory if it exists."""
70
+ self.cleanup()
71
+
72
+ @property
73
+ def videos_script(self) -> VideoScript:
74
+ """Get the video script."""
75
+ return self._video_configs.videos_script
76
+
77
+ def update_runtime_configs(self, settings: Settings):
78
+ """Update the runtime configs from the settings."""
79
+ self._video_configs.update_runtime_configs(settings)
80
+
81
+ def video_exists(self, name: str, is_action: bool = None) -> bool:
82
+ """Check if the video exists in the navigator.
83
+
84
+ Args:
85
+ name: The video name
86
+ is_action: If True, only check the action videos.
87
+ If False, only check the looping videos.
88
+ None, check both.
89
+
90
+ Returns:
91
+ True if the video exists.
92
+ """
93
+ if name not in self.videos:
94
+ return False
95
+ if is_action is None:
96
+ return True
97
+ return isinstance(self.videos[name], SingleActionVideo) == is_action
98
+
99
+ @property
100
+ def action_videos(self) -> List[SingleActionVideo]:
101
+ """Get all the action videos in the navigator."""
102
+ return [
103
+ video
104
+ for video in self.videos.values()
105
+ if isinstance(video, SingleActionVideo)
106
+ ]
107
+
108
+ @property
109
+ def action_video_names(self) -> List[str]:
110
+ """Get all the action video names in the navigator."""
111
+ return [video.video_name for video in self.action_videos]
112
+
113
+ @classmethod
114
+ def from_single_video(
115
+ cls, video_file: str, inference_data_file: str = None
116
+ ) -> "VideoGraphNavigator":
117
+ """Create a navigator with a single video."""
118
+ video_configs = VideoConfigs.from_videofile(video_file, inference_data_file)
119
+ return cls(
120
+ avatar_model_path=Path(video_file).parent,
121
+ video_configs=video_configs,
122
+ )
123
+
124
+ @classmethod
125
+ def from_workspace(
126
+ cls,
127
+ avatar_model_path: str,
128
+ video_config_file: str = None,
129
+ extract_to_local: bool = False,
130
+ ) -> "VideoGraphNavigator":
131
+ """Create a navigator from the model."""
132
+ logger.info(f"Loading model from {avatar_model_path}")
133
+ avatar_model_path, temp_dir = unzip_tarfile(
134
+ avatar_model_path, extract_to_local=extract_to_local
135
+ )
136
+
137
+ # Load video configs
138
+ video_config_file = video_config_file or Path(avatar_model_path) / "videos.yaml"
139
+ if Path(video_config_file).exists():
140
+ video_configs = VideoConfigs.from_yaml(video_config_file)
141
+ elif (Path(avatar_model_path) / "videos").exists():
142
+ video_configs = VideoConfigs.from_videofolder(
143
+ Path(avatar_model_path) / "videos"
144
+ )
145
+ else:
146
+ files = list(Path(avatar_model_path).glob("*"))
147
+ raise FileNotFoundError(
148
+ f"model not found in {avatar_model_path}, files: {files}"
149
+ )
150
+
151
+ # Update video files to absolute path
152
+ for video in video_configs.videos:
153
+ video.video_file = str(
154
+ Path(avatar_model_path).absolute() / video.video_file
155
+ )
156
+
157
+ return cls(
158
+ avatar_model_path=(avatar_model_path, temp_dir),
159
+ video_configs=video_configs,
160
+ )
161
+
162
+ def load_workspace(
163
+ self, prepare_filler_frames: bool = True
164
+ ) -> "VideoGraphNavigator":
165
+ """Load the videos from workspace.
166
+
167
+ Args:
168
+ prepare_filler_frames: If True, prepare filler frames for all the edges
169
+ """
170
+ # Load videos
171
+ videos = self._video_configs.load_videos(video_root=self.avatar_model_path)
172
+
173
+ # Init the navigator
174
+ for video, config in zip(videos, self._video_configs.videos):
175
+ self.add_video(video, **config.adding_kwargs)
176
+ if prepare_filler_frames:
177
+ self.load_filler_frames_for_allnodes()
178
+ return self
179
+
180
+ def update_path_cache(self):
181
+ """Update the path cache for all the nodes in the graph."""
182
+ self.path_cache.clear()
183
+ for source in self.graph.nodes:
184
+ distance, path = nx.single_source_dijkstra(
185
+ self.graph, source, weight="distance"
186
+ )
187
+ self.path_cache[source] = {
188
+ target: (distance[target], path[target]) for target in path
189
+ }
190
+
191
+ def single_source_multi_target_dijkstra(
192
+ self, source: NodeID, targets: List[NodeID]
193
+ ) -> Tuple[float, List[NodeID]]:
194
+ """Find the shortest path from the source node to any of target nodes.
195
+
196
+ The method uses the path cache to speed up the computation.
197
+ Make sure to update the path cache after the graph is updated.
198
+
199
+ Args:
200
+ source: Node to start the search from
201
+ targets: List of target nodes
202
+
203
+ Raises:
204
+ nx.NetworkXNoPath: If no path is found from the source
205
+ to any of the target nodes
206
+
207
+ Returns:
208
+ Tuple[float, List[NodeID]]: The distance and the shortest path
209
+ from the source to one of the target nodes
210
+ """
211
+ paths = [
212
+ self.path_cache[source][target]
213
+ for target in targets
214
+ if target in self.path_cache[source]
215
+ ]
216
+ if not paths:
217
+ raise nx.NetworkXNoPath(f"No path found from {source} to {targets}")
218
+ shortest_path = min(paths, key=lambda x: x[0])
219
+ return shortest_path
220
+
221
+ def reset_buffer(self) -> None:
222
+ """Reset the frame buffer."""
223
+ self.curr_node = None
224
+ self.frame_buffer = []
225
+
226
+ def add_edge(
227
+ self,
228
+ source_node: NodeID,
229
+ target_node: NodeID,
230
+ distance: float,
231
+ num_filler_frames: int = 0,
232
+ single_direction: bool = False,
233
+ cross_video: bool = False,
234
+ ) -> None:
235
+ """Add an edge to the graph.
236
+
237
+ Set two metadata for the edge: distance and num_filler_frames.
238
+
239
+ Args:
240
+ source_node: The source node of the edge
241
+ target_node: The target node of the edge
242
+ distance: The distance between the two nodes, added a penalty for
243
+ cross-video edges with filler frames
244
+ num_filler_frames: The number of filler frames between the two nodes
245
+ single_direction: If the edge is only in one direction
246
+ """
247
+ if cross_video:
248
+ distance += self.CROSS_VIDEO_PENALTY
249
+ metadata = {
250
+ "distance": distance,
251
+ "num_filler_frames": num_filler_frames,
252
+ "cross_video": cross_video,
253
+ }
254
+ self.graph.add_edge(source_node, target_node, **metadata)
255
+ if not single_direction:
256
+ self.graph.add_edge(target_node, source_node, **metadata)
257
+
258
+ def get_first_frame(self, output_size: Optional[int] = None) -> np.ndarray:
259
+ if not self.videos:
260
+ raise ValueError("No videos is added.")
261
+ video = next(iter(self.videos.values()))
262
+ return video.get_first_frame(output_size)
263
+
264
+ def get_frame_wh(self, output_size: Optional[int] = None) -> Tuple[int, int]:
265
+ if not self.videos:
266
+ raise ValueError("No videos is added.")
267
+ video = next(iter(self.videos.values()))
268
+ return video.get_frame_wh(output_size)
269
+
270
+ @property
271
+ def num_frames(self) -> int:
272
+ return sum(len(video.frames) for video in self.videos.values())
273
+
274
+ @property
275
+ def num_nodes(self) -> int:
276
+ return self.graph.number_of_nodes()
277
+
278
+ @property
279
+ def num_edges(self) -> int:
280
+ return self.graph.number_of_edges()
281
+
282
+ @property
283
+ def edges_with_filler_frames(self) -> List[Tuple[NodeID, NodeID]]:
284
+ edges = []
285
+ seen = set()
286
+ for source, target in self.graph.edges:
287
+ if not self.graph[source][target].get("num_filler_frames"):
288
+ continue
289
+ key = tuple(sorted([source, target]))
290
+ if key in seen:
291
+ continue
292
+ seen.add(key)
293
+ edges.append((source, target))
294
+ return edges
295
+
296
+ @property
297
+ def num_filler_frames(self) -> int:
298
+ return sum(
299
+ self.graph[source][target]["num_filler_frames"]
300
+ for source, target in self.edges_with_filler_frames
301
+ )
302
+
303
+ def load_similarity_matrix(
304
+ self, video1: DriverVideo, video2: DriverVideo
305
+ ) -> np.ndarray:
306
+ sorted_videos = sorted([video1, video2], key=lambda x: x.video_id)
307
+ transition_frames = [
308
+ [node.frame_index for node in video.transition_nodes]
309
+ for video in sorted_videos
310
+ ]
311
+ nodes_hash = hashlib.md5(str(transition_frames).encode()).hexdigest()
312
+ sim_cache_file = (
313
+ Path(self.similarity_cache_dir)
314
+ / f"{sorted_videos[0].video_id}-{sorted_videos[1].video_id}"
315
+ f"-{nodes_hash[:8]}.npy"
316
+ )
317
+ if not Path(sim_cache_file).exists():
318
+ raise FileNotFoundError(
319
+ f"Similarity matrix not found between {video1.video_name} and {video2.video_name}."
320
+ )
321
+ sim_matrix = np.load(sim_cache_file)
322
+ if sorted_videos[0].video_id == video1.video_id:
323
+ return sim_matrix
324
+ else:
325
+ return sim_matrix.T
326
+
327
+ def load_filler_frames_for_allnodes(self) -> "VideoGraphNavigator":
328
+ """Generate filler frames for all the edges in the graph."""
329
+ logger.trace(
330
+ f"Load filler frames for {len(self.edges_with_filler_frames)} edges, "
331
+ f"total {self.num_filler_frames} frames.",
332
+ )
333
+ # Load videos sequentially without ThreadPool
334
+ for source, target in self.edges_with_filler_frames:
335
+ self.load_filler_frames(source, target)
336
+
337
+ self.update_path_cache()
338
+ return self
339
+
340
+ def load_filler_frames(
341
+ self, source_node: NodeID, target_node: NodeID, lock: Lock = None
342
+ ) -> int:
343
+ """Generate filler frames between two nodes.
344
+
345
+ Returns the number of filler frames loaded.
346
+ """
347
+ edge = self.graph.get_edge_data(source_node, target_node)
348
+ if not edge or not edge.get("num_filler_frames"):
349
+ return 0
350
+
351
+ lock = lock or nullcontext()
352
+ num_filler_frames = edge["num_filler_frames"]
353
+
354
+ filler_name = self._get_filler_video_name(source_node, target_node)
355
+ filler_video_path = self._get_filler_video_path(source_node, target_node)
356
+ if not Path(filler_video_path).exists():
357
+ logger.warning(
358
+ f"Filler video {filler_video_path} not found, "
359
+ f"skip the {num_filler_frames} filler frames between "
360
+ f"{source_node} and {target_node}."
361
+ )
362
+ return 0
363
+
364
+ # Load the filler video
365
+ filler_video = DriverVideo(name=filler_name, video_path=filler_video_path)
366
+ self.filler_videos[filler_name] = filler_video
367
+
368
+ # Update edge distance
369
+ with lock:
370
+ self.graph[source_node][target_node]["distance"] = (
371
+ 1 + filler_video.num_frames + self.CROSS_VIDEO_PENALTY
372
+ )
373
+ return filler_video.num_frames
374
+
375
+ def _get_filler_video_name(self, source: NodeID, target: NodeID) -> str:
376
+ """Get the filler video name."""
377
+ key = tuple(sorted([source, target]))
378
+ n1, n2 = key
379
+ return f"Filler_{n1}-{n2}"
380
+
381
+ def _get_filler_video_path(self, source: NodeID, target: NodeID) -> str:
382
+ """Get the cache path for the filler frames."""
383
+ key = tuple(sorted([source, target]))
384
+ n1, n2 = key
385
+ video_file = str(self.filler_frames_dir / f"{n1}-{n2}.mp4")
386
+ return video_file
387
+
388
+ def get_filler_frames(
389
+ self, source_node: NodeID, target_node: NodeID
390
+ ) -> List[Frame]:
391
+ """Get the filler frames between two nodes."""
392
+ edge = self.graph.get_edge_data(source_node, target_node)
393
+ if edge is None or not edge.get("num_filler_frames"):
394
+ # No filler frames between the two nodes
395
+ return []
396
+
397
+ filler_name = self._get_filler_video_name(source_node, target_node)
398
+ if filler_name not in self.filler_videos:
399
+ logger.warning(
400
+ f"Filler video not found between {source_node} and {target_node}, "
401
+ f"expected {edge['num_filler_frames']} frames. "
402
+ "Please generate the filler frames first."
403
+ )
404
+ return []
405
+
406
+ frames = self.filler_videos[filler_name].frames.copy()
407
+ if source_node > target_node:
408
+ # reverse the frames if the source node is after the target node
409
+ frames = frames[::-1]
410
+ return frames
411
+
412
+ def add_video(
413
+ self,
414
+ video: DriverVideo,
415
+ edge_threshold: float = 0.7,
416
+ connects_to: List[str] = None,
417
+ num_filler_frames: int = None,
418
+ ) -> None:
419
+ """Add a video to the navigator.
420
+
421
+ Connect the video with all the existing videos in the navigator.
422
+
423
+ Args:
424
+ video: The video to add
425
+ edge_threshold: The threshold for the edge weight
426
+ connects_to: The video names to connect to the new video,
427
+ if None, connect to all the existing videos
428
+ num_filler_frames: The number of filler frames between the two videos
429
+ """
430
+ if video.video_name in self.videos:
431
+ raise ValueError(f"Video {video.video_name} is already added.")
432
+
433
+ self.videos[video.video_name] = video
434
+ self.graph.update(video.as_graph())
435
+
436
+ for target_video in self.videos.values():
437
+ if target_video.video_hash == video.video_hash:
438
+ continue
439
+ if connects_to and target_video.video_name not in connects_to:
440
+ continue
441
+ self.connect_two_videos(
442
+ target_video,
443
+ video,
444
+ edge_threshold=edge_threshold,
445
+ num_filler_frames=num_filler_frames,
446
+ )
447
+ self.update_path_cache()
448
+
449
+ def connect_two_videos(
450
+ self,
451
+ video1: DriverVideo,
452
+ video2: DriverVideo,
453
+ edge_threshold: float = 0.7,
454
+ num_filler_frames: int = None,
455
+ ) -> int:
456
+ if video1.video_hash == video2.video_hash:
457
+ # TODO: support the same video with different frames
458
+ return 0
459
+
460
+ key = (video1.video_id, video2.video_id)
461
+ if key not in self.similarity_matrices:
462
+ sim_matrix = self.load_similarity_matrix(video1, video2)
463
+ self.similarity_matrices[key] = sim_matrix
464
+ self.similarity_matrices[key[::-1]] = sim_matrix.T
465
+
466
+ sim_matrix = self.similarity_matrices[key]
467
+
468
+ # Add nodes and edges to the graph
469
+ def get_num_fillers(similarity: float):
470
+ if similarity > 0.98:
471
+ return 0
472
+ if similarity > 0.80:
473
+ return 3
474
+ if similarity > 0.70:
475
+ return 7
476
+ return 7
477
+
478
+ new_edges = set()
479
+
480
+ def connect(video1: DriverVideo, video2: DriverVideo, sim_matrix):
481
+ sim_matrix = sim_matrix.copy()
482
+ if isinstance(video1, SingleActionVideo) and video1.single_direction:
483
+ # Only nodes after the action node are out-nodes
484
+ invalid_indices = [
485
+ i
486
+ for i, n in enumerate(video1.transition_nodes)
487
+ if n < video1.action_node
488
+ ]
489
+ sim_matrix[invalid_indices] = -1
490
+ if isinstance(video2, SingleActionVideo):
491
+ # Only nodes before the action node are in-nodes
492
+ invalid_indices = [
493
+ i
494
+ for i, n in enumerate(video2.transition_nodes)
495
+ if n >= video2.action_node
496
+ ]
497
+ sim_matrix[:, invalid_indices] = -1
498
+
499
+ argmax = np.argmax(sim_matrix, axis=1) # video1 -> video2
500
+ indices = [(i, j, sim_matrix[i, j]) for i, j in enumerate(argmax)]
501
+ indices = list(filter(lambda x: x[2] > edge_threshold, indices))
502
+
503
+ for i, j, score in indices:
504
+ # if isinstance(video1, SingleActionVideo) and video1.single_direction:
505
+ # # Only nodes after the action node are out-nodes
506
+ # if video1.transition_nodes[i] < video1.action_node:
507
+ # continue
508
+ # if isinstance(video2, SingleActionVideo):
509
+ # # Only nodes before the action node are in-nodes
510
+ # print(video2.transition_nodes[j], video2.action_node)
511
+ # if video2.transition_nodes[j] >= video2.action_node:
512
+ # continue
513
+
514
+ edge_filler_frames = (
515
+ get_num_fillers(score)
516
+ if num_filler_frames is None
517
+ else num_filler_frames
518
+ )
519
+ self.add_edge(
520
+ video1.transition_nodes[i],
521
+ video2.transition_nodes[j],
522
+ distance=edge_filler_frames + 1,
523
+ num_filler_frames=edge_filler_frames,
524
+ single_direction=True,
525
+ cross_video=video1.video_id != video2.video_id,
526
+ )
527
+ key = tuple(
528
+ sorted([video1.transition_nodes[i], video2.transition_nodes[j]])
529
+ )
530
+ new_edges.add(key)
531
+
532
+ connect(video1, video2, sim_matrix)
533
+ connect(video2, video1, sim_matrix.T)
534
+
535
+ logger.trace(
536
+ f"Connect {video1.video_name} and "
537
+ f"{video2.video_name}, {len(new_edges)} edges."
538
+ )
539
+
540
+ def find_path(
541
+ self, source: NodeID, target: str | NodeID
542
+ ) -> Tuple[float, List[NodeID]]:
543
+ """Find the shortest path from the source node to the target video or node.
544
+
545
+ Args:
546
+ source: The source node
547
+ target: The target video name or node
548
+
549
+ Returns:
550
+ The distance and the shortest path from the source to the target
551
+ """
552
+
553
+ def count_cross_video_penalty(path: List[NodeID]) -> int:
554
+ penalty = 0
555
+ for i in range(1, len(path)):
556
+ edge = self.graph.get_edge_data(path[i - 1], path[i])
557
+ if edge is None or not edge.get("cross_video"):
558
+ continue
559
+ penalty += self.CROSS_VIDEO_PENALTY
560
+ return penalty
561
+
562
+ if isinstance(target, NodeID):
563
+ distance, path = self.single_source_multi_target_dijkstra(source, [target])
564
+ return distance - count_cross_video_penalty(path), path
565
+
566
+ target_video = self.videos[target]
567
+ if target_video.video_hash == source.video_hash:
568
+ # Already in the target video
569
+ return 0, [source]
570
+
571
+ # find the shortest path to a node of the target video
572
+ distance, path = self.single_source_multi_target_dijkstra(
573
+ source, target_video.transition_nodes
574
+ )
575
+ return distance - count_cross_video_penalty(path), path
576
+
577
+ def collect_path_frames(self, path: List[NodeID]) -> List[Frame]:
578
+ """Collect frames from the path."""
579
+ frames = []
580
+ for i in range(len(path) - 1):
581
+ source, target = path[i], path[i + 1]
582
+ frames += self.videos[source.video_name].collect_frames(source, target)
583
+ # Add filler frames if they exist
584
+ frames += self.get_filler_frames(source, target)
585
+ return frames
586
+
587
+ def collect_n_frames(
588
+ self,
589
+ min_n: int,
590
+ target_video_name: str = None,
591
+ actions_name: List[str] | str = None,
592
+ ) -> List[Frame]:
593
+ """Collect at least min_n frames from the navigator.
594
+
595
+ The frames are collected from the current node to the target video.
596
+ """
597
+ target_video_name, actions_name = self.filter_video_names(
598
+ target_video_name, actions_name
599
+ )
600
+ if min_n <= 0 and not actions_name and not target_video_name:
601
+ return []
602
+
603
+ if self.curr_node is None:
604
+ if target_video_name is None:
605
+ self.curr_node = list(self.videos.values())[0].nodes[0]
606
+ logger.trace(
607
+ f"Current node is not set, set to the first node of "
608
+ f"{self.curr_node.video_name}."
609
+ )
610
+ else:
611
+ self.curr_node = self.videos[target_video_name].nodes[0]
612
+
613
+ target_video_name = target_video_name or self.curr_node.video_name
614
+ target_videos = [self.videos[target_video_name]]
615
+
616
+ if isinstance(actions_name, str):
617
+ actions_name = [actions_name]
618
+ if actions_name:
619
+ target_videos = [self.videos[name] for name in actions_name] + target_videos
620
+
621
+ assert isinstance(target_videos[-1], LoopingVideo), (
622
+ "The target video must be a LoopingVideo"
623
+ )
624
+
625
+ # [action_1, action_2, ..., target_video]
626
+ total_path = [self.curr_node]
627
+ total_distance = 0
628
+ for video in target_videos:
629
+ target = (
630
+ video.action_node
631
+ if isinstance(video, SingleActionVideo)
632
+ else video.video_name
633
+ )
634
+ distance, path = self.find_path(total_path[-1], target)
635
+
636
+ total_path += path[1:]
637
+ total_distance += distance
638
+
639
+ frames = self.collect_path_frames(total_path)
640
+ if total_distance != len(frames):
641
+ logger.warning(
642
+ f"Distance mismatch: {total_distance} != {len(frames)}, "
643
+ f"Path: {total_path}"
644
+ )
645
+ # Collect frames from the target video
646
+ n_left = min_n - len(frames)
647
+ last_node = total_path[-2] if len(total_path) > 1 else None
648
+ target_frames, last_node = target_videos[-1].get_n_frames(
649
+ n_left, start=total_path[-1], last_position=last_node
650
+ )
651
+ frames += target_frames
652
+ self.curr_node = last_node or total_path[-1]
653
+
654
+ logger.trace(
655
+ f"Path: {total_path} -> {self.curr_node}, path len: {total_distance}, "
656
+ f"rest: {len(target_frames)}, total: {len(frames)}/{min_n}"
657
+ )
658
+
659
+ return frames
660
+
661
+ def next_n_frames(
662
+ self,
663
+ num_frames: int,
664
+ target_video_name: str = None,
665
+ actions_name: List[str] | str = None,
666
+ on_user_speech: bool = False,
667
+ on_agent_speech: bool = False,
668
+ stop_on_user_speech_override: Optional[bool] = None,
669
+ stop_on_agent_speech_override: Optional[bool] = None,
670
+ ) -> List[Frame]:
671
+ """Get the next n frames from the navigator.
672
+
673
+ Args:
674
+ num_frames: The number of frames to get
675
+ target_video_name: The target video name. Keep the current video if None.
676
+ actions_name: The actions before the target video.
677
+ on_user_speech: Whether user is currently speaking
678
+ on_agent_speech: Whether agent is currently speaking
679
+ stop_on_user_speech_override: Override stop_on_user_speech from video config if provided
680
+ stop_on_agent_speech_override: Override stop_on_agent_speech from video config if provided
681
+ """
682
+ if self.frame_buffer:
683
+ video = self.videos.get(self.frame_buffer[0].video_name)
684
+ if video:
685
+ # Use override values if provided, otherwise use video's default values
686
+ stop_on_user = (
687
+ stop_on_user_speech_override
688
+ if stop_on_user_speech_override is not None
689
+ else video.stop_on_user_speech
690
+ )
691
+ stop_on_agent = (
692
+ stop_on_agent_speech_override
693
+ if stop_on_agent_speech_override is not None
694
+ else video.stop_on_agent_speech
695
+ )
696
+
697
+ if (on_user_speech and stop_on_user) or (
698
+ on_agent_speech and stop_on_agent
699
+ ):
700
+ self.reset_buffer()
701
+ logger.trace(
702
+ f"Stop on {video.video_name} because of {on_user_speech=} or "
703
+ f"{on_agent_speech=} (stop_on_user={stop_on_user}, stop_on_agent={stop_on_agent})"
704
+ )
705
+
706
+ if num_frames <= 0:
707
+ return []
708
+
709
+ min_n = num_frames - len(self.frame_buffer)
710
+ self.frame_buffer += self.collect_n_frames(
711
+ min_n, target_video_name=target_video_name, actions_name=actions_name
712
+ )
713
+ frames = self.frame_buffer[:num_frames]
714
+ self.frame_buffer = self.frame_buffer[num_frames:]
715
+ return frames
716
+
717
+ def filter_video_names(
718
+ self, target_video: Optional[str], actions: Optional[List[str] | str]
719
+ ) -> Tuple[str, List[str]]:
720
+ """Filter the target video and actions."""
721
+ if target_video:
722
+ if not self.video_exists(target_video, is_action=False):
723
+ logger.warning(f"Invalid video name: {target_video}")
724
+ target_video = None
725
+
726
+ if actions:
727
+ if isinstance(actions, str):
728
+ actions = [actions]
729
+ valid_actions = []
730
+ for action in actions:
731
+ if self.video_exists(action, is_action=True):
732
+ valid_actions.append(action)
733
+ else:
734
+ logger.warning(f"Invalid action name: {action}")
735
+ actions = valid_actions
736
+ return target_video, actions