camel-ai 0.2.45__py3-none-any.whl → 0.2.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (54) hide show
  1. camel/__init__.py +1 -1
  2. camel/configs/__init__.py +6 -0
  3. camel/configs/bedrock_config.py +73 -0
  4. camel/configs/lmstudio_config.py +94 -0
  5. camel/configs/qwen_config.py +3 -3
  6. camel/datasets/few_shot_generator.py +19 -3
  7. camel/datasets/models.py +1 -1
  8. camel/loaders/__init__.py +2 -0
  9. camel/loaders/scrapegraph_reader.py +96 -0
  10. camel/models/__init__.py +4 -0
  11. camel/models/aiml_model.py +11 -104
  12. camel/models/anthropic_model.py +11 -76
  13. camel/models/aws_bedrock_model.py +112 -0
  14. camel/models/deepseek_model.py +11 -44
  15. camel/models/gemini_model.py +10 -72
  16. camel/models/groq_model.py +11 -131
  17. camel/models/internlm_model.py +11 -61
  18. camel/models/lmstudio_model.py +82 -0
  19. camel/models/model_factory.py +7 -1
  20. camel/models/modelscope_model.py +11 -122
  21. camel/models/moonshot_model.py +10 -76
  22. camel/models/nemotron_model.py +4 -60
  23. camel/models/nvidia_model.py +11 -111
  24. camel/models/ollama_model.py +12 -205
  25. camel/models/openai_compatible_model.py +51 -12
  26. camel/models/openai_model.py +3 -1
  27. camel/models/openrouter_model.py +12 -131
  28. camel/models/ppio_model.py +10 -99
  29. camel/models/qwen_model.py +11 -122
  30. camel/models/reka_model.py +1 -1
  31. camel/models/sglang_model.py +5 -3
  32. camel/models/siliconflow_model.py +10 -58
  33. camel/models/togetherai_model.py +10 -177
  34. camel/models/vllm_model.py +11 -218
  35. camel/models/volcano_model.py +1 -15
  36. camel/models/yi_model.py +11 -98
  37. camel/models/zhipuai_model.py +11 -102
  38. camel/storages/__init__.py +2 -0
  39. camel/storages/vectordb_storages/__init__.py +2 -0
  40. camel/storages/vectordb_storages/oceanbase.py +458 -0
  41. camel/toolkits/__init__.py +4 -0
  42. camel/toolkits/browser_toolkit.py +4 -7
  43. camel/toolkits/jina_reranker_toolkit.py +231 -0
  44. camel/toolkits/pyautogui_toolkit.py +428 -0
  45. camel/toolkits/search_toolkit.py +167 -0
  46. camel/toolkits/video_analysis_toolkit.py +215 -80
  47. camel/toolkits/video_download_toolkit.py +10 -3
  48. camel/types/enums.py +70 -0
  49. camel/types/unified_model_type.py +10 -0
  50. camel/utils/token_counting.py +7 -3
  51. {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/METADATA +13 -1
  52. {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/RECORD +54 -46
  53. {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/WHEEL +0 -0
  54. {camel_ai-0.2.45.dist-info → camel_ai-0.2.47.dist-info}/licenses/LICENSE +0 -0
@@ -1064,6 +1064,172 @@ class SearchToolkit(BaseToolkit):
1064
1064
  except Exception as e:
1065
1065
  return {"error": f"Exa search failed: {e!s}"}
1066
1066
 
1067
+ @api_keys_required([(None, 'TONGXIAO_API_KEY')])
1068
+ def search_alibaba_tongxiao(
1069
+ self,
1070
+ query: str,
1071
+ time_range: Literal[
1072
+ "OneDay", "OneWeek", "OneMonth", "OneYear", "NoLimit"
1073
+ ] = "NoLimit",
1074
+ industry: Optional[
1075
+ Literal[
1076
+ "finance",
1077
+ "law",
1078
+ "medical",
1079
+ "internet",
1080
+ "tax",
1081
+ "news_province",
1082
+ "news_center",
1083
+ ]
1084
+ ] = None,
1085
+ page: int = 1,
1086
+ return_main_text: bool = False,
1087
+ return_markdown_text: bool = True,
1088
+ enable_rerank: bool = True,
1089
+ ) -> Dict[str, Any]:
1090
+ r"""Query the Alibaba Tongxiao search API and return search results.
1091
+
1092
+ A powerful search API optimized for Chinese language queries with
1093
+ features:
1094
+ - Enhanced Chinese language understanding
1095
+ - Industry-specific filtering (finance, law, medical, etc.)
1096
+ - Structured data with markdown formatting
1097
+ - Result reranking for relevance
1098
+ - Time-based filtering
1099
+
1100
+ Args:
1101
+ query (str): The search query string (length >= 1 and <= 100).
1102
+ time_range (Literal["OneDay", "OneWeek", "OneMonth", "OneYear",
1103
+ "NoLimit"]): Time frame filter for search results.
1104
+ (default: :obj:`"NoLimit"`)
1105
+ industry (Optional[Literal["finance", "law", "medical",
1106
+ "internet", "tax", "news_province", "news_center"]]):
1107
+ Industry-specific search filter. When specified, only returns
1108
+ results from sites in the specified industries. Multiple
1109
+ industries can be comma-separated.
1110
+ (default: :obj:`None`)
1111
+ page (int): Page number for results pagination.
1112
+ (default: :obj:`1`)
1113
+ return_main_text (bool): Whether to include the main text of the
1114
+ webpage in results. (default: :obj:`True`)
1115
+ return_markdown_text (bool): Whether to include markdown formatted
1116
+ content in results. (default: :obj:`True`)
1117
+ enable_rerank (bool): Whether to enable result reranking. If
1118
+ response time is critical, setting this to False can reduce
1119
+ response time by approximately 140ms. (default: :obj:`True`)
1120
+
1121
+ Returns:
1122
+ Dict[str, Any]: A dictionary containing either search results with
1123
+ 'requestId' and 'results' keys, or an 'error' key with error
1124
+ message. Each result contains title, snippet, url and other
1125
+ metadata.
1126
+ """
1127
+ TONGXIAO_API_KEY = os.getenv("TONGXIAO_API_KEY")
1128
+
1129
+ # Validate query length
1130
+ if not query or len(query) > 100:
1131
+ return {
1132
+ "error": "Query length must be between 1 and 100 characters"
1133
+ }
1134
+
1135
+ # API endpoint and parameters
1136
+ base_url = "https://cloud-iqs.aliyuncs.com/search/genericSearch"
1137
+ headers = {
1138
+ "X-API-Key": TONGXIAO_API_KEY,
1139
+ }
1140
+
1141
+ # Convert boolean parameters to string for compatibility with requests
1142
+ params: Dict[str, Union[str, int]] = {
1143
+ "query": query,
1144
+ "timeRange": time_range,
1145
+ "page": page,
1146
+ "returnMainText": str(return_main_text).lower(),
1147
+ "returnMarkdownText": str(return_markdown_text).lower(),
1148
+ "enableRerank": str(enable_rerank).lower(),
1149
+ }
1150
+
1151
+ # Only add industry parameter if specified
1152
+ if industry is not None:
1153
+ params["industry"] = industry
1154
+
1155
+ try:
1156
+ # Send GET request with proper typing for params
1157
+ response = requests.get(
1158
+ base_url, headers=headers, params=params, timeout=10
1159
+ )
1160
+
1161
+ # Check response status
1162
+ if response.status_code != 200:
1163
+ return {
1164
+ "error": (
1165
+ f"Alibaba Tongxiao API request failed with status "
1166
+ f"code {response.status_code}: {response.text}"
1167
+ )
1168
+ }
1169
+
1170
+ # Parse JSON response
1171
+ data = response.json()
1172
+
1173
+ # Extract and format pageItems
1174
+ page_items = data.get("pageItems", [])
1175
+ results = []
1176
+ for idx, item in enumerate(page_items):
1177
+ # Create a simplified result structure
1178
+ result = {
1179
+ "result_id": idx + 1,
1180
+ "title": item.get("title", ""),
1181
+ "snippet": item.get("snippet", ""),
1182
+ "url": item.get("link", ""),
1183
+ "hostname": item.get("hostname", ""),
1184
+ }
1185
+
1186
+ # Only include additional fields if they exist and are
1187
+ # requested
1188
+ if "summary" in item and item.get("summary"):
1189
+ result["summary"] = item["summary"]
1190
+ elif (
1191
+ return_main_text
1192
+ and "mainText" in item
1193
+ and item.get("mainText")
1194
+ ):
1195
+ result["summary"] = item["mainText"]
1196
+
1197
+ if (
1198
+ return_main_text
1199
+ and "mainText" in item
1200
+ and item.get("mainText")
1201
+ ):
1202
+ result["main_text"] = item["mainText"]
1203
+
1204
+ if (
1205
+ return_markdown_text
1206
+ and "markdownText" in item
1207
+ and item.get("markdownText")
1208
+ ):
1209
+ result["markdown_text"] = item["markdownText"]
1210
+
1211
+ if "score" in item:
1212
+ result["score"] = item["score"]
1213
+
1214
+ if "publishTime" in item:
1215
+ result["publish_time"] = item["publishTime"]
1216
+
1217
+ results.append(result)
1218
+
1219
+ # Return a simplified structure
1220
+ return {
1221
+ "request_id": data.get("requestId", ""),
1222
+ "results": results,
1223
+ }
1224
+
1225
+ except requests.exceptions.RequestException as e:
1226
+ return {"error": f"Alibaba Tongxiao search request failed: {e!s}"}
1227
+ except Exception as e:
1228
+ return {
1229
+ "error": f"Unexpected error during Alibaba Tongxiao "
1230
+ f"search: {e!s}"
1231
+ }
1232
+
1067
1233
  def get_tools(self) -> List[FunctionTool]:
1068
1234
  r"""Returns a list of FunctionTool objects representing the
1069
1235
  functions in the toolkit.
@@ -1084,4 +1250,5 @@ class SearchToolkit(BaseToolkit):
1084
1250
  FunctionTool(self.search_baidu),
1085
1251
  FunctionTool(self.search_bing),
1086
1252
  FunctionTool(self.search_exa),
1253
+ FunctionTool(self.search_alibaba_tongxiao),
1087
1254
  ]
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
+ import io
15
16
  import os
16
17
  import tempfile
17
18
  from pathlib import Path
@@ -92,9 +93,15 @@ class VideoAnalysisToolkit(BaseToolkit):
92
93
  transcription using OpenAI's audio models. Requires a valid OpenAI
93
94
  API key. When disabled, video analysis will be based solely on
94
95
  visual content. (default: :obj:`False`)
96
+ frame_interval (float, optional): Interval in seconds between frames
97
+ to extract from the video. (default: :obj:`4.0`)
98
+ output_language (str, optional): The language for output responses.
99
+ (default: :obj:`"English"`)
100
+ cookies_path (Optional[str]): The path to the cookies file
101
+ for the video service in Netscape format. (default: :obj:`None`)
95
102
  timeout (Optional[float]): The timeout value for API requests
96
- in seconds. If None, no timeout is applied.
97
- (default: :obj:`None`)
103
+ in seconds. If None, no timeout is applied.
104
+ (default: :obj:`None`)
98
105
  """
99
106
 
100
107
  @dependencies_required("ffmpeg", "scenedetect")
@@ -103,27 +110,29 @@ class VideoAnalysisToolkit(BaseToolkit):
103
110
  download_directory: Optional[str] = None,
104
111
  model: Optional[BaseModelBackend] = None,
105
112
  use_audio_transcription: bool = False,
113
+ frame_interval: float = 4.0,
114
+ output_language: str = "English",
115
+ cookies_path: Optional[str] = None,
106
116
  timeout: Optional[float] = None,
107
117
  ) -> None:
108
118
  super().__init__(timeout=timeout)
109
119
  self._cleanup = download_directory is None
110
120
  self._temp_files: list[str] = [] # Track temporary files for cleanup
111
121
  self._use_audio_transcription = use_audio_transcription
122
+ self.output_language = output_language
123
+ self.frame_interval = frame_interval
112
124
 
113
125
  self._download_directory = Path(
114
126
  download_directory or tempfile.mkdtemp()
115
127
  ).resolve()
116
128
 
117
129
  self.video_downloader_toolkit = VideoDownloaderToolkit(
118
- download_directory=str(self._download_directory)
130
+ download_directory=str(self._download_directory),
131
+ cookies_path=cookies_path,
119
132
  )
120
133
 
121
134
  try:
122
135
  self._download_directory.mkdir(parents=True, exist_ok=True)
123
- except FileExistsError:
124
- raise ValueError(
125
- f"{self._download_directory} is not a valid directory."
126
- )
127
136
  except OSError as e:
128
137
  raise ValueError(
129
138
  f"Error creating directory {self._download_directory}: {e}"
@@ -137,16 +146,18 @@ class VideoAnalysisToolkit(BaseToolkit):
137
146
  # Import ChatAgent at runtime to avoid circular imports
138
147
  from camel.agents import ChatAgent
139
148
 
140
- self.vl_agent = ChatAgent(model=self.vl_model)
149
+ self.vl_agent = ChatAgent(
150
+ model=self.vl_model, output_language=self.output_language
151
+ )
141
152
  else:
142
153
  # If no model is provided, use default model in ChatAgent
143
154
  # Import ChatAgent at runtime to avoid circular imports
144
155
  from camel.agents import ChatAgent
145
156
 
146
- self.vl_agent = ChatAgent()
157
+ self.vl_agent = ChatAgent(output_language=self.output_language)
147
158
  logger.warning(
148
- "No vision-language model provided. Using default model in"
149
- " ChatAgent."
159
+ "No vision-language model provided. Using default model in "
160
+ "ChatAgent."
150
161
  )
151
162
 
152
163
  # Initialize audio models only if audio transcription is enabled
@@ -179,16 +190,22 @@ class VideoAnalysisToolkit(BaseToolkit):
179
190
  # Clean up temporary directory if needed
180
191
  if self._cleanup and os.path.exists(self._download_directory):
181
192
  try:
182
- import shutil
193
+ import sys
183
194
 
184
- shutil.rmtree(self._download_directory)
185
- logger.debug(
186
- f"Removed temporary directory: {self._download_directory}"
187
- )
195
+ if getattr(sys, 'modules', None) is not None:
196
+ import shutil
197
+
198
+ shutil.rmtree(self._download_directory)
199
+ logger.debug(
200
+ f"Removed temp directory: {self._download_directory}"
201
+ )
202
+ except (ImportError, AttributeError):
203
+ # Skip cleanup if interpreter is shutting down
204
+ pass
188
205
  except OSError as e:
189
206
  logger.warning(
190
- f"Failed to remove temporary directory"
191
- f" {self._download_directory}: {e}"
207
+ f"Failed to remove temporary directory "
208
+ f"{self._download_directory}: {e}"
192
209
  )
193
210
 
194
211
  def _extract_audio_from_video(
@@ -242,88 +259,217 @@ class VideoAnalysisToolkit(BaseToolkit):
242
259
  logger.error(f"Audio transcription failed: {e}")
243
260
  return "Audio transcription failed."
244
261
 
245
- def _extract_keyframes(
246
- self, video_path: str, num_frames: int, threshold: float = 25.0
247
- ) -> List[Image.Image]:
248
- r"""Extract keyframes from a video based on scene changes
249
- and return them as PIL.Image.Image objects.
262
+ def _extract_keyframes(self, video_path: str) -> List[Image.Image]:
263
+ r"""Extract keyframes from a video based on scene changes and
264
+ regular intervals,and return them as PIL.Image.Image objects.
250
265
 
251
266
  Args:
252
267
  video_path (str): Path to the video file.
253
- num_frames (int): Number of keyframes to extract.
254
- threshold (float): The threshold value for scene change detection.
255
268
 
256
269
  Returns:
257
- list: A list of PIL.Image.Image objects representing
270
+ List[Image.Image]: A list of PIL.Image.Image objects representing
258
271
  the extracted keyframes.
272
+
273
+ Raises:
274
+ ValueError: If no frames could be extracted from the video.
259
275
  """
276
+ import cv2
277
+ import numpy as np
260
278
  from scenedetect import ( # type: ignore[import-untyped]
261
279
  SceneManager,
262
- VideoManager,
280
+ open_video,
263
281
  )
264
282
  from scenedetect.detectors import ( # type: ignore[import-untyped]
265
283
  ContentDetector,
266
284
  )
267
285
 
268
- if num_frames <= 0:
286
+ # Get video information
287
+ cap = cv2.VideoCapture(video_path)
288
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
289
+ fps = cap.get(cv2.CAP_PROP_FPS)
290
+ duration = total_frames / fps if fps > 0 else 0
291
+ cap.release()
292
+
293
+ frame_interval = self.frame_interval # seconds
294
+ # Maximum number of frames to extract to avoid memory issues
295
+ MAX_FRAMES = 100
296
+ # Minimum time difference (in seconds) to consider frames as distinct
297
+ TIME_THRESHOLD = 1.0
298
+
299
+ # Calculate the total number of frames to extract
300
+ if duration <= 0 or fps <= 0:
269
301
  logger.warning(
270
- f"Invalid num_frames: {num_frames}, using default of 1"
302
+ "Invalid video duration or fps, using default frame count"
271
303
  )
272
- num_frames = 1
304
+ num_frames = 10
305
+ else:
306
+ num_frames = max(int(duration / frame_interval), 1)
307
+
308
+ if num_frames > MAX_FRAMES:
309
+ frame_interval = duration / MAX_FRAMES
310
+ num_frames = MAX_FRAMES
273
311
 
274
- video_manager = VideoManager([video_path])
312
+ logger.info(
313
+ f"Video duration: {duration:.2f}s, target frames: {num_frames}"
314
+ f"at {frame_interval:.2f}s intervals"
315
+ )
316
+
317
+ # Use scene detection to extract keyframes
318
+ # Use open_video instead of VideoManager
319
+ video = open_video(video_path)
275
320
  scene_manager = SceneManager()
276
- scene_manager.add_detector(ContentDetector(threshold=threshold))
321
+ scene_manager.add_detector(ContentDetector())
277
322
 
278
- video_manager.set_duration()
279
- video_manager.start()
280
- scene_manager.detect_scenes(video_manager)
323
+ # Detect scenes using the modern API
324
+ scene_manager.detect_scenes(video)
281
325
 
282
326
  scenes = scene_manager.get_scene_list()
283
327
  keyframes: List[Image.Image] = []
284
328
 
285
- # Handle case where no scenes are detected
286
- if not scenes:
329
+ # If scene detection is successful, prioritize scene change points
330
+ if scenes:
331
+ logger.info(f"Detected {len(scenes)} scene changes")
332
+
333
+ if len(scenes) > num_frames:
334
+ scene_indices = np.linspace(
335
+ 0, len(scenes) - 1, num_frames, dtype=int
336
+ )
337
+ selected_scenes = [scenes[i] for i in scene_indices]
338
+ else:
339
+ selected_scenes = scenes
340
+
341
+ # Extract frames from scenes
342
+ for scene in selected_scenes:
343
+ try:
344
+ # Get start time in seconds
345
+ start_time = scene[0].get_seconds()
346
+ frame = _capture_screenshot(video_path, start_time)
347
+ keyframes.append(frame)
348
+ except Exception as e:
349
+ logger.warning(
350
+ f"Failed to capture frame at scene change"
351
+ f" {scene[0].get_seconds()}s: {e}"
352
+ )
353
+
354
+ if len(keyframes) < num_frames and duration > 0:
355
+ logger.info(
356
+ f"Scene detection provided {len(keyframes)} frames, "
357
+ f"supplementing with regular interval frames"
358
+ )
359
+
360
+ existing_times = []
361
+ if scenes:
362
+ existing_times = [scene[0].get_seconds() for scene in scenes]
363
+
364
+ regular_frames = []
365
+ for i in range(num_frames):
366
+ time_sec = i * frame_interval
367
+
368
+ is_duplicate = False
369
+ for existing_time in existing_times:
370
+ if abs(existing_time - time_sec) < TIME_THRESHOLD:
371
+ is_duplicate = True
372
+ break
373
+
374
+ if not is_duplicate:
375
+ try:
376
+ frame = _capture_screenshot(video_path, time_sec)
377
+ regular_frames.append(frame)
378
+ except Exception as e:
379
+ logger.warning(
380
+ f"Failed to capture frame at {time_sec}s: {e}"
381
+ )
382
+
383
+ frames_needed = num_frames - len(keyframes)
384
+ if frames_needed > 0 and regular_frames:
385
+ if len(regular_frames) > frames_needed:
386
+ indices = np.linspace(
387
+ 0, len(regular_frames) - 1, frames_needed, dtype=int
388
+ )
389
+ selected_frames = [regular_frames[i] for i in indices]
390
+ else:
391
+ selected_frames = regular_frames
392
+
393
+ keyframes.extend(selected_frames)
394
+
395
+ if not keyframes:
287
396
  logger.warning(
288
- "No scenes detected in video, capturing frames at "
289
- "regular intervals"
397
+ "No frames extracted, falling back to simple interval"
398
+ "extraction"
290
399
  )
291
- import cv2
292
-
293
- cap = cv2.VideoCapture(video_path)
294
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
295
- fps = cap.get(cv2.CAP_PROP_FPS)
296
- duration = total_frames / fps if fps > 0 else 0
297
-
298
- if duration > 0 and total_frames > 0:
299
- # Extract frames at regular intervals
300
- interval = duration / min(num_frames, total_frames)
301
- for i in range(min(num_frames, total_frames)):
302
- time_sec = i * interval
400
+ for i in range(
401
+ min(num_frames, 10)
402
+ ): # Limit to a maximum of 10 frames to avoid infinite loops
403
+ time_sec = i * (duration / 10 if duration > 0 else 6.0)
404
+ try:
303
405
  frame = _capture_screenshot(video_path, time_sec)
304
406
  keyframes.append(frame)
305
-
306
- cap.release()
307
- else:
308
- # Extract frames from detected scenes
309
- for start_time, _ in scenes:
310
- if len(keyframes) >= num_frames:
311
- break
312
- frame = _capture_screenshot(video_path, start_time)
313
- keyframes.append(frame)
407
+ except Exception as e:
408
+ logger.warning(
409
+ f"Failed to capture frame at {time_sec}s: {e}"
410
+ )
314
411
 
315
412
  if not keyframes:
316
- logger.error("Failed to extract any keyframes from video")
317
- raise ValueError("Failed to extract keyframes from video")
413
+ error_msg = (
414
+ f"Failed to extract any keyframes from video: {video_path}"
415
+ )
416
+ logger.error(error_msg)
417
+ raise ValueError(error_msg)
418
+
419
+ # Normalize image sizes
420
+ normalized_keyframes = self._normalize_frames(keyframes)
421
+
422
+ logger.info(
423
+ f"Extracted and normalized {len(normalized_keyframes)} keyframes"
424
+ )
425
+ return normalized_keyframes
426
+
427
+ def _normalize_frames(
428
+ self, frames: List[Image.Image], target_width: int = 512
429
+ ) -> List[Image.Image]:
430
+ r"""Normalize the size of extracted frames.
318
431
 
319
- logger.info(f"Extracted {len(keyframes)} keyframes")
320
- return keyframes
432
+ Args:
433
+ frames (List[Image.Image]): List of frames to normalize.
434
+ target_width (int): Target width for normalized frames.
435
+
436
+ Returns:
437
+ List[Image.Image]: List of normalized frames.
438
+ """
439
+ normalized_frames: List[Image.Image] = []
440
+
441
+ for frame in frames:
442
+ # Get original dimensions
443
+ width, height = frame.size
444
+
445
+ # Calculate new height, maintaining aspect ratio
446
+ aspect_ratio = width / height
447
+ new_height = int(target_width / aspect_ratio)
448
+
449
+ # Resize image
450
+ resized_frame = frame.resize(
451
+ (target_width, new_height), Image.Resampling.LANCZOS
452
+ )
453
+
454
+ # Ensure the image has a proper format
455
+ if resized_frame.mode != 'RGB':
456
+ resized_frame = resized_frame.convert('RGB')
457
+
458
+ # Create a new image with explicit format
459
+ with io.BytesIO() as buffer:
460
+ resized_frame.save(buffer, format='JPEG')
461
+ buffer.seek(0)
462
+ formatted_frame = Image.open(buffer)
463
+ formatted_frame.load() # Load the image data
464
+
465
+ normalized_frames.append(formatted_frame)
466
+
467
+ return normalized_frames
321
468
 
322
469
  def ask_question_about_video(
323
470
  self,
324
471
  video_path: str,
325
472
  question: str,
326
- num_frames: int = 28,
327
473
  ) -> str:
328
474
  r"""Ask a question about the video.
329
475
 
@@ -331,24 +477,12 @@ class VideoAnalysisToolkit(BaseToolkit):
331
477
  video_path (str): The path to the video file.
332
478
  It can be a local file or a URL (such as Youtube website).
333
479
  question (str): The question to ask about the video.
334
- num_frames (int): The number of frames to extract from the video.
335
- To be adjusted based on the length of the video.
336
- (default: :obj:`28`)
337
480
 
338
481
  Returns:
339
482
  str: The answer to the question.
340
483
  """
341
484
  from urllib.parse import urlparse
342
485
 
343
- if not question:
344
- raise ValueError("Question cannot be empty")
345
-
346
- if num_frames <= 0:
347
- logger.warning(
348
- f"Invalid num_frames: {num_frames}, using default of 28"
349
- )
350
- num_frames = 28
351
-
352
486
  parsed_url = urlparse(video_path)
353
487
  is_url = all([parsed_url.scheme, parsed_url.netloc])
354
488
 
@@ -374,7 +508,7 @@ class VideoAnalysisToolkit(BaseToolkit):
374
508
  audio_path = self._extract_audio_from_video(video_path)
375
509
  audio_transcript = self._transcribe_audio(audio_path)
376
510
 
377
- video_frames = self._extract_keyframes(video_path, num_frames)
511
+ video_frames = self._extract_keyframes(video_path)
378
512
  prompt = VIDEO_QA_PROMPT.format(
379
513
  audio_transcription=audio_transcript,
380
514
  question=question,
@@ -385,7 +519,8 @@ class VideoAnalysisToolkit(BaseToolkit):
385
519
  content=prompt,
386
520
  image_list=video_frames,
387
521
  )
388
-
522
+ # Reset the agent to clear previous state
523
+ self.vl_agent.reset()
389
524
  response = self.vl_agent.step(msg)
390
525
  if not response or not response.msgs:
391
526
  logger.error("Model returned empty response")
@@ -398,7 +533,7 @@ class VideoAnalysisToolkit(BaseToolkit):
398
533
  return answer
399
534
 
400
535
  except Exception as e:
401
- error_message = f"Error processing video: {e!s}"
536
+ error_message = f"Error processing video: {e}"
402
537
  logger.error(error_message)
403
538
  return f"Error: {error_message}"
404
539
 
@@ -102,10 +102,17 @@ class VideoDownloaderToolkit(BaseToolkit):
102
102
  Cleans up the downloaded video if they are stored in a temporary
103
103
  directory.
104
104
  """
105
- import shutil
106
-
107
105
  if self._cleanup:
108
- shutil.rmtree(self._download_directory, ignore_errors=True)
106
+ try:
107
+ import sys
108
+
109
+ if getattr(sys, 'modules', None) is not None:
110
+ import shutil
111
+
112
+ shutil.rmtree(self._download_directory, ignore_errors=True)
113
+ except (ImportError, AttributeError):
114
+ # Skip cleanup if interpreter is shutting down
115
+ pass
109
116
 
110
117
  def download_video(self, url: str) -> str:
111
118
  r"""Download the video and optionally split it into chunks.