camel-ai 0.2.75a5__py3-none-any.whl → 0.2.76a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

Files changed (47) hide show
  1. camel/__init__.py +1 -1
  2. camel/agents/chat_agent.py +298 -130
  3. camel/configs/__init__.py +6 -0
  4. camel/configs/amd_config.py +70 -0
  5. camel/configs/nebius_config.py +103 -0
  6. camel/interpreters/__init__.py +2 -0
  7. camel/interpreters/microsandbox_interpreter.py +395 -0
  8. camel/models/__init__.py +4 -0
  9. camel/models/amd_model.py +101 -0
  10. camel/models/model_factory.py +4 -0
  11. camel/models/nebius_model.py +83 -0
  12. camel/models/ollama_model.py +3 -3
  13. camel/models/openai_model.py +0 -6
  14. camel/runtimes/daytona_runtime.py +11 -12
  15. camel/societies/workforce/task_channel.py +120 -27
  16. camel/societies/workforce/workforce.py +35 -3
  17. camel/toolkits/__init__.py +5 -3
  18. camel/toolkits/code_execution.py +28 -1
  19. camel/toolkits/function_tool.py +6 -1
  20. camel/toolkits/github_toolkit.py +104 -17
  21. camel/toolkits/hybrid_browser_toolkit/config_loader.py +8 -0
  22. camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py +12 -0
  23. camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py +33 -14
  24. camel/toolkits/hybrid_browser_toolkit/ts/src/browser-session.ts +135 -40
  25. camel/toolkits/hybrid_browser_toolkit/ts/src/config-loader.ts +2 -0
  26. camel/toolkits/hybrid_browser_toolkit/ts/src/hybrid-browser-toolkit.ts +43 -207
  27. camel/toolkits/hybrid_browser_toolkit/ts/src/parent-child-filter.ts +226 -0
  28. camel/toolkits/hybrid_browser_toolkit/ts/src/snapshot-parser.ts +231 -0
  29. camel/toolkits/hybrid_browser_toolkit/ts/src/som-screenshot-injected.ts +543 -0
  30. camel/toolkits/hybrid_browser_toolkit/ts/websocket-server.js +39 -6
  31. camel/toolkits/hybrid_browser_toolkit/ws_wrapper.py +248 -58
  32. camel/toolkits/hybrid_browser_toolkit_py/hybrid_browser_toolkit.py +5 -1
  33. camel/toolkits/{openai_image_toolkit.py → image_generation_toolkit.py} +98 -31
  34. camel/toolkits/math_toolkit.py +64 -10
  35. camel/toolkits/mcp_toolkit.py +39 -14
  36. camel/toolkits/minimax_mcp_toolkit.py +195 -0
  37. camel/toolkits/search_toolkit.py +13 -2
  38. camel/toolkits/terminal_toolkit.py +12 -2
  39. camel/toolkits/video_analysis_toolkit.py +16 -10
  40. camel/types/enums.py +42 -0
  41. camel/types/unified_model_type.py +5 -0
  42. camel/utils/commons.py +2 -0
  43. camel/utils/mcp.py +136 -2
  44. {camel_ai-0.2.75a5.dist-info → camel_ai-0.2.76a0.dist-info}/METADATA +5 -11
  45. {camel_ai-0.2.75a5.dist-info → camel_ai-0.2.76a0.dist-info}/RECORD +47 -38
  46. {camel_ai-0.2.75a5.dist-info → camel_ai-0.2.76a0.dist-info}/WHEEL +0 -0
  47. {camel_ai-0.2.75a5.dist-info → camel_ai-0.2.76a0.dist-info}/licenses/LICENSE +0 -0
@@ -13,6 +13,7 @@
13
13
  # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
14
 
15
15
  import asyncio
16
+ import contextlib
16
17
  import datetime
17
18
  import json
18
19
  import os
@@ -115,24 +116,34 @@ class WebSocketBrowserWrapper:
115
116
  self._pending_responses: Dict[
116
117
  str, asyncio.Future[Dict[str, Any]]
117
118
  ] = {} # Message ID -> Future
119
+ self._server_ready_future = None # Future to track server ready state
118
120
 
119
121
  # Logging configuration
120
122
  self.browser_log_to_file = (config or {}).get(
121
123
  'browser_log_to_file', False
122
124
  )
125
+ self.log_dir = (config or {}).get('log_dir', 'browser_log')
123
126
  self.session_id = (config or {}).get('session_id', 'default')
124
127
  self.log_file_path: Optional[str] = None
125
128
  self.log_buffer: List[Dict[str, Any]] = []
129
+ self.ts_log_file_path: Optional[str] = None
130
+ self.ts_log_file = None # File handle for TypeScript logs
131
+ self._log_reader_task = None # Task for reading and logging stdout
126
132
 
127
- # Set up log file if needed
133
+ # Set up log files if needed
128
134
  if self.browser_log_to_file:
129
- log_dir = "browser_log"
135
+ log_dir = self.log_dir if self.log_dir else "browser_log"
130
136
  os.makedirs(log_dir, exist_ok=True)
131
137
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
132
138
  self.log_file_path = os.path.join(
133
139
  log_dir,
134
140
  f"hybrid_browser_toolkit_ws_{timestamp}_{self.session_id}.log",
135
141
  )
142
+ # Add TypeScript console log file
143
+ self.ts_log_file_path = os.path.join(
144
+ log_dir,
145
+ f"typescript_console_{timestamp}_{self.session_id}.log",
146
+ )
136
147
 
137
148
  async def __aenter__(self):
138
149
  """Async context manager entry."""
@@ -207,39 +218,67 @@ class WebSocketBrowserWrapper:
207
218
  ['node', 'websocket-server.js'],
208
219
  cwd=self.ts_dir,
209
220
  stdout=subprocess.PIPE,
210
- stderr=subprocess.PIPE,
221
+ stderr=subprocess.STDOUT, # Redirect stderr to stdout
211
222
  text=True,
223
+ bufsize=1, # Line buffered
212
224
  )
213
225
 
226
+ # Create a future to wait for server ready (before starting log reader)
227
+ self._server_ready_future = asyncio.get_running_loop().create_future()
228
+
229
+ # Start log reader task immediately after process starts
230
+ self._log_reader_task = asyncio.create_task(
231
+ self._read_and_log_output()
232
+ )
233
+
234
+ if self.browser_log_to_file and self.ts_log_file_path:
235
+ logger.info(
236
+ f"TypeScript console logs will be written to: "
237
+ f"{self.ts_log_file_path}"
238
+ )
239
+
214
240
  # Wait for server to output the port
215
241
  server_ready = False
216
242
  timeout = 10 # 10 seconds timeout
217
- start_time = time.time()
218
-
219
- while not server_ready and time.time() - start_time < timeout:
220
- if self.process.poll() is not None:
221
- # Process died
222
- stderr = self.process.stderr.read()
223
- raise RuntimeError(
224
- f"WebSocket server failed to start: {stderr}"
225
- )
226
243
 
227
- try:
228
- line = self.process.stdout.readline()
229
- if line.startswith('SERVER_READY:'):
230
- self.server_port = int(line.split(':')[1].strip())
231
- server_ready = True
232
- logger.info(
233
- f"WebSocket server ready on port {self.server_port}"
234
- )
235
- except (ValueError, IndexError):
236
- continue
244
+ # Wait for the server to be ready
245
+ try:
246
+ await asyncio.wait_for(self._server_ready_future, timeout=timeout)
247
+ server_ready = True
248
+ except asyncio.TimeoutError:
249
+ server_ready = False
237
250
 
238
251
  if not server_ready:
239
- self.process.kill()
240
- raise RuntimeError(
241
- "WebSocket server failed to start within timeout"
242
- )
252
+ with contextlib.suppress(ProcessLookupError, Exception):
253
+ self.process.kill()
254
+ with contextlib.suppress(Exception):
255
+ # Ensure the process fully exits
256
+ self.process.wait(timeout=2)
257
+ # Cancel and await the log reader task
258
+ if self._log_reader_task and not self._log_reader_task.done():
259
+ self._log_reader_task.cancel()
260
+ with contextlib.suppress(asyncio.CancelledError):
261
+ await self._log_reader_task
262
+ # Close TS log file if open
263
+ if getattr(self, 'ts_log_file', None):
264
+ with contextlib.suppress(Exception):
265
+ self.ts_log_file.close()
266
+ self.ts_log_file = None
267
+ self.process = None
268
+
269
+ error_msg = "WebSocket server failed to start within timeout"
270
+ import psutil
271
+
272
+ mem = psutil.virtual_memory()
273
+ if mem.available < 1024**3: # Less than 1GB available
274
+ error_msg = (
275
+ f"WebSocket server failed to start"
276
+ f"(likely due to insufficient memory). "
277
+ f"Available memory: {mem.available / 1024**3:.2f}GB "
278
+ f"({mem.percent}% used)"
279
+ )
280
+
281
+ raise RuntimeError(error_msg)
243
282
 
244
283
  # Connect to the WebSocket server
245
284
  try:
@@ -251,10 +290,34 @@ class WebSocketBrowserWrapper:
251
290
  )
252
291
  logger.info("Connected to WebSocket server")
253
292
  except Exception as e:
254
- self.process.kill()
255
- raise RuntimeError(
256
- f"Failed to connect to WebSocket server: {e}"
257
- ) from e
293
+ with contextlib.suppress(ProcessLookupError, Exception):
294
+ self.process.kill()
295
+ with contextlib.suppress(Exception):
296
+ self.process.wait(timeout=2)
297
+ if self._log_reader_task and not self._log_reader_task.done():
298
+ self._log_reader_task.cancel()
299
+ with contextlib.suppress(asyncio.CancelledError):
300
+ await self._log_reader_task
301
+ if getattr(self, 'ts_log_file', None):
302
+ with contextlib.suppress(Exception):
303
+ self.ts_log_file.close()
304
+ self.ts_log_file = None
305
+ self.process = None
306
+
307
+ error_msg = f"Failed to connect to WebSocket server: {e}"
308
+ import psutil
309
+
310
+ mem = psutil.virtual_memory()
311
+ if mem.available < 1024**3: # Less than 1GB available
312
+ error_msg = (
313
+ f"Failed to connect to WebSocket server"
314
+ f"(likely due to insufficient memory). "
315
+ f"Available memory: {mem.available / 1024**3:.2f}GB"
316
+ f"({mem.percent}% used). "
317
+ f"Original error: {e}"
318
+ )
319
+
320
+ raise RuntimeError(error_msg) from e
258
321
 
259
322
  # Start the background receiver task
260
323
  self._receive_task = asyncio.create_task(self._receive_loop())
@@ -264,34 +327,59 @@ class WebSocketBrowserWrapper:
264
327
 
265
328
  async def stop(self):
266
329
  """Stop the WebSocket connection and server."""
267
- # Cancel the receiver task
268
- if self._receive_task and not self._receive_task.done():
269
- self._receive_task.cancel()
270
- try:
271
- await self._receive_task
272
- except asyncio.CancelledError:
273
- pass
274
-
330
+ # First, send shutdown command while receive task is still running
275
331
  if self.websocket:
276
- try:
277
- await self._send_command('shutdown', {})
332
+ with contextlib.suppress(asyncio.TimeoutError, Exception):
333
+ # Send shutdown command with a short timeout
334
+ await asyncio.wait_for(
335
+ self._send_command('shutdown', {}),
336
+ timeout=2.0, # 2 second timeout for shutdown
337
+ )
338
+ # Note: TimeoutError is expected as server may close
339
+ # before responding
340
+
341
+ # Close websocket connection
342
+ with contextlib.suppress(Exception):
278
343
  await self.websocket.close()
279
- except Exception as e:
280
- logger.warning(f"Error during websocket shutdown: {e}")
281
- finally:
282
- self.websocket = None
344
+ self.websocket = None
283
345
 
346
+ # Gracefully stop the Node process before cancelling the log reader
284
347
  if self.process:
285
348
  try:
286
- self.process.terminate()
287
- self.process.wait(timeout=5)
349
+ # give the process a short grace period to exit after shutdown
350
+ self.process.wait(timeout=2)
288
351
  except subprocess.TimeoutExpired:
289
- self.process.kill()
290
- self.process.wait()
352
+ try:
353
+ self.process.terminate()
354
+ self.process.wait(timeout=3)
355
+ except subprocess.TimeoutExpired:
356
+ with contextlib.suppress(ProcessLookupError, Exception):
357
+ self.process.kill()
358
+ self.process.wait()
359
+ except Exception as e:
360
+ logger.warning(f"Error terminating process: {e}")
291
361
  except Exception as e:
292
- logger.warning(f"Error terminating process: {e}")
293
- finally:
294
- self.process = None
362
+ logger.warning(f"Error waiting for process: {e}")
363
+
364
+ # Now cancel background tasks (reader won't block on readline)
365
+ tasks_to_cancel = [
366
+ ('_receive_task', self._receive_task),
367
+ ('_log_reader_task', self._log_reader_task),
368
+ ]
369
+ for _, task in tasks_to_cancel:
370
+ if task and not task.done():
371
+ task.cancel()
372
+ with contextlib.suppress(asyncio.CancelledError):
373
+ await task
374
+
375
+ # Close TS log file if open
376
+ if getattr(self, 'ts_log_file', None):
377
+ with contextlib.suppress(Exception):
378
+ self.ts_log_file.close()
379
+ self.ts_log_file = None
380
+
381
+ # Ensure process handle cleared
382
+ self.process = None
295
383
 
296
384
  async def _log_action(
297
385
  self,
@@ -379,16 +467,42 @@ class WebSocketBrowserWrapper:
379
467
  async def _ensure_connection(self) -> None:
380
468
  """Ensure WebSocket connection is alive."""
381
469
  if not self.websocket:
382
- raise RuntimeError("WebSocket not connected")
470
+ error_msg = "WebSocket not connected"
471
+ import psutil
472
+
473
+ mem = psutil.virtual_memory()
474
+ if mem.available < 1024**3: # Less than 1GB available
475
+ error_msg = (
476
+ f"WebSocket not connected "
477
+ f"(likely due to insufficient memory). "
478
+ f"Available memory: {mem.available / 1024**3:.2f}GB "
479
+ f"({mem.percent}% used)"
480
+ )
481
+
482
+ raise RuntimeError(error_msg)
383
483
 
384
484
  # Check if connection is still alive
385
485
  try:
386
- # Send a ping to check connection
387
- await self.websocket.ping()
486
+ # Send a ping and wait for the corresponding pong (bounded wait)
487
+ pong_waiter = await self.websocket.ping()
488
+ await asyncio.wait_for(pong_waiter, timeout=5.0)
388
489
  except Exception as e:
389
490
  logger.warning(f"WebSocket ping failed: {e}")
390
491
  self.websocket = None
391
- raise RuntimeError("WebSocket connection lost")
492
+
493
+ error_msg = "WebSocket connection lost"
494
+ import psutil
495
+
496
+ mem = psutil.virtual_memory()
497
+ if mem.available < 1024**3: # Less than 1GB available
498
+ error_msg = (
499
+ f"WebSocket connection lost "
500
+ f"(likely due to insufficient memory). "
501
+ f"Available memory: {mem.available / 1024**3:.2f}GB "
502
+ f"({mem.percent}% used)"
503
+ )
504
+
505
+ raise RuntimeError(error_msg)
392
506
 
393
507
  async def _send_command(
394
508
  self, command: str, params: Dict[str, Any]
@@ -403,7 +517,8 @@ class WebSocketBrowserWrapper:
403
517
  message = {'id': message_id, 'command': command, 'params': params}
404
518
 
405
519
  # Create a future for this message
406
- future: asyncio.Future[Dict[str, Any]] = asyncio.Future()
520
+ loop = asyncio.get_running_loop()
521
+ future: asyncio.Future[Dict[str, Any]] = loop.create_future()
407
522
  self._pending_responses[message_id] = future
408
523
 
409
524
  try:
@@ -507,9 +622,14 @@ class WebSocketBrowserWrapper:
507
622
  return ToolResult(text=response['text'], images=response['images'])
508
623
 
509
624
  def _ensure_ref_prefix(self, ref: str) -> str:
510
- """Ensure ref has 'e' prefix."""
511
- if ref and not ref.startswith('e'):
625
+ """Ensure ref has proper prefix"""
626
+ if not ref:
627
+ return ref
628
+
629
+ # If ref is purely numeric, add 'e' prefix for main frame
630
+ if ref.isdigit():
512
631
  return f'e{ref}'
632
+
513
633
  return ref
514
634
 
515
635
  def _process_refs_in_params(
@@ -676,3 +796,73 @@ class WebSocketBrowserWrapper:
676
796
  'wait_user', {'timeout': timeout_sec}
677
797
  )
678
798
  return response
799
+
800
+ async def _read_and_log_output(self):
801
+ """Read stdout from Node.js process & handle SERVER_READY + logging."""
802
+ if not self.process:
803
+ return
804
+
805
+ try:
806
+ with contextlib.ExitStack() as stack:
807
+ if self.ts_log_file_path:
808
+ self.ts_log_file = stack.enter_context(
809
+ open(self.ts_log_file_path, 'w', encoding='utf-8')
810
+ )
811
+ self.ts_log_file.write(
812
+ f"TypeScript Console Log - Started at "
813
+ f"{time.strftime('%Y-%m-%d %H:%M:%S')}\n"
814
+ )
815
+ self.ts_log_file.write("=" * 80 + "\n")
816
+ self.ts_log_file.flush()
817
+
818
+ while self.process and self.process.poll() is None:
819
+ try:
820
+ line = (
821
+ await asyncio.get_running_loop().run_in_executor(
822
+ None, self.process.stdout.readline
823
+ )
824
+ )
825
+ if not line: # EOF
826
+ break
827
+
828
+ # Check for SERVER_READY message
829
+ if line.startswith('SERVER_READY:'):
830
+ try:
831
+ self.server_port = int(
832
+ line.split(':', 1)[1].strip()
833
+ )
834
+ logger.info(
835
+ f"WebSocket server ready on port "
836
+ f"{self.server_port}"
837
+ )
838
+ if (
839
+ self._server_ready_future
840
+ and not self._server_ready_future.done()
841
+ ):
842
+ self._server_ready_future.set_result(True)
843
+ except (ValueError, IndexError) as e:
844
+ logger.error(
845
+ f"Failed to parse SERVER_READY: {e}"
846
+ )
847
+
848
+ # Write all output to log file
849
+ if self.ts_log_file:
850
+ timestamp = time.strftime('%H:%M:%S')
851
+ self.ts_log_file.write(f"[{timestamp}] {line}")
852
+ self.ts_log_file.flush()
853
+
854
+ except Exception as e:
855
+ logger.warning(f"Error reading stdout: {e}")
856
+ break
857
+
858
+ # Footer if we had a file
859
+ if self.ts_log_file:
860
+ self.ts_log_file.write("\n" + "=" * 80 + "\n")
861
+ self.ts_log_file.write(
862
+ f"TypeScript Console Log - Ended at "
863
+ f"{time.strftime('%Y-%m-%d %H:%M:%S')}\n"
864
+ )
865
+ # ExitStack closes file; clear handle
866
+ self.ts_log_file = None
867
+ except Exception as e:
868
+ logger.warning(f"Error in _read_and_log_output: {e}")
@@ -95,6 +95,7 @@ class HybridBrowserToolkit(BaseToolkit, RegisteredAgentToolkit):
95
95
  cache_dir: str = "tmp/",
96
96
  enabled_tools: Optional[List[str]] = None,
97
97
  browser_log_to_file: bool = False,
98
+ log_dir: Optional[str] = None,
98
99
  session_id: Optional[str] = None,
99
100
  default_start_url: str = "https://google.com/",
100
101
  default_timeout: Optional[int] = None,
@@ -144,6 +145,8 @@ class HybridBrowserToolkit(BaseToolkit, RegisteredAgentToolkit):
144
145
  and page loading times.
145
146
  Logs are saved to an auto-generated timestamped file.
146
147
  Defaults to `False`.
148
+ log_dir (Optional[str]): Custom directory path for log files.
149
+ If None, defaults to "browser_log". Defaults to `None`.
147
150
  session_id (Optional[str]): A unique identifier for this browser
148
151
  session. When multiple HybridBrowserToolkit instances are
149
152
  used
@@ -201,6 +204,7 @@ class HybridBrowserToolkit(BaseToolkit, RegisteredAgentToolkit):
201
204
  self._web_agent_model = web_agent_model
202
205
  self._cache_dir = cache_dir
203
206
  self._browser_log_to_file = browser_log_to_file
207
+ self._log_dir = log_dir
204
208
  self._default_start_url = default_start_url
205
209
  self._session_id = session_id or "default"
206
210
  self._viewport_limit = viewport_limit
@@ -237,7 +241,7 @@ class HybridBrowserToolkit(BaseToolkit, RegisteredAgentToolkit):
237
241
  # Set up log file if needed
238
242
  if self.log_to_file:
239
243
  # Create log directory if it doesn't exist
240
- log_dir = "browser_log"
244
+ log_dir = self._log_dir if self._log_dir else "browser_log"
241
245
  os.makedirs(log_dir, exist_ok=True)
242
246
 
243
247
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -15,7 +15,7 @@
15
15
  import base64
16
16
  import os
17
17
  from io import BytesIO
18
- from typing import List, Literal, Optional, Union
18
+ from typing import ClassVar, List, Literal, Optional, Tuple, Union
19
19
 
20
20
  from openai import OpenAI
21
21
  from PIL import Image
@@ -29,21 +29,32 @@ logger = get_logger(__name__)
29
29
 
30
30
 
31
31
  @MCPServer()
32
- class OpenAIImageToolkit(BaseToolkit):
33
- r"""A class toolkit for image generation using OpenAI's
34
- Image Generation API.
35
- """
36
-
37
- @api_keys_required(
38
- [
39
- ("api_key", "OPENAI_API_KEY"),
40
- ]
41
- )
32
+ class ImageGenToolkit(BaseToolkit):
33
+ r"""A class toolkit for image generation using Grok and OpenAI models."""
34
+
35
+ GROK_MODELS: ClassVar[List[str]] = [
36
+ "grok-2-image",
37
+ "grok-2-image-latest",
38
+ "grok-2-image-1212",
39
+ ]
40
+ OPENAI_MODELS: ClassVar[List[str]] = [
41
+ "gpt-image-1",
42
+ "dall-e-3",
43
+ "dall-e-2",
44
+ ]
45
+
42
46
  def __init__(
43
47
  self,
44
48
  model: Optional[
45
- Literal["gpt-image-1", "dall-e-3", "dall-e-2"]
46
- ] = "gpt-image-1",
49
+ Literal[
50
+ "gpt-image-1",
51
+ "dall-e-3",
52
+ "dall-e-2",
53
+ "grok-2-image",
54
+ "grok-2-image-latest",
55
+ "grok-2-image-1212",
56
+ ]
57
+ ] = "dall-e-3",
47
58
  timeout: Optional[float] = None,
48
59
  api_key: Optional[str] = None,
49
60
  url: Optional[str] = None,
@@ -72,12 +83,12 @@ class OpenAIImageToolkit(BaseToolkit):
72
83
  # NOTE: Some arguments are set in the constructor to prevent the agent
73
84
  # from making invalid API calls with model-specific parameters. For
74
85
  # example, the 'style' argument is only supported by 'dall-e-3'.
75
- r"""Initializes a new instance of the OpenAIImageToolkit class.
86
+ r"""Initializes a new instance of the ImageGenToolkit class.
76
87
 
77
88
  Args:
78
89
  api_key (Optional[str]): The API key for authenticating
79
- with the OpenAI service. (default: :obj:`None`)
80
- url (Optional[str]): The url to the OpenAI service.
90
+ with the image model service. (default: :obj:`None`)
91
+ url (Optional[str]): The url to the image model service.
81
92
  (default: :obj:`None`)
82
93
  model (Optional[str]): The model to use.
83
94
  (default: :obj:`"dall-e-3"`)
@@ -103,9 +114,23 @@ class OpenAIImageToolkit(BaseToolkit):
103
114
  image.(default: :obj:`"image_save"`)
104
115
  """
105
116
  super().__init__(timeout=timeout)
106
- api_key = api_key or os.environ.get("OPENAI_API_KEY")
107
- url = url or os.environ.get("OPENAI_API_BASE_URL")
108
- self.client = OpenAI(api_key=api_key, base_url=url)
117
+ if model not in self.GROK_MODELS + self.OPENAI_MODELS:
118
+ available_models = sorted(self.OPENAI_MODELS + self.GROK_MODELS)
119
+ raise ValueError(
120
+ f"Unsupported model: {model}. "
121
+ f"Supported models are: {available_models}"
122
+ )
123
+
124
+ # Set default url for Grok models
125
+ url = "https://api.x.ai/v1" if model in self.GROK_MODELS else url
126
+
127
+ api_key, base_url = (
128
+ self.get_openai_credentials(url, api_key)
129
+ if model in self.OPENAI_MODELS
130
+ else self.get_grok_credentials(url, api_key)
131
+ )
132
+
133
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
109
134
  self.model = model
110
135
  self.size = size
111
136
  self.quality = quality
@@ -139,7 +164,7 @@ class OpenAIImageToolkit(BaseToolkit):
139
164
  return None
140
165
 
141
166
  def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
142
- r"""Build base parameters dict for OpenAI API calls.
167
+ r"""Build base parameters dict for Image Model API calls.
143
168
 
144
169
  Args:
145
170
  prompt (str): The text prompt for the image operation.
@@ -153,6 +178,10 @@ class OpenAIImageToolkit(BaseToolkit):
153
178
  # basic parameters supported by all models
154
179
  if n is not None:
155
180
  params["n"] = n # type: ignore[assignment]
181
+
182
+ if self.model in self.GROK_MODELS:
183
+ return params
184
+
156
185
  if self.size is not None:
157
186
  params["size"] = self.size
158
187
 
@@ -179,16 +208,18 @@ class OpenAIImageToolkit(BaseToolkit):
179
208
  params["quality"] = self.quality
180
209
  if self.background is not None:
181
210
  params["background"] = self.background
182
-
183
211
  return params
184
212
 
185
213
  def _handle_api_response(
186
- self, response, image_name: Union[str, List[str]], operation: str
214
+ self,
215
+ response,
216
+ image_name: Union[str, List[str]],
217
+ operation: str,
187
218
  ) -> str:
188
- r"""Handle API response from OpenAI image operations.
219
+ r"""Handle API response from image operations.
189
220
 
190
221
  Args:
191
- response: The response object from OpenAI API.
222
+ response: The response object from image model API.
192
223
  image_name (Union[str, List[str]]): Name(s) for the saved image
193
224
  file(s). If str, the same name is used for all images (will
194
225
  cause error for multiple images). If list, must have exactly
@@ -198,8 +229,9 @@ class OpenAIImageToolkit(BaseToolkit):
198
229
  Returns:
199
230
  str: Success message with image path/URL or error message.
200
231
  """
232
+ source = "Grok" if self.model in self.GROK_MODELS else "OpenAI"
201
233
  if response.data is None or len(response.data) == 0:
202
- error_msg = "No image data returned from OpenAI API."
234
+ error_msg = f"No image data returned from {source} API."
203
235
  logger.error(error_msg)
204
236
  return error_msg
205
237
 
@@ -283,7 +315,7 @@ class OpenAIImageToolkit(BaseToolkit):
283
315
  image_name: Union[str, List[str]] = "image.png",
284
316
  n: int = 1,
285
317
  ) -> str:
286
- r"""Generate an image using OpenAI's Image Generation models.
318
+ r"""Generate an image using image models.
287
319
  The generated image will be saved locally (for ``b64_json`` response
288
320
  formats) or an image URL will be returned (for ``url`` response
289
321
  formats).
@@ -309,15 +341,50 @@ class OpenAIImageToolkit(BaseToolkit):
309
341
  logger.error(error_msg)
310
342
  return error_msg
311
343
 
344
+ @api_keys_required([("api_key", "XAI_API_KEY")])
345
+ def get_grok_credentials(self, url, api_key) -> Tuple[str, str]: # type: ignore[return-value]
346
+ r"""Get API credentials for the specified Grok model.
347
+
348
+ Args:
349
+ url (str): The base URL for the Grok API.
350
+ api_key (str): The API key for the Grok API.
351
+
352
+ Returns:
353
+ tuple: (api_key, base_url)
354
+ """
355
+
356
+ # Get credentials based on model type
357
+ api_key = api_key or os.getenv("XAI_API_KEY")
358
+ return api_key, url
359
+
360
+ @api_keys_required([("api_key", "OPENAI_API_KEY")])
361
+ def get_openai_credentials(self, url, api_key) -> Tuple[str, str | None]: # type: ignore[return-value]
362
+ r"""Get API credentials for the specified OpenAI model.
363
+
364
+ Args:
365
+ url (str): The base URL for the OpenAI API.
366
+ api_key (str): The API key for the OpenAI API.
367
+
368
+ Returns:
369
+ Tuple[str, str | None]: (api_key, base_url)
370
+ """
371
+
372
+ api_key = api_key or os.getenv("OPENAI_API_KEY")
373
+ base_url = url or os.getenv("OPENAI_API_BASE_URL")
374
+ return api_key, base_url
375
+
312
376
  def get_tools(self) -> List[FunctionTool]:
313
- r"""Returns a list of FunctionTool objects representing the
314
- functions in the toolkit.
377
+ r"""Returns a list of FunctionTool objects representing the functions
378
+ in the toolkit.
315
379
 
316
380
  Returns:
317
- List[FunctionTool]: A list of FunctionTool objects
318
- representing the functions in the toolkit.
381
+ List[FunctionTool]: A list of FunctionTool objects representing the
382
+ functions in the toolkit.
319
383
  """
320
384
  return [
321
385
  FunctionTool(self.generate_image),
322
- # could add edit_image function later
323
386
  ]
387
+
388
+
389
+ # Backward compatibility alias
390
+ OpenAIImageToolkit = ImageGenToolkit