vibesurf 0.1.9a6__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of vibesurf might be problematic. Click here for more details.
- vibe_surf/_version.py +2 -2
- vibe_surf/agents/browser_use_agent.py +68 -45
- vibe_surf/agents/prompts/report_writer_prompt.py +73 -0
- vibe_surf/agents/prompts/vibe_surf_prompt.py +85 -172
- vibe_surf/agents/report_writer_agent.py +380 -226
- vibe_surf/agents/vibe_surf_agent.py +878 -814
- vibe_surf/agents/views.py +130 -0
- vibe_surf/backend/api/activity.py +3 -1
- vibe_surf/backend/api/browser.py +70 -0
- vibe_surf/backend/api/config.py +8 -5
- vibe_surf/backend/api/files.py +59 -50
- vibe_surf/backend/api/models.py +2 -2
- vibe_surf/backend/api/task.py +47 -13
- vibe_surf/backend/database/manager.py +24 -18
- vibe_surf/backend/database/queries.py +199 -192
- vibe_surf/backend/database/schemas.py +1 -1
- vibe_surf/backend/main.py +80 -3
- vibe_surf/backend/shared_state.py +30 -35
- vibe_surf/backend/utils/encryption.py +3 -1
- vibe_surf/backend/utils/llm_factory.py +41 -36
- vibe_surf/browser/agent_browser_session.py +308 -62
- vibe_surf/browser/browser_manager.py +71 -100
- vibe_surf/browser/utils.py +5 -3
- vibe_surf/browser/watchdogs/dom_watchdog.py +0 -45
- vibe_surf/chrome_extension/background.js +88 -0
- vibe_surf/chrome_extension/manifest.json +3 -1
- vibe_surf/chrome_extension/scripts/api-client.js +13 -0
- vibe_surf/chrome_extension/scripts/file-manager.js +482 -0
- vibe_surf/chrome_extension/scripts/history-manager.js +658 -0
- vibe_surf/chrome_extension/scripts/modal-manager.js +487 -0
- vibe_surf/chrome_extension/scripts/session-manager.js +52 -11
- vibe_surf/chrome_extension/scripts/settings-manager.js +1214 -0
- vibe_surf/chrome_extension/scripts/ui-manager.js +1530 -3163
- vibe_surf/chrome_extension/sidepanel.html +47 -7
- vibe_surf/chrome_extension/styles/activity.css +934 -0
- vibe_surf/chrome_extension/styles/base.css +76 -0
- vibe_surf/chrome_extension/styles/history-modal.css +791 -0
- vibe_surf/chrome_extension/styles/input.css +568 -0
- vibe_surf/chrome_extension/styles/layout.css +186 -0
- vibe_surf/chrome_extension/styles/responsive.css +454 -0
- vibe_surf/chrome_extension/styles/settings-environment.css +165 -0
- vibe_surf/chrome_extension/styles/settings-forms.css +389 -0
- vibe_surf/chrome_extension/styles/settings-modal.css +141 -0
- vibe_surf/chrome_extension/styles/settings-profiles.css +244 -0
- vibe_surf/chrome_extension/styles/settings-responsive.css +144 -0
- vibe_surf/chrome_extension/styles/settings-utilities.css +25 -0
- vibe_surf/chrome_extension/styles/variables.css +54 -0
- vibe_surf/cli.py +5 -22
- vibe_surf/common.py +35 -0
- vibe_surf/llm/openai_compatible.py +148 -93
- vibe_surf/logger.py +99 -0
- vibe_surf/{controller/vibesurf_tools.py → tools/browser_use_tools.py} +233 -221
- vibe_surf/tools/file_system.py +415 -0
- vibe_surf/{controller → tools}/mcp_client.py +4 -3
- vibe_surf/tools/report_writer_tools.py +21 -0
- vibe_surf/tools/vibesurf_tools.py +657 -0
- vibe_surf/tools/views.py +120 -0
- {vibesurf-0.1.9a6.dist-info → vibesurf-0.1.11.dist-info}/METADATA +23 -3
- vibesurf-0.1.11.dist-info/RECORD +93 -0
- vibe_surf/chrome_extension/styles/main.css +0 -2338
- vibe_surf/chrome_extension/styles/settings.css +0 -1100
- vibe_surf/controller/file_system.py +0 -53
- vibe_surf/controller/views.py +0 -37
- vibesurf-0.1.9a6.dist-info/RECORD +0 -71
- /vibe_surf/{controller → tools}/__init__.py +0 -0
- {vibesurf-0.1.9a6.dist-info → vibesurf-0.1.11.dist-info}/WHEEL +0 -0
- {vibesurf-0.1.9a6.dist-info → vibesurf-0.1.11.dist-info}/entry_points.txt +0 -0
- {vibesurf-0.1.9a6.dist-info → vibesurf-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {vibesurf-0.1.9a6.dist-info → vibesurf-0.1.11.dist-info}/top_level.txt +0 -0
|
@@ -58,7 +58,7 @@ class McpServerConfig(BaseModel):
|
|
|
58
58
|
mcpServers: Dict[str, McpServerParams] = Field(default_factory=dict)
|
|
59
59
|
|
|
60
60
|
class ControllerConfiguration(BaseModel):
|
|
61
|
-
"""Schema for Task.mcp_server_config JSON field (legacy
|
|
61
|
+
"""Schema for Task.mcp_server_config JSON field (legacy tools config)"""
|
|
62
62
|
|
|
63
63
|
# Action control
|
|
64
64
|
exclude_actions: List[str] = Field(default_factory=list)
|
vibe_surf/backend/main.py
CHANGED
|
@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
|
11
11
|
import logging
|
|
12
12
|
import argparse
|
|
13
13
|
import os
|
|
14
|
+
import asyncio
|
|
14
15
|
from datetime import datetime
|
|
15
16
|
|
|
16
17
|
# Import routers
|
|
@@ -18,13 +19,16 @@ from .api.task import router as agents_router
|
|
|
18
19
|
from .api.files import router as files_router
|
|
19
20
|
from .api.activity import router as activity_router
|
|
20
21
|
from .api.config import router as config_router
|
|
22
|
+
from .api.browser import router as browser_router
|
|
21
23
|
|
|
22
24
|
# Import shared state
|
|
23
25
|
from . import shared_state
|
|
24
26
|
|
|
25
27
|
# Configure logging
|
|
26
|
-
|
|
27
|
-
logger
|
|
28
|
+
|
|
29
|
+
from vibe_surf.logger import get_logger
|
|
30
|
+
|
|
31
|
+
logger = get_logger(__name__)
|
|
28
32
|
|
|
29
33
|
app = FastAPI(
|
|
30
34
|
title="VibeSurf Backend API",
|
|
@@ -46,18 +50,86 @@ app.include_router(agents_router, prefix="/api", tags=["tasks"])
|
|
|
46
50
|
app.include_router(files_router, prefix="/api", tags=["files"])
|
|
47
51
|
app.include_router(activity_router, prefix="/api", tags=["activity"])
|
|
48
52
|
app.include_router(config_router, prefix="/api", tags=["config"])
|
|
53
|
+
app.include_router(browser_router, prefix="/api", tags=["browser"])
|
|
54
|
+
|
|
55
|
+
# Global variable to control browser monitoring task
|
|
56
|
+
browser_monitor_task = None
|
|
57
|
+
|
|
58
|
+
async def monitor_browser_connection():
|
|
59
|
+
"""Background task to monitor browser connection"""
|
|
60
|
+
while True:
|
|
61
|
+
try:
|
|
62
|
+
await asyncio.sleep(2) # Check every 1 second
|
|
63
|
+
|
|
64
|
+
if shared_state.browser_manager:
|
|
65
|
+
is_connected = await shared_state.browser_manager.check_browser_connected()
|
|
66
|
+
if not is_connected:
|
|
67
|
+
logger.error("No Available Browser, Exiting...")
|
|
68
|
+
|
|
69
|
+
# Schedule a graceful shutdown using os.kill in a separate thread
|
|
70
|
+
import threading
|
|
71
|
+
import signal
|
|
72
|
+
import os
|
|
73
|
+
|
|
74
|
+
def trigger_shutdown():
|
|
75
|
+
try:
|
|
76
|
+
# Give a brief moment for any cleanup
|
|
77
|
+
import time
|
|
78
|
+
time.sleep(0.5)
|
|
79
|
+
# Send SIGTERM to current process for graceful shutdown
|
|
80
|
+
os.kill(os.getpid(), signal.SIGTERM)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Error during shutdown trigger: {e}")
|
|
83
|
+
# Fallback to SIGKILL if SIGTERM doesn't work
|
|
84
|
+
try:
|
|
85
|
+
os.kill(os.getpid(), signal.SIGKILL)
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
# Start shutdown in a separate thread to avoid blocking the async loop
|
|
90
|
+
shutdown_thread = threading.Thread(target=trigger_shutdown)
|
|
91
|
+
shutdown_thread.daemon = True
|
|
92
|
+
shutdown_thread.start()
|
|
93
|
+
|
|
94
|
+
# Exit the monitoring loop
|
|
95
|
+
break
|
|
96
|
+
|
|
97
|
+
except asyncio.CancelledError:
|
|
98
|
+
logger.info("Browser monitor task cancelled")
|
|
99
|
+
break
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.warning(f"Browser monitor error: {e}")
|
|
102
|
+
# Continue monitoring even if there's an error
|
|
49
103
|
|
|
50
104
|
@app.on_event("startup")
|
|
51
105
|
async def startup_event():
|
|
52
106
|
"""Initialize database and VibeSurf components on startup"""
|
|
107
|
+
global browser_monitor_task
|
|
108
|
+
|
|
53
109
|
# Initialize VibeSurf components and update shared state
|
|
54
110
|
await shared_state.initialize_vibesurf_components()
|
|
55
111
|
|
|
112
|
+
# Start browser monitoring task
|
|
113
|
+
browser_monitor_task = asyncio.create_task(monitor_browser_connection())
|
|
114
|
+
logger.info("🔍 Started browser connection monitor")
|
|
115
|
+
|
|
56
116
|
logger.info("🚀 VibeSurf Backend API started with single-task execution model")
|
|
57
117
|
|
|
58
118
|
@app.on_event("shutdown")
|
|
59
119
|
async def shutdown_event():
|
|
60
120
|
"""Cleanup on shutdown"""
|
|
121
|
+
global browser_monitor_task
|
|
122
|
+
|
|
123
|
+
logger.info("🛑 Starting graceful shutdown...")
|
|
124
|
+
|
|
125
|
+
# Cancel browser monitor task
|
|
126
|
+
if browser_monitor_task and not browser_monitor_task.done():
|
|
127
|
+
browser_monitor_task.cancel()
|
|
128
|
+
try:
|
|
129
|
+
await asyncio.wait_for(browser_monitor_task, timeout=2.0)
|
|
130
|
+
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
131
|
+
pass
|
|
132
|
+
logger.info("✅ Browser monitor task stopped")
|
|
61
133
|
|
|
62
134
|
# Cleanup VibeSurf components
|
|
63
135
|
if shared_state.browser_manager:
|
|
@@ -70,7 +142,12 @@ async def shutdown_event():
|
|
|
70
142
|
|
|
71
143
|
# Close database
|
|
72
144
|
if shared_state.db_manager:
|
|
73
|
-
|
|
145
|
+
try:
|
|
146
|
+
await shared_state.db_manager.close()
|
|
147
|
+
logger.info("✅ Database manager closed")
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"❌ Error closing database manager: {e}")
|
|
150
|
+
|
|
74
151
|
logger.info("🛑 VibeSurf Backend API stopped")
|
|
75
152
|
|
|
76
153
|
# Health check endpoint
|
|
@@ -15,7 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
|
|
16
16
|
# VibeSurf components
|
|
17
17
|
from vibe_surf.agents.vibe_surf_agent import VibeSurfAgent
|
|
18
|
-
from vibe_surf.
|
|
18
|
+
from vibe_surf.tools.browser_use_tools import BrowserUseTools
|
|
19
|
+
from vibe_surf.tools.vibesurf_tools import VibeSurfTools
|
|
19
20
|
from vibe_surf.browser.browser_manager import BrowserManager
|
|
20
21
|
from browser_use.llm.base import BaseChatModel
|
|
21
22
|
from browser_use.llm.openai.chat import ChatOpenAI
|
|
@@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|
|
29
30
|
# Global VibeSurf components
|
|
30
31
|
vibesurf_agent: Optional[VibeSurfAgent] = None
|
|
31
32
|
browser_manager: Optional[BrowserManager] = None
|
|
32
|
-
|
|
33
|
+
vibesurf_tools: Optional[VibeSurfTools] = None
|
|
33
34
|
llm: Optional[BaseChatModel] = None
|
|
34
35
|
db_manager: Optional['DatabaseManager'] = None
|
|
35
36
|
|
|
@@ -50,10 +51,13 @@ active_task: Optional[Dict[str, Any]] = None
|
|
|
50
51
|
|
|
51
52
|
def get_all_components():
|
|
52
53
|
"""Get all components as a dictionary"""
|
|
54
|
+
global vibesurf_agent, browser_manager, vibesurf_tools, llm, db_manager
|
|
55
|
+
global workspace_dir, browser_execution_path, browser_user_data, active_mcp_server, envs
|
|
56
|
+
|
|
53
57
|
return {
|
|
54
58
|
"vibesurf_agent": vibesurf_agent,
|
|
55
59
|
"browser_manager": browser_manager,
|
|
56
|
-
"
|
|
60
|
+
"tools": vibesurf_tools,
|
|
57
61
|
"llm": llm,
|
|
58
62
|
"db_manager": db_manager,
|
|
59
63
|
"workspace_dir": workspace_dir,
|
|
@@ -67,15 +71,15 @@ def get_all_components():
|
|
|
67
71
|
|
|
68
72
|
def set_components(**kwargs):
|
|
69
73
|
"""Update global components"""
|
|
70
|
-
global vibesurf_agent, browser_manager,
|
|
74
|
+
global vibesurf_agent, browser_manager, vibesurf_tools, llm, db_manager
|
|
71
75
|
global workspace_dir, browser_execution_path, browser_user_data, active_mcp_server, envs
|
|
72
76
|
|
|
73
77
|
if "vibesurf_agent" in kwargs:
|
|
74
78
|
vibesurf_agent = kwargs["vibesurf_agent"]
|
|
75
79
|
if "browser_manager" in kwargs:
|
|
76
80
|
browser_manager = kwargs["browser_manager"]
|
|
77
|
-
if "
|
|
78
|
-
|
|
81
|
+
if "tools" in kwargs:
|
|
82
|
+
vibesurf_tools = kwargs["tools"]
|
|
79
83
|
if "llm" in kwargs:
|
|
80
84
|
llm = kwargs["llm"]
|
|
81
85
|
if "db_manager" in kwargs:
|
|
@@ -223,8 +227,8 @@ def clear_active_task():
|
|
|
223
227
|
|
|
224
228
|
|
|
225
229
|
async def _check_and_update_mcp_servers(db_session):
|
|
226
|
-
"""Check if MCP server configuration has changed and update
|
|
227
|
-
global
|
|
230
|
+
"""Check if MCP server configuration has changed and update tools if needed"""
|
|
231
|
+
global vibesurf_tools, active_mcp_server
|
|
228
232
|
|
|
229
233
|
try:
|
|
230
234
|
if not db_session:
|
|
@@ -238,21 +242,21 @@ async def _check_and_update_mcp_servers(db_session):
|
|
|
238
242
|
|
|
239
243
|
# Compare with shared state
|
|
240
244
|
if current_active_servers != active_mcp_server:
|
|
241
|
-
logger.info(f"MCP server configuration changed. Updating
|
|
245
|
+
logger.info(f"MCP server configuration changed. Updating tools...")
|
|
242
246
|
logger.info(f"Old config: {active_mcp_server}")
|
|
243
247
|
logger.info(f"New config: {current_active_servers}")
|
|
244
248
|
|
|
245
249
|
# Update shared state
|
|
246
250
|
active_mcp_server = current_active_servers.copy()
|
|
247
251
|
|
|
248
|
-
# Create new MCP server config for
|
|
252
|
+
# Create new MCP server config for tools
|
|
249
253
|
mcp_server_config = await _build_mcp_server_config(active_profiles)
|
|
250
254
|
|
|
251
255
|
# Unregister old MCP clients and register new ones
|
|
252
|
-
if
|
|
253
|
-
await
|
|
254
|
-
|
|
255
|
-
await
|
|
256
|
+
if vibesurf_tools:
|
|
257
|
+
await vibesurf_tools.unregister_mcp_clients()
|
|
258
|
+
vibesurf_tools.mcp_server_config = mcp_server_config
|
|
259
|
+
await vibesurf_tools.register_mcp_clients()
|
|
256
260
|
logger.info("✅ Controller MCP configuration updated successfully")
|
|
257
261
|
|
|
258
262
|
except Exception as e:
|
|
@@ -310,25 +314,13 @@ async def _load_active_mcp_servers():
|
|
|
310
314
|
|
|
311
315
|
async def initialize_vibesurf_components():
|
|
312
316
|
"""Initialize VibeSurf components from environment variables and default LLM profile"""
|
|
313
|
-
global vibesurf_agent, browser_manager,
|
|
317
|
+
global vibesurf_agent, browser_manager, vibesurf_tools, llm, db_manager
|
|
314
318
|
global workspace_dir, browser_execution_path, browser_user_data, envs
|
|
319
|
+
from vibe_surf import common
|
|
315
320
|
|
|
316
321
|
try:
|
|
317
322
|
# Load environment variables
|
|
318
|
-
|
|
319
|
-
if not env_workspace_dir or not env_workspace_dir.strip():
|
|
320
|
-
# Set default workspace directory based on OS
|
|
321
|
-
if platform.system() == "Windows":
|
|
322
|
-
default_workspace = os.path.join(os.environ.get("APPDATA", ""), "VibeSurf")
|
|
323
|
-
elif platform.system() == "Darwin": # macOS
|
|
324
|
-
default_workspace = os.path.join(os.path.expanduser("~"), "Library", "Application Support", "VibeSurf")
|
|
325
|
-
else: # Linux and others
|
|
326
|
-
default_workspace = os.path.join(os.path.expanduser("~"), ".vibesurf")
|
|
327
|
-
workspace_dir = default_workspace
|
|
328
|
-
else:
|
|
329
|
-
workspace_dir = env_workspace_dir
|
|
330
|
-
workspace_dir = os.path.abspath(workspace_dir)
|
|
331
|
-
os.makedirs(workspace_dir, exist_ok=True)
|
|
323
|
+
workspace_dir = common.get_workspace_dir()
|
|
332
324
|
logger.info("WorkSpace directory: {}".format(workspace_dir))
|
|
333
325
|
|
|
334
326
|
# Load environment configuration from envs.json
|
|
@@ -422,6 +414,8 @@ async def initialize_vibesurf_components():
|
|
|
422
414
|
user_data_dir=browser_user_data,
|
|
423
415
|
headless=False,
|
|
424
416
|
keep_alive=True,
|
|
417
|
+
auto_download_pdfs=False,
|
|
418
|
+
highlight_elements=True,
|
|
425
419
|
custom_extensions=[envs["VIBESURF_EXTENSION"]],
|
|
426
420
|
window_size={"width": primary_monitor.width, "height": primary_monitor.height}
|
|
427
421
|
)
|
|
@@ -436,19 +430,19 @@ async def initialize_vibesurf_components():
|
|
|
436
430
|
# Load active MCP servers from database
|
|
437
431
|
mcp_server_config = await _load_active_mcp_servers()
|
|
438
432
|
|
|
439
|
-
# Initialize vibesurf
|
|
440
|
-
|
|
433
|
+
# Initialize vibesurf tools with MCP server config
|
|
434
|
+
vibesurf_tools = VibeSurfTools(mcp_server_config=mcp_server_config)
|
|
441
435
|
|
|
442
436
|
# Register MCP clients if there are any active MCP servers
|
|
443
437
|
if mcp_server_config and mcp_server_config.get("mcpServers"):
|
|
444
|
-
await
|
|
438
|
+
await vibesurf_tools.register_mcp_clients()
|
|
445
439
|
logger.info(f"✅ Registered {len(mcp_server_config['mcpServers'])} MCP servers")
|
|
446
440
|
|
|
447
441
|
# Initialize VibeSurfAgent
|
|
448
442
|
vibesurf_agent = VibeSurfAgent(
|
|
449
443
|
llm=llm,
|
|
450
444
|
browser_manager=browser_manager,
|
|
451
|
-
|
|
445
|
+
tools=vibesurf_tools,
|
|
452
446
|
workspace_dir=workspace_dir
|
|
453
447
|
)
|
|
454
448
|
|
|
@@ -530,8 +524,9 @@ async def update_llm_from_profile(profile_name: str):
|
|
|
530
524
|
|
|
531
525
|
# Update global state
|
|
532
526
|
llm = new_llm
|
|
533
|
-
if vibesurf_agent:
|
|
534
|
-
|
|
527
|
+
if vibesurf_agent and vibesurf_agent.token_cost_service:
|
|
528
|
+
# FIX: Register new LLM with token cost service to maintain tracking
|
|
529
|
+
vibesurf_agent.llm = vibesurf_agent.token_cost_service.register_llm(new_llm)
|
|
535
530
|
|
|
536
531
|
logger.info(f"✅ LLM updated to profile: {profile_name}")
|
|
537
532
|
return True
|
|
@@ -6,26 +6,29 @@ from typing import Optional
|
|
|
6
6
|
import logging
|
|
7
7
|
from ..llm_config import get_supported_providers, is_provider_supported
|
|
8
8
|
|
|
9
|
-
logger
|
|
9
|
+
from vibe_surf.logger import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
def create_llm_from_profile(llm_profile):
|
|
12
15
|
"""Create LLM instance from LLMProfile database record (dict or object)"""
|
|
13
16
|
try:
|
|
14
17
|
# Import LLM classes from browser_use and vibe_surf
|
|
15
18
|
from browser_use.llm import (
|
|
16
|
-
ChatOpenAI, ChatAnthropic, ChatGoogle, ChatAzureOpenAI,
|
|
17
|
-
ChatGroq, ChatOllama, ChatOpenRouter, ChatDeepSeek,
|
|
19
|
+
ChatOpenAI, ChatAnthropic, ChatGoogle, ChatAzureOpenAI,
|
|
20
|
+
ChatGroq, ChatOllama, ChatOpenRouter, ChatDeepSeek,
|
|
18
21
|
ChatAWSBedrock, ChatAnthropicBedrock
|
|
19
22
|
)
|
|
20
23
|
from vibe_surf.llm import ChatOpenAICompatible
|
|
21
|
-
|
|
24
|
+
|
|
22
25
|
# Handle both dict and object access patterns
|
|
23
26
|
def get_attr(obj, key, default=None):
|
|
24
27
|
if isinstance(obj, dict):
|
|
25
28
|
return obj.get(key, default)
|
|
26
29
|
else:
|
|
27
30
|
return getattr(obj, key, default)
|
|
28
|
-
|
|
31
|
+
|
|
29
32
|
provider = get_attr(llm_profile, 'provider')
|
|
30
33
|
model = get_attr(llm_profile, 'model')
|
|
31
34
|
api_key = get_attr(llm_profile, 'api_key') # Should already be decrypted by queries
|
|
@@ -36,11 +39,11 @@ def create_llm_from_profile(llm_profile):
|
|
|
36
39
|
frequency_penalty = get_attr(llm_profile, 'frequency_penalty')
|
|
37
40
|
seed = get_attr(llm_profile, 'seed')
|
|
38
41
|
provider_config = get_attr(llm_profile, 'provider_config', {})
|
|
39
|
-
|
|
42
|
+
|
|
40
43
|
# Validate provider
|
|
41
44
|
if not is_provider_supported(provider):
|
|
42
45
|
raise ValueError(f"Unsupported provider: {provider}. Supported: {get_supported_providers()}")
|
|
43
|
-
|
|
46
|
+
|
|
44
47
|
# Define provider-specific parameter support
|
|
45
48
|
provider_param_support = {
|
|
46
49
|
"openai": ["temperature"],
|
|
@@ -55,11 +58,11 @@ def create_llm_from_profile(llm_profile):
|
|
|
55
58
|
"anthropic_bedrock": ["temperature"],
|
|
56
59
|
"openai_compatible": ["temperature"]
|
|
57
60
|
}
|
|
58
|
-
|
|
61
|
+
|
|
59
62
|
# Build common parameters based on provider support
|
|
60
63
|
supported_params = provider_param_support.get(provider, [])
|
|
61
64
|
common_params = {}
|
|
62
|
-
|
|
65
|
+
|
|
63
66
|
if temperature is not None and "temperature" in supported_params:
|
|
64
67
|
common_params["temperature"] = temperature
|
|
65
68
|
if max_tokens is not None and "max_tokens" in supported_params:
|
|
@@ -70,11 +73,11 @@ def create_llm_from_profile(llm_profile):
|
|
|
70
73
|
common_params["frequency_penalty"] = frequency_penalty
|
|
71
74
|
if seed is not None and "seed" in supported_params:
|
|
72
75
|
common_params["seed"] = seed
|
|
73
|
-
|
|
76
|
+
|
|
74
77
|
# Add provider-specific config if available
|
|
75
78
|
if provider_config:
|
|
76
79
|
common_params.update(provider_config)
|
|
77
|
-
|
|
80
|
+
|
|
78
81
|
# Create LLM instance based on provider
|
|
79
82
|
if provider == "openai":
|
|
80
83
|
params = {
|
|
@@ -85,21 +88,21 @@ def create_llm_from_profile(llm_profile):
|
|
|
85
88
|
if base_url:
|
|
86
89
|
params["base_url"] = base_url
|
|
87
90
|
return ChatOpenAI(**params)
|
|
88
|
-
|
|
91
|
+
|
|
89
92
|
elif provider == "anthropic":
|
|
90
93
|
return ChatAnthropic(
|
|
91
94
|
model=model,
|
|
92
95
|
api_key=api_key,
|
|
93
96
|
**common_params
|
|
94
97
|
)
|
|
95
|
-
|
|
98
|
+
|
|
96
99
|
elif provider == "google":
|
|
97
100
|
return ChatGoogle(
|
|
98
101
|
model=model,
|
|
99
102
|
api_key=api_key,
|
|
100
103
|
**common_params
|
|
101
104
|
)
|
|
102
|
-
|
|
105
|
+
|
|
103
106
|
elif provider == "azure_openai":
|
|
104
107
|
if not base_url:
|
|
105
108
|
raise ValueError("Azure OpenAI requires base_url (azure_endpoint)")
|
|
@@ -110,14 +113,14 @@ def create_llm_from_profile(llm_profile):
|
|
|
110
113
|
azure_endpoint=base_url,
|
|
111
114
|
**common_params
|
|
112
115
|
)
|
|
113
|
-
|
|
116
|
+
|
|
114
117
|
elif provider == "groq":
|
|
115
118
|
return ChatGroq(
|
|
116
119
|
model=model,
|
|
117
120
|
api_key=api_key,
|
|
118
121
|
**common_params
|
|
119
122
|
)
|
|
120
|
-
|
|
123
|
+
|
|
121
124
|
elif provider == "ollama":
|
|
122
125
|
params = {
|
|
123
126
|
"model": model,
|
|
@@ -128,21 +131,21 @@ def create_llm_from_profile(llm_profile):
|
|
|
128
131
|
else:
|
|
129
132
|
params["host"] = "http://localhost:11434" # Default Ollama URL
|
|
130
133
|
return ChatOllama(**params)
|
|
131
|
-
|
|
134
|
+
|
|
132
135
|
elif provider == "openrouter":
|
|
133
136
|
return ChatOpenRouter(
|
|
134
137
|
model=model,
|
|
135
138
|
api_key=api_key,
|
|
136
139
|
**common_params
|
|
137
140
|
)
|
|
138
|
-
|
|
141
|
+
|
|
139
142
|
elif provider == "deepseek":
|
|
140
143
|
return ChatDeepSeek(
|
|
141
144
|
model=model,
|
|
142
145
|
api_key=api_key,
|
|
143
146
|
**common_params
|
|
144
147
|
)
|
|
145
|
-
|
|
148
|
+
|
|
146
149
|
elif provider == "aws_bedrock":
|
|
147
150
|
params = {
|
|
148
151
|
"model": model,
|
|
@@ -157,7 +160,7 @@ def create_llm_from_profile(llm_profile):
|
|
|
157
160
|
if 'aws_region' not in params:
|
|
158
161
|
params["aws_region"] = "us-east-1"
|
|
159
162
|
return ChatAWSBedrock(**params)
|
|
160
|
-
|
|
163
|
+
|
|
161
164
|
elif provider == "anthropic_bedrock":
|
|
162
165
|
params = {
|
|
163
166
|
"model": model,
|
|
@@ -170,7 +173,7 @@ def create_llm_from_profile(llm_profile):
|
|
|
170
173
|
if "region_name" in provider_config:
|
|
171
174
|
params["region_name"] = provider_config["region_name"]
|
|
172
175
|
return ChatAnthropicBedrock(**params)
|
|
173
|
-
|
|
176
|
+
|
|
174
177
|
elif provider == "openai_compatible":
|
|
175
178
|
if not base_url:
|
|
176
179
|
raise ValueError("OpenAI Compatible provider requires base_url")
|
|
@@ -180,61 +183,63 @@ def create_llm_from_profile(llm_profile):
|
|
|
180
183
|
base_url=base_url,
|
|
181
184
|
**common_params
|
|
182
185
|
)
|
|
183
|
-
|
|
186
|
+
|
|
184
187
|
else:
|
|
185
188
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
186
|
-
|
|
189
|
+
|
|
187
190
|
except Exception as e:
|
|
188
191
|
logger.error(f"Failed to create LLM from profile: {e}")
|
|
189
192
|
raise RuntimeError(f"Failed to create LLM from profile: {str(e)}")
|
|
190
193
|
|
|
194
|
+
|
|
191
195
|
def validate_llm_configuration(provider: str, model: str, api_key: str, base_url: Optional[str] = None):
|
|
192
196
|
"""Validate LLM configuration parameters"""
|
|
193
197
|
if not provider:
|
|
194
198
|
raise ValueError("Provider is required")
|
|
195
|
-
|
|
199
|
+
|
|
196
200
|
if not model:
|
|
197
201
|
raise ValueError("Model is required")
|
|
198
|
-
|
|
202
|
+
|
|
199
203
|
if not is_provider_supported(provider):
|
|
200
204
|
raise ValueError(f"Unsupported provider: {provider}. Supported: {get_supported_providers()}")
|
|
201
|
-
|
|
205
|
+
|
|
202
206
|
# Provider-specific validation
|
|
203
207
|
from ..llm_config import get_provider_metadata
|
|
204
208
|
metadata = get_provider_metadata(provider)
|
|
205
|
-
|
|
209
|
+
|
|
206
210
|
if metadata.get("requires_api_key", True) and not api_key:
|
|
207
211
|
raise ValueError(f"API key is required for provider: {provider}")
|
|
208
|
-
|
|
212
|
+
|
|
209
213
|
if metadata.get("requires_base_url", False) and not base_url:
|
|
210
214
|
raise ValueError(f"Base URL is required for provider: {provider}")
|
|
211
|
-
|
|
215
|
+
|
|
212
216
|
return True
|
|
213
217
|
|
|
218
|
+
|
|
214
219
|
def get_llm_creation_parameters(provider: str):
|
|
215
220
|
"""Get the required and optional parameters for creating an LLM instance"""
|
|
216
221
|
from ..llm_config import get_provider_metadata
|
|
217
|
-
|
|
222
|
+
|
|
218
223
|
if not is_provider_supported(provider):
|
|
219
224
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
220
|
-
|
|
225
|
+
|
|
221
226
|
metadata = get_provider_metadata(provider)
|
|
222
|
-
|
|
227
|
+
|
|
223
228
|
required_params = ["model"]
|
|
224
229
|
optional_params = ["temperature", "max_tokens", "top_p", "frequency_penalty", "seed"]
|
|
225
|
-
|
|
230
|
+
|
|
226
231
|
if metadata.get("requires_api_key", True):
|
|
227
232
|
required_params.append("api_key")
|
|
228
|
-
|
|
233
|
+
|
|
229
234
|
if metadata.get("requires_base_url", False):
|
|
230
235
|
required_params.append("base_url")
|
|
231
236
|
elif metadata.get("supports_base_url", False):
|
|
232
237
|
optional_params.append("base_url")
|
|
233
|
-
|
|
238
|
+
|
|
234
239
|
# Special cases for AWS Bedrock
|
|
235
240
|
if provider in ["aws_bedrock", "anthropic_bedrock"]:
|
|
236
241
|
required_params.extend(["aws_secret_access_key", "region_name"])
|
|
237
|
-
|
|
242
|
+
|
|
238
243
|
return {
|
|
239
244
|
"required": required_params,
|
|
240
245
|
"optional": optional_params,
|