bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,43 @@
1
+ import hashlib
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from loguru import logger
6
+
7
+ from .fps_controller import FPSController
8
+
9
+ __all__ = ["FPSController"]
10
+
11
+
12
+ def calculate_file_hash(file_path: str) -> Optional[str]:
13
+ """Calculate an MD5 hash of a file.
14
+
15
+ This function reads the file in chunks to efficiently handle large files
16
+ and calculates an MD5 hash, which is returned as a hexadecimal string.
17
+
18
+ Args:
19
+ file_path: Path to the file to be hashed
20
+
21
+ Returns:
22
+ A hexadecimal string representing the file hash, or None if the file doesn't exist
23
+
24
+ Raises:
25
+ IOError: If there's an error reading the file
26
+ """
27
+ try:
28
+ path = Path(file_path)
29
+ if not path.is_file():
30
+ logger.warning(f"Cannot calculate hash for non-file: {file_path}")
31
+ return None
32
+
33
+ md5_hash = hashlib.md5()
34
+
35
+ # Read the file in chunks of 4K to avoid loading large files into memory
36
+ with open(path, "rb") as f:
37
+ for byte_block in iter(lambda: f.read(4096), b""):
38
+ md5_hash.update(byte_block)
39
+
40
+ return md5_hash.hexdigest()
41
+ except Exception as e:
42
+ logger.error(f"Error calculating file hash for {file_path}: {e}")
43
+ raise
@@ -0,0 +1,359 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import time
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, Optional
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ try:
12
+ from livekit import rtc
13
+ from livekit.agents import utils
14
+ from livekit.agents.voice import AgentSession, io
15
+ from livekit.agents.voice.avatar import (
16
+ AudioReceiver,
17
+ AudioSegmentEnd,
18
+ AvatarOptions,
19
+ )
20
+ from livekit.agents.voice.chat_cli import ChatCLI
21
+ except ImportError:
22
+ raise ImportError(
23
+ "livekit-agents is required, please install it with `pip install livekit-agents[openai,silero,deepgram,cartesia]~=1.0rc`"
24
+ )
25
+ from loguru import logger
26
+
27
+ from bithuman import AsyncBithuman, AudioChunk, VideoFrame
28
+ from bithuman.utils import FPSController
29
+
30
+
31
+ class AudioOutput(ABC):
32
+ @abstractmethod
33
+ async def capture_frame(self, audio_chunk: AudioChunk) -> None:
34
+ pass
35
+
36
+ @abstractmethod
37
+ def clear_buffer(self) -> None:
38
+ pass
39
+
40
+
41
+ class VideoOutput(ABC):
42
+ @abstractmethod
43
+ async def capture_frame(
44
+ self, frame: VideoFrame, fps: float, exp_time: float
45
+ ) -> None:
46
+ pass
47
+
48
+ @abstractmethod
49
+ def buffer_empty(self) -> bool:
50
+ pass
51
+
52
+
53
+ class LocalAudioIO(ChatCLI, AudioOutput):
54
+ """Chat interface that redirects audio output to a custom destination."""
55
+
56
+ def __init__(
57
+ self,
58
+ session: AgentSession,
59
+ agent_audio_output: io.AudioOutput,
60
+ *,
61
+ buffer_size: int = 0,
62
+ loop: Optional[asyncio.AbstractEventLoop] = None,
63
+ ) -> None:
64
+ super().__init__(agent_session=session, loop=loop)
65
+ self._redirected_audio_output = agent_audio_output
66
+ self._input_buffer = utils.aio.Chan[rtc.AudioFrame](maxsize=buffer_size)
67
+ self._forward_audio_atask: Optional[asyncio.Task] = None
68
+
69
+ self._sample_rate = self._audio_sink.sample_rate
70
+ self._resampler: Optional[rtc.AudioResampler] = None
71
+
72
+ async def start(self) -> None:
73
+ await super().start()
74
+ self._forward_audio_atask = asyncio.create_task(self._forward_audio())
75
+
76
+ async def capture_frame(self, audio_chunk: AudioChunk) -> None:
77
+ audio_frame = rtc.AudioFrame(
78
+ data=audio_chunk.bytes,
79
+ sample_rate=audio_chunk.sample_rate,
80
+ num_channels=1,
81
+ samples_per_channel=len(audio_chunk.array),
82
+ )
83
+
84
+ if not self._resampler and self._sample_rate != audio_chunk.sample_rate:
85
+ self._resampler = rtc.AudioResampler(
86
+ input_rate=audio_chunk.sample_rate,
87
+ output_rate=self._sample_rate,
88
+ num_channels=1,
89
+ )
90
+
91
+ if self._resampler:
92
+ for f in self._resampler.push(audio_frame):
93
+ await self._input_buffer.send(f)
94
+ else:
95
+ await self._input_buffer.send(audio_frame)
96
+
97
+ def clear_buffer(self) -> None:
98
+ while not self._input_buffer.empty():
99
+ self._input_buffer.recv_nowait()
100
+ with self._audio_sink.lock:
101
+ self._audio_sink.audio_buffer.clear()
102
+
103
+ @utils.log_exceptions(logger=logger)
104
+ async def _forward_audio(self) -> None:
105
+ async for frame in self._input_buffer:
106
+ await self._audio_sink.capture_frame(frame)
107
+
108
+ def _update_speaker(self, *, enable: bool) -> None:
109
+ super()._update_speaker(enable=enable)
110
+
111
+ # redirect the agent's audio output
112
+ if enable:
113
+ self._session.output.audio = self._redirected_audio_output
114
+ else:
115
+ self._session.output.audio = None
116
+
117
+ async def aclose(self) -> None:
118
+ if not self._done_fut.done():
119
+ self._done_fut.set_result(None)
120
+ if self._main_atask:
121
+ await utils.aio.cancel_and_wait(self._main_atask)
122
+
123
+ self._input_buffer.close()
124
+ if self._forward_audio_atask:
125
+ await utils.aio.cancel_and_wait(self._forward_audio_atask)
126
+
127
+
128
+ class LocalVideoPlayer(VideoOutput):
129
+ """Video display for rendering avatar frames with debug information."""
130
+
131
+ def __init__(
132
+ self,
133
+ window_size: tuple[int, int],
134
+ window_name: str = "BitHuman Avatar",
135
+ buffer_size: int = 0,
136
+ ) -> None:
137
+ self.window_name: str = window_name
138
+ self.start_time: Optional[float] = None
139
+ self._input_buffer = utils.aio.Chan[tuple[VideoFrame, float, float]](
140
+ maxsize=buffer_size
141
+ )
142
+ self._display_atask: Optional[asyncio.Task] = None
143
+
144
+ cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL)
145
+ cv2.resizeWindow(self.window_name, window_size[0], window_size[1])
146
+
147
+ self.start_time = asyncio.get_event_loop().time()
148
+ self._display_atask = asyncio.create_task(self._display_frame())
149
+
150
+ async def aclose(self) -> None:
151
+ cv2.destroyAllWindows()
152
+ if self._display_atask:
153
+ await utils.aio.cancel_and_wait(self._display_atask)
154
+
155
+ async def capture_frame(
156
+ self, frame: VideoFrame, fps: float = 0.0, exp_time: float = 0.0
157
+ ) -> None:
158
+ if not frame.has_image:
159
+ return
160
+ await self._input_buffer.send((frame, fps, exp_time))
161
+
162
+ def buffer_empty(self) -> bool:
163
+ return self._input_buffer.empty()
164
+
165
+ @utils.log_exceptions(logger=logger)
166
+ async def _display_frame(self) -> None:
167
+ async for frame, fps, exp_time in self._input_buffer:
168
+ image = await self.render_image(frame, fps, exp_time)
169
+ cv2.imshow(self.window_name, image)
170
+ cv2.waitKey(1)
171
+
172
+ async def render_image(
173
+ self, frame: VideoFrame, fps: float = 0.0, exp_time: float = 0.0
174
+ ) -> np.ndarray:
175
+ image = frame.bgr_image.copy()
176
+
177
+ # Add overlay information
178
+ self._add_debug_info(image, fps, exp_time)
179
+
180
+ return image
181
+
182
+ def _add_debug_info(self, image: np.ndarray, fps: float, exp_time: float) -> None:
183
+ # Add FPS information
184
+ cv2.putText(
185
+ image,
186
+ f"FPS: {fps:.1f}",
187
+ (10, 30),
188
+ cv2.FONT_HERSHEY_SIMPLEX,
189
+ 1,
190
+ (0, 255, 0),
191
+ 2,
192
+ )
193
+
194
+ # Add elapsed time
195
+ current_time = asyncio.get_event_loop().time()
196
+ if self.start_time is not None:
197
+ elapsed = current_time - self.start_time
198
+ cv2.putText(
199
+ image,
200
+ f"Time: {elapsed:.1f}s",
201
+ (10, 70),
202
+ cv2.FONT_HERSHEY_SIMPLEX,
203
+ 1,
204
+ (0, 255, 0),
205
+ 2,
206
+ )
207
+
208
+ # Add expiration time if available
209
+ if exp_time > 0:
210
+ exp_in_seconds = exp_time - time.time()
211
+ cv2.putText(
212
+ image,
213
+ f"Exp in: {exp_in_seconds:.1f}s",
214
+ (10, 110),
215
+ cv2.FONT_HERSHEY_SIMPLEX,
216
+ 1,
217
+ (0, 255, 0),
218
+ 2,
219
+ )
220
+
221
+
222
+ class LocalAvatarRunner:
223
+ """Controls and synchronizes avatar audio and video playback."""
224
+
225
+ def __init__(
226
+ self,
227
+ *,
228
+ bithuman_runtime: AsyncBithuman,
229
+ audio_input: AudioReceiver,
230
+ audio_output: AudioOutput,
231
+ video_output: VideoOutput,
232
+ options: AvatarOptions,
233
+ runtime_kwargs: dict[str, Any] | None = None,
234
+ ) -> None:
235
+ self._bithuman_runtime = bithuman_runtime
236
+ self._runtime_kwargs = runtime_kwargs or {}
237
+ self._options = options
238
+
239
+ self._audio_recv = audio_input
240
+ self._audio_output = audio_output
241
+ self._video_output = video_output
242
+ self._stop_event = asyncio.Event()
243
+
244
+ # State management
245
+ self._playback_position: float = 0.0
246
+ self._audio_playing: bool = False
247
+ self._tasks: set[asyncio.Task] = set()
248
+ self._read_audio_atask: Optional[asyncio.Task] = None
249
+ self._publish_video_atask: Optional[asyncio.Task] = None
250
+
251
+ # FPS control
252
+ self._fps_controller = FPSController(target_fps=options.video_fps)
253
+
254
+ async def start(self) -> None:
255
+ await self._audio_recv.start()
256
+
257
+ # Setup event handler
258
+ self._audio_recv.on("clear_buffer", self._create_clear_buffer_task)
259
+
260
+ # Start processing tasks
261
+ self._read_audio_atask = asyncio.create_task(self._read_audio())
262
+ self._publish_video_atask = asyncio.create_task(self._publish_video())
263
+
264
+ def _create_clear_buffer_task(self) -> None:
265
+ """Create a task to handle clear buffer events."""
266
+ task = asyncio.create_task(self._handle_clear_buffer())
267
+ self._tasks.add(task)
268
+ task.add_done_callback(self._tasks.discard)
269
+
270
+ @utils.log_exceptions(logger=logger)
271
+ async def _read_audio(self) -> None:
272
+ """Process incoming audio frames."""
273
+ async for frame in self._audio_recv:
274
+ if self._stop_event.is_set():
275
+ break
276
+
277
+ if not self._audio_playing and isinstance(frame, rtc.AudioFrame):
278
+ self._audio_playing = True
279
+ if isinstance(frame, AudioSegmentEnd):
280
+ await self._bithuman_runtime.flush()
281
+ continue
282
+ await self._bithuman_runtime.push_audio(
283
+ bytes(frame.data), frame.sample_rate, last_chunk=False
284
+ )
285
+
286
+ @utils.log_exceptions(logger=logger)
287
+ async def _publish_video(self) -> None:
288
+ """Process and display video frames."""
289
+ async for frame in self._bithuman_runtime.run(
290
+ out_buffer_empty=self._video_output.buffer_empty,
291
+ **self._runtime_kwargs,
292
+ ):
293
+ # Control frame rate
294
+ sleep_time = self._fps_controller.wait_next_frame(sleep=False)
295
+ if sleep_time > 0:
296
+ await asyncio.sleep(sleep_time)
297
+
298
+ # Send video frame
299
+ if frame.has_image:
300
+ await self._video_output.capture_frame(
301
+ frame,
302
+ fps=self._fps_controller.average_fps,
303
+ exp_time=self._bithuman_runtime.get_expiration_time(),
304
+ )
305
+
306
+ # Send audio chunk
307
+ audio_chunk = frame.audio_chunk
308
+ if audio_chunk is not None:
309
+ await self._audio_output.capture_frame(audio_chunk)
310
+ self._playback_position += audio_chunk.duration
311
+
312
+ # Handle end of speech
313
+ if frame.end_of_speech:
314
+ await self._handle_end_of_speech()
315
+
316
+ self._fps_controller.update()
317
+
318
+ async def _handle_end_of_speech(self) -> None:
319
+ """Handle end of speech event."""
320
+ if self._audio_playing:
321
+ notify_task = self._audio_recv.notify_playback_finished(
322
+ playback_position=self._playback_position,
323
+ interrupted=False,
324
+ )
325
+ if asyncio.iscoroutine(notify_task):
326
+ await notify_task
327
+
328
+ self._playback_position = 0.0
329
+ self._audio_playing = False
330
+
331
+ async def _handle_clear_buffer(self) -> None:
332
+ """Handle clearing the buffer and notify about interrupted playback."""
333
+ tasks = []
334
+ self._bithuman_runtime.interrupt()
335
+ self._audio_output.clear_buffer()
336
+
337
+ # Handle interrupted playback
338
+ if self._audio_playing:
339
+ notify_task = self._audio_recv.notify_playback_finished(
340
+ playback_position=self._playback_position,
341
+ interrupted=True,
342
+ )
343
+ if asyncio.iscoroutine(notify_task):
344
+ tasks.append(notify_task)
345
+ self._playback_position = 0.0
346
+ self._audio_playing = False
347
+
348
+ await asyncio.gather(*tasks)
349
+
350
+ async def aclose(self) -> None:
351
+ """Close the avatar controller and clean up resources."""
352
+ if self._read_audio_atask:
353
+ await utils.aio.cancel_and_wait(self._read_audio_atask)
354
+ if self._publish_video_atask:
355
+ await utils.aio.cancel_and_wait(self._publish_video_atask)
356
+ await utils.aio.cancel_and_wait(*self._tasks)
357
+
358
+ def stop(self) -> None:
359
+ self._stop_event.set()
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from collections import deque
5
+
6
+ from loguru import logger
7
+
8
+
9
+ class FPSController:
10
+ """Controls frame rate for synchronous processing.
11
+
12
+ Maintains target FPS by calculating appropriate sleep times and adjusting
13
+ for processing delays.
14
+
15
+ Attributes:
16
+ target_fps: Target frames per second
17
+ frame_interval: Time interval between frames in seconds
18
+ average_fps: Current average FPS
19
+ """
20
+
21
+ def __init__(
22
+ self, target_fps: int, max_frame_count: int = 10, disabled: bool = False
23
+ ) -> None:
24
+ """Initialize FPS controller.
25
+
26
+ Args:
27
+ target_fps: Target frames per second
28
+ max_frame_count: Number of frames to keep for FPS calculation
29
+ disabled: If True, the FPS controller will be disabled.
30
+ """
31
+ self.target_fps = target_fps
32
+ self.frame_interval = 1.0 / target_fps
33
+ self.max_frame_count = max_frame_count
34
+ self.disabled = disabled
35
+
36
+ # Timing control
37
+ self.next_frame_time = None
38
+ self.display_ts: deque[float] = deque(maxlen=max_frame_count)
39
+ self.average_fps = 0
40
+
41
+ def wait_next_frame(self, *, sleep: bool = True) -> float:
42
+ """Wait until it's time for the next frame.
43
+
44
+ Adjusts sleep time based on actual FPS to maintain target rate.
45
+ """
46
+ current_time = time.time()
47
+
48
+ # Initialize next_frame_time if needed
49
+ if self.next_frame_time is None:
50
+ self.next_frame_time = current_time
51
+ self.display_ts.clear()
52
+
53
+ # Calculate sleep time to maintain target FPS
54
+ sleep_time = self.next_frame_time - current_time
55
+
56
+ if sleep_time > 0 and not self.disabled:
57
+ # Adjust sleep time based on actual FPS
58
+ if len(self.display_ts) >= 2:
59
+ self.average_fps = (len(self.display_ts) - 1) / (
60
+ self.display_ts[-1] - self.display_ts[0]
61
+ )
62
+ scale = min(1.1, max(0.9, self.average_fps / self.target_fps))
63
+ sleep_time *= scale
64
+ if sleep:
65
+ time.sleep(sleep_time)
66
+ return sleep_time
67
+ else:
68
+ # Check if significantly behind schedule
69
+ if -sleep_time > self.frame_interval * 8:
70
+ logger.warning(
71
+ f"Frame processing was behind schedule for "
72
+ f"{-sleep_time * 1000:.2f} ms"
73
+ )
74
+ self.next_frame_time = time.time()
75
+ return sleep_time
76
+
77
+ def update(self) -> None:
78
+ """Update timing information after processing a frame."""
79
+ current_time = time.time()
80
+
81
+ # Update timing information (deque auto-evicts oldest when maxlen exceeded)
82
+ self.display_ts.append(current_time)
83
+
84
+ # Calculate next frame time
85
+ self.next_frame_time += self.frame_interval
86
+
87
+ @property
88
+ def fps(self) -> float:
89
+ """Get current average FPS."""
90
+ return self.average_fps
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ try:
7
+ from turbojpeg import TurboJPEG
8
+
9
+ jpeg_encoder = TurboJPEG()
10
+ except (ImportError, ModuleNotFoundError, RuntimeError):
11
+ jpeg_encoder = None
12
+
13
+
14
+ def encode_image(image: np.ndarray, quality: int = 85) -> bytes:
15
+ """Encode the image to bytes."""
16
+ if jpeg_encoder is not None:
17
+ return jpeg_encoder.encode(image, quality=quality)
18
+ return cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality])[
19
+ 1
20
+ ].tobytes()
21
+
22
+
23
+ def decode_image(image_bytes: bytes) -> np.ndarray:
24
+ """Decode the image from bytes."""
25
+ if jpeg_encoder is not None:
26
+ return jpeg_encoder.decode(image_bytes)
27
+ return cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
28
+
29
+
30
+ class CompressedImage:
31
+ """A compressed image."""
32
+
33
+ def __init__(self, data: bytes | np.ndarray) -> None:
34
+ """Initialize the compressed image."""
35
+ if isinstance(data, np.ndarray):
36
+ data = encode_image(data)
37
+ self.data = data
38
+
39
+ def as_numpy(self) -> np.ndarray:
40
+ """Get the image data as a numpy array."""
41
+ return decode_image(self.data)
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ import tarfile
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
6
+ from typing import Optional
7
+
8
+
9
+ def unzip_tarfile(
10
+ file_path: str, extract_to_local: bool = False
11
+ ) -> tuple[str, Optional[TemporaryDirectory]]:
12
+ """Unzip the workspace directory if it is a file."""
13
+ file_path: Path = Path(file_path)
14
+ if file_path.is_dir():
15
+ return str(file_path), None
16
+
17
+ # Extract the workspace
18
+ if not extract_to_local:
19
+ temp_dir_handle = TemporaryDirectory()
20
+ dest_dir = temp_dir_handle.name
21
+ else:
22
+ temp_dir_handle = None
23
+ dest_dir = str(file_path.parent / file_path.stem)
24
+ if dest_dir.endswith(".tar"):
25
+ dest_dir = dest_dir[:-4] # Remove .tar suffix
26
+
27
+ if temp_dir_handle is not None or not Path(dest_dir).exists():
28
+ Path(dest_dir).mkdir(parents=True, exist_ok=True)
29
+ mode = "r:gz" if file_path.name.endswith("gz") else "r"
30
+ with tarfile.open(file_path, mode) as tar:
31
+ tar.extractall(dest_dir)
32
+ file_path = dest_dir
33
+
34
+ # Enter the dir if there is only one directory in the tar file
35
+ files = list(Path(dest_dir).iterdir())
36
+ if len(files) == 1 and files[0].is_dir():
37
+ file_path = str(files[0])
38
+ return file_path, temp_dir_handle
@@ -0,0 +1,16 @@
1
+ from . import trigger
2
+ from .driver_video import DriverVideo, Frame, LoopingVideo, SingleActionVideo
3
+ from .navigator import VideoGraphNavigator
4
+ from .video_script import VideoConfig, VideoConfigs, VideoScript
5
+
6
+ __all__ = [
7
+ "DriverVideo",
8
+ "LoopingVideo",
9
+ "SingleActionVideo",
10
+ "VideoConfigs",
11
+ "VideoConfig",
12
+ "VideoScript",
13
+ "VideoGraphNavigator",
14
+ "Frame",
15
+ "trigger",
16
+ ]
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import List, Literal, Optional
5
+
6
+ from loguru import logger
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class TriggerData(BaseModel):
11
+ """Data to be sent when a trigger is activated"""
12
+
13
+ target_video: Optional[str] = None
14
+ actions: List[str] | str = Field(default_factory=list)
15
+ description: str = ""
16
+
17
+
18
+ class VideoActionTrigger(BaseModel):
19
+ """Base class for video action triggers"""
20
+
21
+ trigger_data: TriggerData = Field(
22
+ description="Data to be sent when trigger conditions are met"
23
+ )
24
+
25
+ def check_trigger(self, condition: any) -> Optional[TriggerData]:
26
+ """
27
+ Base method to check if trigger conditions are met
28
+ Args:
29
+ condition: The condition to check against (type varies by trigger type)
30
+ Returns:
31
+ TriggerData if triggered, None otherwise
32
+ """
33
+ return None
34
+
35
+ @classmethod
36
+ def from_json(cls, json_str: str) -> List["VideoActionTrigger"]:
37
+ """
38
+ Create KeywordTrigger instances from JSON string using Pydantic validation
39
+ Args:
40
+ json_str: JSON string containing trigger configurations
41
+ Returns:
42
+ List of validated KeywordTrigger instances
43
+ """
44
+ if not json_str:
45
+ return []
46
+ try:
47
+ triggers_data = json.loads(json_str)
48
+ return [
49
+ cls.model_validate_json(json.dumps(trigger))
50
+ for trigger in triggers_data
51
+ ]
52
+ except Exception as e:
53
+ logger.exception(f"Error parsing KeywordTrigger: {e}")
54
+ return []
55
+
56
+
57
+ class KeywordTrigger(VideoActionTrigger):
58
+ """Trigger that activates when specific keywords are detected"""
59
+
60
+ keywords: List[str] = Field(
61
+ description="List of keywords that can trigger this action"
62
+ )
63
+ trigger_source: Literal["user", "agent", "both"] = Field(
64
+ default="both", description="Who can trigger this action - user, agent, or both"
65
+ )
66
+
67
+ def check_trigger(
68
+ self, text: str, source: Literal["user", "agent"]
69
+ ) -> Optional[TriggerData]:
70
+ """
71
+ Check if the given text and source triggers this keyword
72
+ Args:
73
+ text: The text to check
74
+ source: The source of the text - either "user" or "agent"
75
+ Returns:
76
+ TriggerData if triggered, None otherwise
77
+ """
78
+ if self.trigger_source != "both" and source != self.trigger_source:
79
+ return None
80
+
81
+ if any(keyword.lower() in text.lower() for keyword in self.keywords):
82
+ return self.trigger_data
83
+ return None