bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,482 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import List, Optional, Tuple
7
+
8
+ import cv2
9
+ import networkx as nx
10
+ import numpy as np
11
+ from loguru import logger
12
+
13
+
14
+ @dataclass(frozen=True, order=True)
15
+ class Frame:
16
+ video_name: str
17
+ frame_index: int
18
+
19
+
20
+ @dataclass(frozen=True, order=True)
21
+ class NodeID:
22
+ video_name: str
23
+ video_hash: str
24
+ frame_index: int
25
+
26
+ def __repr__(self) -> str:
27
+ return f"{self.video_name}_{self.video_hash[:8]}_{self.frame_index}"
28
+
29
+
30
+ class DriverVideo:
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ video_path: str,
35
+ video_data_path: Optional[str] = None,
36
+ num_frames: Optional[int] = None,
37
+ *,
38
+ stride: int = 10,
39
+ single_direction: bool = False,
40
+ stop_on_user_speech: bool = False,
41
+ stop_on_agent_speech: bool = False,
42
+ lip_sync_required: bool = True,
43
+ ) -> None:
44
+ self.video_name = name
45
+
46
+ self.video_path: str = video_path
47
+
48
+ # read the video hash
49
+ with open(video_path, "rb") as f:
50
+ self.video_hash = hashlib.md5(f.read()).hexdigest()
51
+
52
+ # update the video_data_path
53
+ self.video_data_path: Optional[str] = video_data_path or _find_video_data_path(
54
+ video_path, self.video_hash
55
+ )
56
+
57
+ # read the number of frames
58
+ cap = cv2.VideoCapture(video_path)
59
+ try:
60
+ total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
61
+ self.resolution = (
62
+ int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
63
+ int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
64
+ )
65
+ finally:
66
+ cap.release()
67
+ if num_frames is not None:
68
+ assert num_frames <= total_num_frames, (
69
+ f"num_frames {num_frames} > total_num_frames {total_num_frames}"
70
+ )
71
+ else:
72
+ num_frames = total_num_frames
73
+ self.num_frames = num_frames
74
+ self.frames = [Frame(self.video_name, i) for i in range(num_frames)]
75
+
76
+ # The transition points, default to all nodes
77
+ self.single_direction = single_direction
78
+ self.init_nodes(stride)
79
+ self.transition_nodes = self.nodes
80
+
81
+ # Stop on user speech or agent speech
82
+ self.stop_on_user_speech = stop_on_user_speech
83
+ self.stop_on_agent_speech = stop_on_agent_speech
84
+ self.lip_sync_required = lip_sync_required
85
+
86
+ @property
87
+ def video_id(self) -> str:
88
+ return f"{self.video_name}_{self.video_hash[:8]}"
89
+
90
+ def get_frame_wh(self, scale_size: Optional[int] = None) -> Tuple[int, int]:
91
+ if scale_size is None:
92
+ return self.resolution
93
+
94
+ # scale max dimension to `scale_size`
95
+ scale = scale_size / max(self.resolution)
96
+ return (
97
+ int(self.resolution[0] * scale),
98
+ int(self.resolution[1] * scale),
99
+ )
100
+
101
+ def get_first_frame(self, scale_size: Optional[int] = None) -> np.ndarray:
102
+ cap = cv2.VideoCapture(str(self.video_path))
103
+ try:
104
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
105
+ ret, frame = cap.read()
106
+ if not ret:
107
+ raise ValueError("Failed to read the first frame")
108
+
109
+ if scale_size is not None:
110
+ frame = cv2.resize(frame, self.get_frame_wh(scale_size))
111
+ return frame
112
+ finally:
113
+ cap.release()
114
+
115
+ def init_nodes(self, stride: int) -> None:
116
+ if stride <= 0:
117
+ # one node for whole video
118
+ stride = self.num_frames
119
+ self.stride = stride
120
+ self.nodes = [
121
+ NodeID(self.video_name, self.video_hash, i)
122
+ for i in range(0, self.num_frames, stride)
123
+ ]
124
+ # Last frame
125
+ if self.nodes[-1].frame_index != self.num_frames - 1:
126
+ self.nodes.append(
127
+ NodeID(self.video_name, self.video_hash, self.num_frames - 1)
128
+ )
129
+
130
+ def remove_nodes(
131
+ self, *, indices: List[int] = None, frame_indices: List[int] = None
132
+ ) -> int:
133
+ new_nodes = self.nodes
134
+
135
+ if indices:
136
+ new_nodes = [
137
+ node for idx, node in enumerate(new_nodes) if idx not in set(indices)
138
+ ]
139
+
140
+ if frame_indices:
141
+ new_nodes = [
142
+ node for node in new_nodes if node.frame_index not in set(frame_indices)
143
+ ]
144
+
145
+ removed = len(self.nodes) - len(new_nodes)
146
+ self.nodes = new_nodes
147
+ self.transition_nodes = self.nodes
148
+ return removed
149
+
150
+ def insert_nodes(self, frame_indices: List[int]) -> None:
151
+ new_nodes = []
152
+ for idx in frame_indices:
153
+ if idx < self.num_frames:
154
+ new_nodes.append(NodeID(self.video_name, self.video_hash, idx))
155
+ self.nodes = sorted(set(self.nodes + new_nodes))
156
+ self.transition_nodes = self.nodes
157
+
158
+ @property
159
+ def transition_frames(self) -> List[Frame]:
160
+ return [self.frames[node.frame_index] for node in self.transition_nodes]
161
+
162
+ def draw_nodes(
163
+ self, nodes: List[NodeID], n_cols: int = 2, image_width: int = 720
164
+ ) -> np.ndarray:
165
+ """Draw the frames of the nodes."""
166
+ frames = [self.frames[node.frame_index] for node in nodes]
167
+ node_indices = [self.nodes.index(node) for node in nodes]
168
+ labels = [f"{idx}: {node}" for idx, node in zip(node_indices, nodes)]
169
+ return self.draw_frames(frames, labels, n_cols, image_width)
170
+
171
+ def as_graph(self) -> nx.DiGraph:
172
+ """Create a graph from the video nodes.
173
+ For single_direction mode, creates a directed cycle that only moves forward.
174
+ For normal mode, creates a bidirectional graph.
175
+ """
176
+ graph = nx.DiGraph()
177
+ graph.add_nodes_from(self.nodes)
178
+ # Add edges between consecutive nodes
179
+ for i in range(1, len(self.nodes)):
180
+ first, second = self.nodes[i - 1], self.nodes[i]
181
+ distance = second.frame_index - first.frame_index
182
+
183
+ # Add forward edge
184
+ graph.add_edge(first, second, distance=distance)
185
+
186
+ # Add backward edge if not in single direction mode
187
+ if not self.single_direction:
188
+ graph.add_edge(second, first, distance=distance)
189
+
190
+ return graph
191
+
192
+ def collect_frames(
193
+ self,
194
+ source: NodeID | int,
195
+ target: NodeID | int,
196
+ allow_multi_steps: bool = True,
197
+ ) -> List[Frame]:
198
+ """Collect frames between two nodes.
199
+
200
+ Include the frame of the `source` and exclude the frame of the `target`.
201
+ If the distance between the two nodes is larger than one step,
202
+ and `allow_multi_steps` is False, return only the frame of the `prev_node`.
203
+ # TODO: write filler info to the edge and remove `allow_multi_steps`
204
+
205
+ Returns:
206
+ List[np.ndarray]: The collected frames.
207
+ """
208
+ if isinstance(source, int):
209
+ source = self.nodes[source]
210
+ if isinstance(target, int):
211
+ target = self.nodes[target]
212
+
213
+ if (
214
+ source.video_hash != self.video_hash
215
+ and target.video_hash != self.video_hash
216
+ ):
217
+ logger.warning(
218
+ f"Both nodes are not from this video: "
219
+ f"{source} and {target} != {self.video_hash}"
220
+ )
221
+ return []
222
+
223
+ if source.video_hash != self.video_hash:
224
+ # Jump from another video to this video, return empty list
225
+ return []
226
+
227
+ if target.video_hash != self.video_hash or (
228
+ not allow_multi_steps
229
+ and abs(source.frame_index - target.frame_index) > self.stride
230
+ ):
231
+ # Jump from this video to another video
232
+ # or the distance between the two nodes is larger than one step
233
+ # assume this is a "jumper connection"
234
+ return [self.frames[source.frame_index]]
235
+
236
+ # Both nodes are from this video
237
+ stride = 1 if source.frame_index < target.frame_index else -1
238
+ if self.single_direction and stride == -1:
239
+ # Single direction video can only move forward
240
+ return []
241
+ return self.frames[source.frame_index : target.frame_index : stride]
242
+
243
+
244
+ class LoopingVideo(DriverVideo):
245
+ def __init__(
246
+ self,
247
+ name: str,
248
+ video_path: str,
249
+ video_data_path: Optional[str] = None,
250
+ num_frames: Optional[int] = None,
251
+ *,
252
+ stride: int = 10,
253
+ single_direction: bool = False,
254
+ stop_on_user_speech: bool = False,
255
+ stop_on_agent_speech: bool = False,
256
+ lip_sync_required: bool = True,
257
+ loop_between: Tuple[int, int] = (0, None),
258
+ ) -> None:
259
+ """A video that loops between two frames.
260
+
261
+ Args:
262
+ loop_between: Tuple of (start, end) frame indices to loop between.
263
+ None means start or end of video.
264
+ single_direction: If True, video only plays forward from start to end,
265
+ then jumps back to start. If False, video plays back and forth.
266
+ """
267
+ super().__init__(
268
+ name,
269
+ video_path=video_path,
270
+ video_data_path=video_data_path,
271
+ num_frames=num_frames,
272
+ stride=stride,
273
+ single_direction=single_direction,
274
+ stop_on_user_speech=stop_on_user_speech,
275
+ stop_on_agent_speech=stop_on_agent_speech,
276
+ lip_sync_required=lip_sync_required,
277
+ )
278
+
279
+ self.loop_direction = 1 # 1 for forward, -1 for backward
280
+ self._loop_start_node = None
281
+ self._loop_end_node = None
282
+ self.set_loop_between(*loop_between)
283
+
284
+ def set_loop_between(self, start: Optional[int], end: Optional[int]) -> int:
285
+ start = start or 0
286
+ end = end or -1
287
+ if end < 0:
288
+ end = len(self.frames) + end
289
+ self.loop_between = (max(0, start), min(len(self.frames) - 1, end))
290
+ if self.loop_between[0] > self.loop_between[1]:
291
+ raise ValueError(f"Invalid loop_between {self.loop_between}")
292
+
293
+ # Create and store loop nodes
294
+ self._loop_start_node = NodeID(
295
+ self.video_name, self.video_hash, self.loop_between[0]
296
+ )
297
+ self._loop_end_node = NodeID(
298
+ self.video_name, self.video_hash, self.loop_between[1]
299
+ )
300
+
301
+ # Add loop nodes if not already in nodes list
302
+ if self._loop_start_node not in self.nodes:
303
+ self.nodes = sorted(self.nodes + [self._loop_start_node])
304
+ if self._loop_end_node not in self.nodes:
305
+ self.nodes = sorted(self.nodes + [self._loop_end_node])
306
+
307
+ return self.loop_between[1] - self.loop_between[0]
308
+
309
+ def get_n_frames(
310
+ self, min_n: int, start: NodeID, last_position: Optional[NodeID] = None
311
+ ) -> Tuple[List[Frame], NodeID]:
312
+ """Get at least `min_n` frames from the start node.
313
+
314
+ Args:
315
+ min_n: The minimum number of frames to get.
316
+ start: The start node.
317
+ last_position: The last position node for determine direction.
318
+
319
+ Returns:
320
+ The collected frames and the last position node.
321
+ """
322
+ if start.video_hash != self.video_hash:
323
+ logger.warning(
324
+ f"Enter node is not from this video: {start} != {self.video_hash}"
325
+ )
326
+ return [], None
327
+
328
+ if min_n <= 0:
329
+ return [], None
330
+
331
+ start_idx = self.nodes.index(start)
332
+ if start_idx == -1:
333
+ logger.warning(f"Node {start} is not in the nodes list")
334
+ return [], None
335
+
336
+ # NOTE: loop_end and loop_start are frame index,
337
+ # be careful if the new node index is valid
338
+ loop_start, loop_end = self.loop_between
339
+
340
+ # For single_direction, always move forward
341
+ if self.single_direction:
342
+ self.loop_direction = 1
343
+ elif last_position and last_position.video_hash != self.video_hash:
344
+ # reset the direction if the last postion is from another video
345
+ # Move to the direction with more frames
346
+ left_frames = start.frame_index - loop_start
347
+ right_frames = loop_end - start.frame_index
348
+ self.loop_direction = 1 if right_frames > left_frames else -1
349
+
350
+ frames = []
351
+ curr_node = start
352
+ curr_idx = start_idx
353
+ while len(frames) < min_n:
354
+ if self.single_direction:
355
+ if curr_node.frame_index >= loop_end:
356
+ # Jump back to loop_start when reaching loop_end
357
+ next_node = self._loop_start_node
358
+ next_idx = self.nodes.index(self._loop_start_node)
359
+ else:
360
+ next_idx = curr_idx + 1
361
+ next_node = self.nodes[next_idx]
362
+ else:
363
+ # bidirectional behavior
364
+ if curr_node.frame_index >= loop_end:
365
+ self.loop_direction = -1
366
+ elif curr_node.frame_index <= loop_start:
367
+ self.loop_direction = 1
368
+ next_idx = curr_idx + self.loop_direction
369
+ next_node = self.nodes[next_idx]
370
+
371
+ frames += self.collect_frames(curr_node, next_node)
372
+ curr_node = next_node
373
+ curr_idx = next_idx
374
+
375
+ return frames, curr_node
376
+
377
+ def as_graph(self) -> nx.Graph:
378
+ """Create a graph from the video nodes with a loop back edge."""
379
+ graph = super().as_graph()
380
+
381
+ # Add loop back edge using stored nodes
382
+ if self.single_direction:
383
+ graph.add_edge(self._loop_end_node, self._loop_start_node, distance=0)
384
+
385
+ return graph
386
+
387
+
388
+ class SingleActionVideo(DriverVideo):
389
+ def __init__(
390
+ self,
391
+ name: str,
392
+ video_path: str,
393
+ video_data_path: Optional[str] = None,
394
+ num_frames: Optional[int] = None,
395
+ *,
396
+ single_direction: bool = False,
397
+ stop_on_user_speech: bool = False,
398
+ stop_on_agent_speech: bool = False,
399
+ lip_sync_required: bool = True,
400
+ transition_frames: Optional[List[int]] = None,
401
+ action_frame: int = -1,
402
+ ) -> None:
403
+ """A video that plays a single action.
404
+
405
+ `transition_frame_indices` is a list of frame indices
406
+ that are the transition nodes.
407
+ """
408
+ super().__init__(
409
+ name,
410
+ video_path=video_path,
411
+ video_data_path=video_data_path,
412
+ num_frames=num_frames,
413
+ stride=-1,
414
+ single_direction=single_direction,
415
+ stop_on_user_speech=stop_on_user_speech,
416
+ stop_on_agent_speech=stop_on_agent_speech,
417
+ lip_sync_required=lip_sync_required,
418
+ )
419
+ transition_frames = transition_frames or [0]
420
+ if single_direction:
421
+ transition_frames.append(-1)
422
+ transition_frames = [
423
+ frame if frame >= 0 else len(self.frames) + frame
424
+ for frame in transition_frames
425
+ ]
426
+ transition_nodes = [
427
+ NodeID(self.video_name, self.video_hash, i) for i in set(transition_frames)
428
+ ]
429
+
430
+ action_frame = (
431
+ action_frame if action_frame >= 0 else len(self.frames) + action_frame
432
+ )
433
+ self._action_node = NodeID(self.video_name, self.video_hash, action_frame)
434
+
435
+ self.nodes = sorted(set(self.nodes + transition_nodes + [self._action_node]))
436
+ self.transition_nodes = sorted(transition_nodes)
437
+
438
+ @property
439
+ def action_node(self) -> NodeID:
440
+ return self._action_node
441
+
442
+ def as_graph(self) -> nx.Graph:
443
+ if not self.single_direction:
444
+ return super().as_graph()
445
+ graph = nx.DiGraph()
446
+ graph.add_nodes_from(self.nodes)
447
+ # add edges for a single direction video
448
+ for i in range(1, len(self.nodes)):
449
+ first, second = self.nodes[i - 1], self.nodes[i]
450
+ graph.add_edge(
451
+ first, second, distance=second.frame_index - first.frame_index
452
+ )
453
+ return graph
454
+
455
+ def get_frames(self, start: NodeID) -> Tuple[List[Frame], NodeID]:
456
+ if start.video_hash != self.video_hash:
457
+ logger.warning(
458
+ f"Enter node is not from this video: {start} != {self.video_hash}"
459
+ )
460
+ return ([], None)
461
+ start_idx = self.nodes.index(start)
462
+ if start_idx == -1:
463
+ logger.warning(f"Node {start} is not in the nodes list")
464
+ return ([], None)
465
+
466
+ # Collect all frames of the action
467
+ # start -> last frame -> last transition frame
468
+ path = [start, self.action_node, self.transition_nodes[-1]]
469
+ frames = []
470
+ for i in range(1, len(path)):
471
+ frames += self.collect_frames(path[i - 1], path[i], allow_multi_steps=True)
472
+ return frames, path[-1]
473
+
474
+
475
+ def _find_video_data_path(video_path: str, file_hash: str) -> Optional[str]:
476
+ video_path = Path(video_path)
477
+ name = f"{video_path.name}.*_{file_hash[:8]}"
478
+ for suffix in ["h5", "pth"]:
479
+ files = list(video_path.parent.glob(name + f".{suffix}"))
480
+ if files:
481
+ return files[0].as_posix()
482
+ return None