camel-ai 0.2.45__py3-none-any.whl → 0.2.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/configs/__init__.py +6 -0
- camel/configs/bedrock_config.py +73 -0
- camel/configs/lmstudio_config.py +94 -0
- camel/configs/qwen_config.py +3 -3
- camel/datasets/few_shot_generator.py +19 -3
- camel/datasets/models.py +1 -1
- camel/loaders/__init__.py +2 -0
- camel/loaders/scrapegraph_reader.py +96 -0
- camel/models/__init__.py +4 -0
- camel/models/aiml_model.py +11 -104
- camel/models/anthropic_model.py +11 -76
- camel/models/aws_bedrock_model.py +112 -0
- camel/models/deepseek_model.py +11 -44
- camel/models/gemini_model.py +10 -72
- camel/models/groq_model.py +11 -131
- camel/models/internlm_model.py +11 -61
- camel/models/lmstudio_model.py +82 -0
- camel/models/model_factory.py +7 -1
- camel/models/modelscope_model.py +11 -122
- camel/models/moonshot_model.py +10 -76
- camel/models/nemotron_model.py +4 -60
- camel/models/nvidia_model.py +11 -111
- camel/models/ollama_model.py +12 -205
- camel/models/openai_compatible_model.py +51 -12
- camel/models/openai_model.py +3 -1
- camel/models/openrouter_model.py +12 -131
- camel/models/ppio_model.py +10 -99
- camel/models/qwen_model.py +11 -122
- camel/models/reka_model.py +1 -1
- camel/models/sglang_model.py +5 -3
- camel/models/siliconflow_model.py +10 -58
- camel/models/togetherai_model.py +10 -177
- camel/models/vllm_model.py +11 -218
- camel/models/volcano_model.py +1 -15
- camel/models/yi_model.py +11 -98
- camel/models/zhipuai_model.py +11 -102
- camel/storages/__init__.py +2 -0
- camel/storages/vectordb_storages/__init__.py +2 -0
- camel/storages/vectordb_storages/oceanbase.py +458 -0
- camel/toolkits/__init__.py +4 -0
- camel/toolkits/browser_toolkit.py +4 -7
- camel/toolkits/jina_reranker_toolkit.py +231 -0
- camel/toolkits/pyautogui_toolkit.py +428 -0
- camel/toolkits/search_toolkit.py +167 -0
- camel/toolkits/video_analysis_toolkit.py +215 -80
- camel/toolkits/video_download_toolkit.py +10 -3
- camel/types/enums.py +70 -0
- camel/types/unified_model_type.py +10 -0
- camel/utils/token_counting.py +7 -3
- {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/METADATA +13 -1
- {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/RECORD +54 -46
- {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/licenses/LICENSE +0 -0
camel/toolkits/search_toolkit.py
CHANGED
|
@@ -1064,6 +1064,172 @@ class SearchToolkit(BaseToolkit):
|
|
|
1064
1064
|
except Exception as e:
|
|
1065
1065
|
return {"error": f"Exa search failed: {e!s}"}
|
|
1066
1066
|
|
|
1067
|
+
@api_keys_required([(None, 'TONGXIAO_API_KEY')])
|
|
1068
|
+
def search_alibaba_tongxiao(
|
|
1069
|
+
self,
|
|
1070
|
+
query: str,
|
|
1071
|
+
time_range: Literal[
|
|
1072
|
+
"OneDay", "OneWeek", "OneMonth", "OneYear", "NoLimit"
|
|
1073
|
+
] = "NoLimit",
|
|
1074
|
+
industry: Optional[
|
|
1075
|
+
Literal[
|
|
1076
|
+
"finance",
|
|
1077
|
+
"law",
|
|
1078
|
+
"medical",
|
|
1079
|
+
"internet",
|
|
1080
|
+
"tax",
|
|
1081
|
+
"news_province",
|
|
1082
|
+
"news_center",
|
|
1083
|
+
]
|
|
1084
|
+
] = None,
|
|
1085
|
+
page: int = 1,
|
|
1086
|
+
return_main_text: bool = False,
|
|
1087
|
+
return_markdown_text: bool = True,
|
|
1088
|
+
enable_rerank: bool = True,
|
|
1089
|
+
) -> Dict[str, Any]:
|
|
1090
|
+
r"""Query the Alibaba Tongxiao search API and return search results.
|
|
1091
|
+
|
|
1092
|
+
A powerful search API optimized for Chinese language queries with
|
|
1093
|
+
features:
|
|
1094
|
+
- Enhanced Chinese language understanding
|
|
1095
|
+
- Industry-specific filtering (finance, law, medical, etc.)
|
|
1096
|
+
- Structured data with markdown formatting
|
|
1097
|
+
- Result reranking for relevance
|
|
1098
|
+
- Time-based filtering
|
|
1099
|
+
|
|
1100
|
+
Args:
|
|
1101
|
+
query (str): The search query string (length >= 1 and <= 100).
|
|
1102
|
+
time_range (Literal["OneDay", "OneWeek", "OneMonth", "OneYear",
|
|
1103
|
+
"NoLimit"]): Time frame filter for search results.
|
|
1104
|
+
(default: :obj:`"NoLimit"`)
|
|
1105
|
+
industry (Optional[Literal["finance", "law", "medical",
|
|
1106
|
+
"internet", "tax", "news_province", "news_center"]]):
|
|
1107
|
+
Industry-specific search filter. When specified, only returns
|
|
1108
|
+
results from sites in the specified industries. Multiple
|
|
1109
|
+
industries can be comma-separated.
|
|
1110
|
+
(default: :obj:`None`)
|
|
1111
|
+
page (int): Page number for results pagination.
|
|
1112
|
+
(default: :obj:`1`)
|
|
1113
|
+
return_main_text (bool): Whether to include the main text of the
|
|
1114
|
+
webpage in results. (default: :obj:`True`)
|
|
1115
|
+
return_markdown_text (bool): Whether to include markdown formatted
|
|
1116
|
+
content in results. (default: :obj:`True`)
|
|
1117
|
+
enable_rerank (bool): Whether to enable result reranking. If
|
|
1118
|
+
response time is critical, setting this to False can reduce
|
|
1119
|
+
response time by approximately 140ms. (default: :obj:`True`)
|
|
1120
|
+
|
|
1121
|
+
Returns:
|
|
1122
|
+
Dict[str, Any]: A dictionary containing either search results with
|
|
1123
|
+
'requestId' and 'results' keys, or an 'error' key with error
|
|
1124
|
+
message. Each result contains title, snippet, url and other
|
|
1125
|
+
metadata.
|
|
1126
|
+
"""
|
|
1127
|
+
TONGXIAO_API_KEY = os.getenv("TONGXIAO_API_KEY")
|
|
1128
|
+
|
|
1129
|
+
# Validate query length
|
|
1130
|
+
if not query or len(query) > 100:
|
|
1131
|
+
return {
|
|
1132
|
+
"error": "Query length must be between 1 and 100 characters"
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
# API endpoint and parameters
|
|
1136
|
+
base_url = "https://cloud-iqs.aliyuncs.com/search/genericSearch"
|
|
1137
|
+
headers = {
|
|
1138
|
+
"X-API-Key": TONGXIAO_API_KEY,
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
# Convert boolean parameters to string for compatibility with requests
|
|
1142
|
+
params: Dict[str, Union[str, int]] = {
|
|
1143
|
+
"query": query,
|
|
1144
|
+
"timeRange": time_range,
|
|
1145
|
+
"page": page,
|
|
1146
|
+
"returnMainText": str(return_main_text).lower(),
|
|
1147
|
+
"returnMarkdownText": str(return_markdown_text).lower(),
|
|
1148
|
+
"enableRerank": str(enable_rerank).lower(),
|
|
1149
|
+
}
|
|
1150
|
+
|
|
1151
|
+
# Only add industry parameter if specified
|
|
1152
|
+
if industry is not None:
|
|
1153
|
+
params["industry"] = industry
|
|
1154
|
+
|
|
1155
|
+
try:
|
|
1156
|
+
# Send GET request with proper typing for params
|
|
1157
|
+
response = requests.get(
|
|
1158
|
+
base_url, headers=headers, params=params, timeout=10
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
# Check response status
|
|
1162
|
+
if response.status_code != 200:
|
|
1163
|
+
return {
|
|
1164
|
+
"error": (
|
|
1165
|
+
f"Alibaba Tongxiao API request failed with status "
|
|
1166
|
+
f"code {response.status_code}: {response.text}"
|
|
1167
|
+
)
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
# Parse JSON response
|
|
1171
|
+
data = response.json()
|
|
1172
|
+
|
|
1173
|
+
# Extract and format pageItems
|
|
1174
|
+
page_items = data.get("pageItems", [])
|
|
1175
|
+
results = []
|
|
1176
|
+
for idx, item in enumerate(page_items):
|
|
1177
|
+
# Create a simplified result structure
|
|
1178
|
+
result = {
|
|
1179
|
+
"result_id": idx + 1,
|
|
1180
|
+
"title": item.get("title", ""),
|
|
1181
|
+
"snippet": item.get("snippet", ""),
|
|
1182
|
+
"url": item.get("link", ""),
|
|
1183
|
+
"hostname": item.get("hostname", ""),
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
# Only include additional fields if they exist and are
|
|
1187
|
+
# requested
|
|
1188
|
+
if "summary" in item and item.get("summary"):
|
|
1189
|
+
result["summary"] = item["summary"]
|
|
1190
|
+
elif (
|
|
1191
|
+
return_main_text
|
|
1192
|
+
and "mainText" in item
|
|
1193
|
+
and item.get("mainText")
|
|
1194
|
+
):
|
|
1195
|
+
result["summary"] = item["mainText"]
|
|
1196
|
+
|
|
1197
|
+
if (
|
|
1198
|
+
return_main_text
|
|
1199
|
+
and "mainText" in item
|
|
1200
|
+
and item.get("mainText")
|
|
1201
|
+
):
|
|
1202
|
+
result["main_text"] = item["mainText"]
|
|
1203
|
+
|
|
1204
|
+
if (
|
|
1205
|
+
return_markdown_text
|
|
1206
|
+
and "markdownText" in item
|
|
1207
|
+
and item.get("markdownText")
|
|
1208
|
+
):
|
|
1209
|
+
result["markdown_text"] = item["markdownText"]
|
|
1210
|
+
|
|
1211
|
+
if "score" in item:
|
|
1212
|
+
result["score"] = item["score"]
|
|
1213
|
+
|
|
1214
|
+
if "publishTime" in item:
|
|
1215
|
+
result["publish_time"] = item["publishTime"]
|
|
1216
|
+
|
|
1217
|
+
results.append(result)
|
|
1218
|
+
|
|
1219
|
+
# Return a simplified structure
|
|
1220
|
+
return {
|
|
1221
|
+
"request_id": data.get("requestId", ""),
|
|
1222
|
+
"results": results,
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1225
|
+
except requests.exceptions.RequestException as e:
|
|
1226
|
+
return {"error": f"Alibaba Tongxiao search request failed: {e!s}"}
|
|
1227
|
+
except Exception as e:
|
|
1228
|
+
return {
|
|
1229
|
+
"error": f"Unexpected error during Alibaba Tongxiao "
|
|
1230
|
+
f"search: {e!s}"
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1067
1233
|
def get_tools(self) -> List[FunctionTool]:
|
|
1068
1234
|
r"""Returns a list of FunctionTool objects representing the
|
|
1069
1235
|
functions in the toolkit.
|
|
@@ -1084,4 +1250,5 @@ class SearchToolkit(BaseToolkit):
|
|
|
1084
1250
|
FunctionTool(self.search_baidu),
|
|
1085
1251
|
FunctionTool(self.search_bing),
|
|
1086
1252
|
FunctionTool(self.search_exa),
|
|
1253
|
+
FunctionTool(self.search_alibaba_tongxiao),
|
|
1087
1254
|
]
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
|
|
15
|
+
import io
|
|
15
16
|
import os
|
|
16
17
|
import tempfile
|
|
17
18
|
from pathlib import Path
|
|
@@ -92,9 +93,15 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
92
93
|
transcription using OpenAI's audio models. Requires a valid OpenAI
|
|
93
94
|
API key. When disabled, video analysis will be based solely on
|
|
94
95
|
visual content. (default: :obj:`False`)
|
|
96
|
+
frame_interval (float, optional): Interval in seconds between frames
|
|
97
|
+
to extract from the video. (default: :obj:`4.0`)
|
|
98
|
+
output_language (str, optional): The language for output responses.
|
|
99
|
+
(default: :obj:`"English"`)
|
|
100
|
+
cookies_path (Optional[str]): The path to the cookies file
|
|
101
|
+
for the video service in Netscape format. (default: :obj:`None`)
|
|
95
102
|
timeout (Optional[float]): The timeout value for API requests
|
|
96
|
-
|
|
97
|
-
|
|
103
|
+
in seconds. If None, no timeout is applied.
|
|
104
|
+
(default: :obj:`None`)
|
|
98
105
|
"""
|
|
99
106
|
|
|
100
107
|
@dependencies_required("ffmpeg", "scenedetect")
|
|
@@ -103,27 +110,29 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
103
110
|
download_directory: Optional[str] = None,
|
|
104
111
|
model: Optional[BaseModelBackend] = None,
|
|
105
112
|
use_audio_transcription: bool = False,
|
|
113
|
+
frame_interval: float = 4.0,
|
|
114
|
+
output_language: str = "English",
|
|
115
|
+
cookies_path: Optional[str] = None,
|
|
106
116
|
timeout: Optional[float] = None,
|
|
107
117
|
) -> None:
|
|
108
118
|
super().__init__(timeout=timeout)
|
|
109
119
|
self._cleanup = download_directory is None
|
|
110
120
|
self._temp_files: list[str] = [] # Track temporary files for cleanup
|
|
111
121
|
self._use_audio_transcription = use_audio_transcription
|
|
122
|
+
self.output_language = output_language
|
|
123
|
+
self.frame_interval = frame_interval
|
|
112
124
|
|
|
113
125
|
self._download_directory = Path(
|
|
114
126
|
download_directory or tempfile.mkdtemp()
|
|
115
127
|
).resolve()
|
|
116
128
|
|
|
117
129
|
self.video_downloader_toolkit = VideoDownloaderToolkit(
|
|
118
|
-
download_directory=str(self._download_directory)
|
|
130
|
+
download_directory=str(self._download_directory),
|
|
131
|
+
cookies_path=cookies_path,
|
|
119
132
|
)
|
|
120
133
|
|
|
121
134
|
try:
|
|
122
135
|
self._download_directory.mkdir(parents=True, exist_ok=True)
|
|
123
|
-
except FileExistsError:
|
|
124
|
-
raise ValueError(
|
|
125
|
-
f"{self._download_directory} is not a valid directory."
|
|
126
|
-
)
|
|
127
136
|
except OSError as e:
|
|
128
137
|
raise ValueError(
|
|
129
138
|
f"Error creating directory {self._download_directory}: {e}"
|
|
@@ -137,16 +146,18 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
137
146
|
# Import ChatAgent at runtime to avoid circular imports
|
|
138
147
|
from camel.agents import ChatAgent
|
|
139
148
|
|
|
140
|
-
self.vl_agent = ChatAgent(
|
|
149
|
+
self.vl_agent = ChatAgent(
|
|
150
|
+
model=self.vl_model, output_language=self.output_language
|
|
151
|
+
)
|
|
141
152
|
else:
|
|
142
153
|
# If no model is provided, use default model in ChatAgent
|
|
143
154
|
# Import ChatAgent at runtime to avoid circular imports
|
|
144
155
|
from camel.agents import ChatAgent
|
|
145
156
|
|
|
146
|
-
self.vl_agent = ChatAgent()
|
|
157
|
+
self.vl_agent = ChatAgent(output_language=self.output_language)
|
|
147
158
|
logger.warning(
|
|
148
|
-
"No vision-language model provided. Using default model in"
|
|
149
|
-
"
|
|
159
|
+
"No vision-language model provided. Using default model in "
|
|
160
|
+
"ChatAgent."
|
|
150
161
|
)
|
|
151
162
|
|
|
152
163
|
# Initialize audio models only if audio transcription is enabled
|
|
@@ -179,16 +190,22 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
179
190
|
# Clean up temporary directory if needed
|
|
180
191
|
if self._cleanup and os.path.exists(self._download_directory):
|
|
181
192
|
try:
|
|
182
|
-
import
|
|
193
|
+
import sys
|
|
183
194
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
195
|
+
if getattr(sys, 'modules', None) is not None:
|
|
196
|
+
import shutil
|
|
197
|
+
|
|
198
|
+
shutil.rmtree(self._download_directory)
|
|
199
|
+
logger.debug(
|
|
200
|
+
f"Removed temp directory: {self._download_directory}"
|
|
201
|
+
)
|
|
202
|
+
except (ImportError, AttributeError):
|
|
203
|
+
# Skip cleanup if interpreter is shutting down
|
|
204
|
+
pass
|
|
188
205
|
except OSError as e:
|
|
189
206
|
logger.warning(
|
|
190
|
-
f"Failed to remove temporary directory"
|
|
191
|
-
f"
|
|
207
|
+
f"Failed to remove temporary directory "
|
|
208
|
+
f"{self._download_directory}: {e}"
|
|
192
209
|
)
|
|
193
210
|
|
|
194
211
|
def _extract_audio_from_video(
|
|
@@ -242,88 +259,217 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
242
259
|
logger.error(f"Audio transcription failed: {e}")
|
|
243
260
|
return "Audio transcription failed."
|
|
244
261
|
|
|
245
|
-
def _extract_keyframes(
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
r"""Extract keyframes from a video based on scene changes
|
|
249
|
-
and return them as PIL.Image.Image objects.
|
|
262
|
+
def _extract_keyframes(self, video_path: str) -> List[Image.Image]:
|
|
263
|
+
r"""Extract keyframes from a video based on scene changes and
|
|
264
|
+
regular intervals,and return them as PIL.Image.Image objects.
|
|
250
265
|
|
|
251
266
|
Args:
|
|
252
267
|
video_path (str): Path to the video file.
|
|
253
|
-
num_frames (int): Number of keyframes to extract.
|
|
254
|
-
threshold (float): The threshold value for scene change detection.
|
|
255
268
|
|
|
256
269
|
Returns:
|
|
257
|
-
|
|
270
|
+
List[Image.Image]: A list of PIL.Image.Image objects representing
|
|
258
271
|
the extracted keyframes.
|
|
272
|
+
|
|
273
|
+
Raises:
|
|
274
|
+
ValueError: If no frames could be extracted from the video.
|
|
259
275
|
"""
|
|
276
|
+
import cv2
|
|
277
|
+
import numpy as np
|
|
260
278
|
from scenedetect import ( # type: ignore[import-untyped]
|
|
261
279
|
SceneManager,
|
|
262
|
-
|
|
280
|
+
open_video,
|
|
263
281
|
)
|
|
264
282
|
from scenedetect.detectors import ( # type: ignore[import-untyped]
|
|
265
283
|
ContentDetector,
|
|
266
284
|
)
|
|
267
285
|
|
|
268
|
-
|
|
286
|
+
# Get video information
|
|
287
|
+
cap = cv2.VideoCapture(video_path)
|
|
288
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
289
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
290
|
+
duration = total_frames / fps if fps > 0 else 0
|
|
291
|
+
cap.release()
|
|
292
|
+
|
|
293
|
+
frame_interval = self.frame_interval # seconds
|
|
294
|
+
# Maximum number of frames to extract to avoid memory issues
|
|
295
|
+
MAX_FRAMES = 100
|
|
296
|
+
# Minimum time difference (in seconds) to consider frames as distinct
|
|
297
|
+
TIME_THRESHOLD = 1.0
|
|
298
|
+
|
|
299
|
+
# Calculate the total number of frames to extract
|
|
300
|
+
if duration <= 0 or fps <= 0:
|
|
269
301
|
logger.warning(
|
|
270
|
-
|
|
302
|
+
"Invalid video duration or fps, using default frame count"
|
|
271
303
|
)
|
|
272
|
-
num_frames =
|
|
304
|
+
num_frames = 10
|
|
305
|
+
else:
|
|
306
|
+
num_frames = max(int(duration / frame_interval), 1)
|
|
307
|
+
|
|
308
|
+
if num_frames > MAX_FRAMES:
|
|
309
|
+
frame_interval = duration / MAX_FRAMES
|
|
310
|
+
num_frames = MAX_FRAMES
|
|
273
311
|
|
|
274
|
-
|
|
312
|
+
logger.info(
|
|
313
|
+
f"Video duration: {duration:.2f}s, target frames: {num_frames}"
|
|
314
|
+
f"at {frame_interval:.2f}s intervals"
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Use scene detection to extract keyframes
|
|
318
|
+
# Use open_video instead of VideoManager
|
|
319
|
+
video = open_video(video_path)
|
|
275
320
|
scene_manager = SceneManager()
|
|
276
|
-
scene_manager.add_detector(ContentDetector(
|
|
321
|
+
scene_manager.add_detector(ContentDetector())
|
|
277
322
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
scene_manager.detect_scenes(video_manager)
|
|
323
|
+
# Detect scenes using the modern API
|
|
324
|
+
scene_manager.detect_scenes(video)
|
|
281
325
|
|
|
282
326
|
scenes = scene_manager.get_scene_list()
|
|
283
327
|
keyframes: List[Image.Image] = []
|
|
284
328
|
|
|
285
|
-
#
|
|
286
|
-
if
|
|
329
|
+
# If scene detection is successful, prioritize scene change points
|
|
330
|
+
if scenes:
|
|
331
|
+
logger.info(f"Detected {len(scenes)} scene changes")
|
|
332
|
+
|
|
333
|
+
if len(scenes) > num_frames:
|
|
334
|
+
scene_indices = np.linspace(
|
|
335
|
+
0, len(scenes) - 1, num_frames, dtype=int
|
|
336
|
+
)
|
|
337
|
+
selected_scenes = [scenes[i] for i in scene_indices]
|
|
338
|
+
else:
|
|
339
|
+
selected_scenes = scenes
|
|
340
|
+
|
|
341
|
+
# Extract frames from scenes
|
|
342
|
+
for scene in selected_scenes:
|
|
343
|
+
try:
|
|
344
|
+
# Get start time in seconds
|
|
345
|
+
start_time = scene[0].get_seconds()
|
|
346
|
+
frame = _capture_screenshot(video_path, start_time)
|
|
347
|
+
keyframes.append(frame)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
logger.warning(
|
|
350
|
+
f"Failed to capture frame at scene change"
|
|
351
|
+
f" {scene[0].get_seconds()}s: {e}"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if len(keyframes) < num_frames and duration > 0:
|
|
355
|
+
logger.info(
|
|
356
|
+
f"Scene detection provided {len(keyframes)} frames, "
|
|
357
|
+
f"supplementing with regular interval frames"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
existing_times = []
|
|
361
|
+
if scenes:
|
|
362
|
+
existing_times = [scene[0].get_seconds() for scene in scenes]
|
|
363
|
+
|
|
364
|
+
regular_frames = []
|
|
365
|
+
for i in range(num_frames):
|
|
366
|
+
time_sec = i * frame_interval
|
|
367
|
+
|
|
368
|
+
is_duplicate = False
|
|
369
|
+
for existing_time in existing_times:
|
|
370
|
+
if abs(existing_time - time_sec) < TIME_THRESHOLD:
|
|
371
|
+
is_duplicate = True
|
|
372
|
+
break
|
|
373
|
+
|
|
374
|
+
if not is_duplicate:
|
|
375
|
+
try:
|
|
376
|
+
frame = _capture_screenshot(video_path, time_sec)
|
|
377
|
+
regular_frames.append(frame)
|
|
378
|
+
except Exception as e:
|
|
379
|
+
logger.warning(
|
|
380
|
+
f"Failed to capture frame at {time_sec}s: {e}"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
frames_needed = num_frames - len(keyframes)
|
|
384
|
+
if frames_needed > 0 and regular_frames:
|
|
385
|
+
if len(regular_frames) > frames_needed:
|
|
386
|
+
indices = np.linspace(
|
|
387
|
+
0, len(regular_frames) - 1, frames_needed, dtype=int
|
|
388
|
+
)
|
|
389
|
+
selected_frames = [regular_frames[i] for i in indices]
|
|
390
|
+
else:
|
|
391
|
+
selected_frames = regular_frames
|
|
392
|
+
|
|
393
|
+
keyframes.extend(selected_frames)
|
|
394
|
+
|
|
395
|
+
if not keyframes:
|
|
287
396
|
logger.warning(
|
|
288
|
-
"No
|
|
289
|
-
"
|
|
397
|
+
"No frames extracted, falling back to simple interval"
|
|
398
|
+
"extraction"
|
|
290
399
|
)
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
duration = total_frames / fps if fps > 0 else 0
|
|
297
|
-
|
|
298
|
-
if duration > 0 and total_frames > 0:
|
|
299
|
-
# Extract frames at regular intervals
|
|
300
|
-
interval = duration / min(num_frames, total_frames)
|
|
301
|
-
for i in range(min(num_frames, total_frames)):
|
|
302
|
-
time_sec = i * interval
|
|
400
|
+
for i in range(
|
|
401
|
+
min(num_frames, 10)
|
|
402
|
+
): # Limit to a maximum of 10 frames to avoid infinite loops
|
|
403
|
+
time_sec = i * (duration / 10 if duration > 0 else 6.0)
|
|
404
|
+
try:
|
|
303
405
|
frame = _capture_screenshot(video_path, time_sec)
|
|
304
406
|
keyframes.append(frame)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
for start_time, _ in scenes:
|
|
310
|
-
if len(keyframes) >= num_frames:
|
|
311
|
-
break
|
|
312
|
-
frame = _capture_screenshot(video_path, start_time)
|
|
313
|
-
keyframes.append(frame)
|
|
407
|
+
except Exception as e:
|
|
408
|
+
logger.warning(
|
|
409
|
+
f"Failed to capture frame at {time_sec}s: {e}"
|
|
410
|
+
)
|
|
314
411
|
|
|
315
412
|
if not keyframes:
|
|
316
|
-
|
|
317
|
-
|
|
413
|
+
error_msg = (
|
|
414
|
+
f"Failed to extract any keyframes from video: {video_path}"
|
|
415
|
+
)
|
|
416
|
+
logger.error(error_msg)
|
|
417
|
+
raise ValueError(error_msg)
|
|
418
|
+
|
|
419
|
+
# Normalize image sizes
|
|
420
|
+
normalized_keyframes = self._normalize_frames(keyframes)
|
|
421
|
+
|
|
422
|
+
logger.info(
|
|
423
|
+
f"Extracted and normalized {len(normalized_keyframes)} keyframes"
|
|
424
|
+
)
|
|
425
|
+
return normalized_keyframes
|
|
426
|
+
|
|
427
|
+
def _normalize_frames(
|
|
428
|
+
self, frames: List[Image.Image], target_width: int = 512
|
|
429
|
+
) -> List[Image.Image]:
|
|
430
|
+
r"""Normalize the size of extracted frames.
|
|
318
431
|
|
|
319
|
-
|
|
320
|
-
|
|
432
|
+
Args:
|
|
433
|
+
frames (List[Image.Image]): List of frames to normalize.
|
|
434
|
+
target_width (int): Target width for normalized frames.
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
List[Image.Image]: List of normalized frames.
|
|
438
|
+
"""
|
|
439
|
+
normalized_frames: List[Image.Image] = []
|
|
440
|
+
|
|
441
|
+
for frame in frames:
|
|
442
|
+
# Get original dimensions
|
|
443
|
+
width, height = frame.size
|
|
444
|
+
|
|
445
|
+
# Calculate new height, maintaining aspect ratio
|
|
446
|
+
aspect_ratio = width / height
|
|
447
|
+
new_height = int(target_width / aspect_ratio)
|
|
448
|
+
|
|
449
|
+
# Resize image
|
|
450
|
+
resized_frame = frame.resize(
|
|
451
|
+
(target_width, new_height), Image.Resampling.LANCZOS
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Ensure the image has a proper format
|
|
455
|
+
if resized_frame.mode != 'RGB':
|
|
456
|
+
resized_frame = resized_frame.convert('RGB')
|
|
457
|
+
|
|
458
|
+
# Create a new image with explicit format
|
|
459
|
+
with io.BytesIO() as buffer:
|
|
460
|
+
resized_frame.save(buffer, format='JPEG')
|
|
461
|
+
buffer.seek(0)
|
|
462
|
+
formatted_frame = Image.open(buffer)
|
|
463
|
+
formatted_frame.load() # Load the image data
|
|
464
|
+
|
|
465
|
+
normalized_frames.append(formatted_frame)
|
|
466
|
+
|
|
467
|
+
return normalized_frames
|
|
321
468
|
|
|
322
469
|
def ask_question_about_video(
|
|
323
470
|
self,
|
|
324
471
|
video_path: str,
|
|
325
472
|
question: str,
|
|
326
|
-
num_frames: int = 28,
|
|
327
473
|
) -> str:
|
|
328
474
|
r"""Ask a question about the video.
|
|
329
475
|
|
|
@@ -331,24 +477,12 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
331
477
|
video_path (str): The path to the video file.
|
|
332
478
|
It can be a local file or a URL (such as Youtube website).
|
|
333
479
|
question (str): The question to ask about the video.
|
|
334
|
-
num_frames (int): The number of frames to extract from the video.
|
|
335
|
-
To be adjusted based on the length of the video.
|
|
336
|
-
(default: :obj:`28`)
|
|
337
480
|
|
|
338
481
|
Returns:
|
|
339
482
|
str: The answer to the question.
|
|
340
483
|
"""
|
|
341
484
|
from urllib.parse import urlparse
|
|
342
485
|
|
|
343
|
-
if not question:
|
|
344
|
-
raise ValueError("Question cannot be empty")
|
|
345
|
-
|
|
346
|
-
if num_frames <= 0:
|
|
347
|
-
logger.warning(
|
|
348
|
-
f"Invalid num_frames: {num_frames}, using default of 28"
|
|
349
|
-
)
|
|
350
|
-
num_frames = 28
|
|
351
|
-
|
|
352
486
|
parsed_url = urlparse(video_path)
|
|
353
487
|
is_url = all([parsed_url.scheme, parsed_url.netloc])
|
|
354
488
|
|
|
@@ -374,7 +508,7 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
374
508
|
audio_path = self._extract_audio_from_video(video_path)
|
|
375
509
|
audio_transcript = self._transcribe_audio(audio_path)
|
|
376
510
|
|
|
377
|
-
video_frames = self._extract_keyframes(video_path
|
|
511
|
+
video_frames = self._extract_keyframes(video_path)
|
|
378
512
|
prompt = VIDEO_QA_PROMPT.format(
|
|
379
513
|
audio_transcription=audio_transcript,
|
|
380
514
|
question=question,
|
|
@@ -385,7 +519,8 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
385
519
|
content=prompt,
|
|
386
520
|
image_list=video_frames,
|
|
387
521
|
)
|
|
388
|
-
|
|
522
|
+
# Reset the agent to clear previous state
|
|
523
|
+
self.vl_agent.reset()
|
|
389
524
|
response = self.vl_agent.step(msg)
|
|
390
525
|
if not response or not response.msgs:
|
|
391
526
|
logger.error("Model returned empty response")
|
|
@@ -398,7 +533,7 @@ class VideoAnalysisToolkit(BaseToolkit):
|
|
|
398
533
|
return answer
|
|
399
534
|
|
|
400
535
|
except Exception as e:
|
|
401
|
-
error_message = f"Error processing video: {e
|
|
536
|
+
error_message = f"Error processing video: {e}"
|
|
402
537
|
logger.error(error_message)
|
|
403
538
|
return f"Error: {error_message}"
|
|
404
539
|
|
|
@@ -102,10 +102,17 @@ class VideoDownloaderToolkit(BaseToolkit):
|
|
|
102
102
|
Cleans up the downloaded video if they are stored in a temporary
|
|
103
103
|
directory.
|
|
104
104
|
"""
|
|
105
|
-
import shutil
|
|
106
|
-
|
|
107
105
|
if self._cleanup:
|
|
108
|
-
|
|
106
|
+
try:
|
|
107
|
+
import sys
|
|
108
|
+
|
|
109
|
+
if getattr(sys, 'modules', None) is not None:
|
|
110
|
+
import shutil
|
|
111
|
+
|
|
112
|
+
shutil.rmtree(self._download_directory, ignore_errors=True)
|
|
113
|
+
except (ImportError, AttributeError):
|
|
114
|
+
# Skip cleanup if interpreter is shutting down
|
|
115
|
+
pass
|
|
109
116
|
|
|
110
117
|
def download_video(self, url: str) -> str:
|
|
111
118
|
r"""Download the video and optionally split it into chunks.
|