ctrlcode 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ctrlcode/__init__.py +8 -0
- ctrlcode/agents/__init__.py +29 -0
- ctrlcode/agents/cleanup.py +388 -0
- ctrlcode/agents/communication.py +439 -0
- ctrlcode/agents/observability.py +421 -0
- ctrlcode/agents/react_loop.py +297 -0
- ctrlcode/agents/registry.py +211 -0
- ctrlcode/agents/result_parser.py +242 -0
- ctrlcode/agents/workflow.py +723 -0
- ctrlcode/analysis/__init__.py +28 -0
- ctrlcode/analysis/ast_diff.py +163 -0
- ctrlcode/analysis/bug_detector.py +149 -0
- ctrlcode/analysis/code_graphs.py +329 -0
- ctrlcode/analysis/semantic.py +205 -0
- ctrlcode/analysis/static.py +183 -0
- ctrlcode/analysis/synthesizer.py +281 -0
- ctrlcode/analysis/tests.py +189 -0
- ctrlcode/cleanup/__init__.py +16 -0
- ctrlcode/cleanup/auto_merge.py +350 -0
- ctrlcode/cleanup/doc_gardening.py +388 -0
- ctrlcode/cleanup/pr_automation.py +330 -0
- ctrlcode/cleanup/scheduler.py +356 -0
- ctrlcode/config.py +380 -0
- ctrlcode/embeddings/__init__.py +6 -0
- ctrlcode/embeddings/embedder.py +192 -0
- ctrlcode/embeddings/vector_store.py +213 -0
- ctrlcode/fuzzing/__init__.py +24 -0
- ctrlcode/fuzzing/analyzer.py +280 -0
- ctrlcode/fuzzing/budget.py +112 -0
- ctrlcode/fuzzing/context.py +665 -0
- ctrlcode/fuzzing/context_fuzzer.py +506 -0
- ctrlcode/fuzzing/derived_orchestrator.py +732 -0
- ctrlcode/fuzzing/oracle_adapter.py +135 -0
- ctrlcode/linters/__init__.py +11 -0
- ctrlcode/linters/hand_rolled_utils.py +221 -0
- ctrlcode/linters/yolo_parsing.py +217 -0
- ctrlcode/metrics/__init__.py +6 -0
- ctrlcode/metrics/dashboard.py +283 -0
- ctrlcode/metrics/tech_debt.py +663 -0
- ctrlcode/paths.py +68 -0
- ctrlcode/permissions.py +179 -0
- ctrlcode/providers/__init__.py +15 -0
- ctrlcode/providers/anthropic.py +138 -0
- ctrlcode/providers/base.py +77 -0
- ctrlcode/providers/openai.py +197 -0
- ctrlcode/providers/parallel.py +104 -0
- ctrlcode/server.py +871 -0
- ctrlcode/session/__init__.py +6 -0
- ctrlcode/session/baseline.py +57 -0
- ctrlcode/session/manager.py +967 -0
- ctrlcode/skills/__init__.py +10 -0
- ctrlcode/skills/builtin/commit.toml +29 -0
- ctrlcode/skills/builtin/docs.toml +25 -0
- ctrlcode/skills/builtin/refactor.toml +33 -0
- ctrlcode/skills/builtin/review.toml +28 -0
- ctrlcode/skills/builtin/test.toml +28 -0
- ctrlcode/skills/loader.py +111 -0
- ctrlcode/skills/registry.py +139 -0
- ctrlcode/storage/__init__.py +19 -0
- ctrlcode/storage/history_db.py +708 -0
- ctrlcode/tools/__init__.py +220 -0
- ctrlcode/tools/bash.py +112 -0
- ctrlcode/tools/browser.py +352 -0
- ctrlcode/tools/executor.py +153 -0
- ctrlcode/tools/explore.py +486 -0
- ctrlcode/tools/mcp.py +108 -0
- ctrlcode/tools/observability.py +561 -0
- ctrlcode/tools/registry.py +193 -0
- ctrlcode/tools/todo.py +291 -0
- ctrlcode/tools/update.py +266 -0
- ctrlcode/tools/webfetch.py +147 -0
- ctrlcode-0.1.0.dist-info/METADATA +93 -0
- ctrlcode-0.1.0.dist-info/RECORD +75 -0
- ctrlcode-0.1.0.dist-info/WHEEL +4 -0
- ctrlcode-0.1.0.dist-info/entry_points.txt +3 -0
ctrlcode/paths.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Cross-platform path resolution for ctrl+code.
|
|
2
|
+
|
|
3
|
+
Provides platform-aware directory paths following OS conventions:
|
|
4
|
+
- Linux: XDG Base Directory Specification
|
|
5
|
+
- macOS: ~/Library/Application Support and ~/Library/Caches
|
|
6
|
+
- Windows: AppData directories
|
|
7
|
+
|
|
8
|
+
Environment variable overrides:
|
|
9
|
+
- CTRLCODE_CONFIG_DIR: Override config directory
|
|
10
|
+
- CTRLCODE_DATA_DIR: Override data directory
|
|
11
|
+
- CTRLCODE_CACHE_DIR: Override cache directory
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from platformdirs import user_cache_dir, user_config_dir, user_data_dir
|
|
18
|
+
|
|
19
|
+
APP_NAME = "ctrlcode"
|
|
20
|
+
APP_AUTHOR = "ctrlcode" # Required for Windows paths
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_config_dir() -> Path:
|
|
24
|
+
"""Get platform-specific config directory.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Path to config directory (contains config.toml, skills/, etc.)
|
|
28
|
+
|
|
29
|
+
Platform defaults:
|
|
30
|
+
Linux: ~/.config/ctrlcode/
|
|
31
|
+
macOS: ~/Library/Application Support/ctrlcode/
|
|
32
|
+
Windows: %APPDATA%\\ctrlcode\\
|
|
33
|
+
"""
|
|
34
|
+
if override := os.getenv("CTRLCODE_CONFIG_DIR"):
|
|
35
|
+
return Path(override).expanduser()
|
|
36
|
+
return Path(user_config_dir(APP_NAME, APP_AUTHOR))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_data_dir() -> Path:
|
|
40
|
+
"""Get platform-specific data directory.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Path to data directory (contains sessions/, persistent data)
|
|
44
|
+
|
|
45
|
+
Platform defaults:
|
|
46
|
+
Linux: ~/.local/share/ctrlcode/
|
|
47
|
+
macOS: ~/Library/Application Support/ctrlcode/
|
|
48
|
+
Windows: %LOCALAPPDATA%\\ctrlcode\\
|
|
49
|
+
"""
|
|
50
|
+
if override := os.getenv("CTRLCODE_DATA_DIR"):
|
|
51
|
+
return Path(override).expanduser()
|
|
52
|
+
return Path(user_data_dir(APP_NAME, APP_AUTHOR))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_cache_dir() -> Path:
|
|
56
|
+
"""Get platform-specific cache directory.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Path to cache directory (contains conversations/, temp files)
|
|
60
|
+
|
|
61
|
+
Platform defaults:
|
|
62
|
+
Linux: ~/.cache/ctrlcode/
|
|
63
|
+
macOS: ~/Library/Caches/ctrlcode/
|
|
64
|
+
Windows: %LOCALAPPDATA%\\ctrlcode\\Cache\\
|
|
65
|
+
"""
|
|
66
|
+
if override := os.getenv("CTRLCODE_CACHE_DIR"):
|
|
67
|
+
return Path(override).expanduser()
|
|
68
|
+
return Path(user_cache_dir(APP_NAME, APP_AUTHOR))
|
ctrlcode/permissions.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""User permission system for file operations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Callable, Optional
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PermissionRequest:
|
|
14
|
+
"""Request for user permission."""
|
|
15
|
+
|
|
16
|
+
operation: str # "write_file", "read_file", etc.
|
|
17
|
+
path: str # Path being accessed
|
|
18
|
+
reason: str # Why permission is needed
|
|
19
|
+
details: Optional[dict] = None # Additional context
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PermissionManager:
|
|
23
|
+
"""Manages user permissions for file operations."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, approval_callback: Optional[Callable[[str, PermissionRequest], None]] = None):
|
|
26
|
+
"""Initialize permission manager.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
approval_callback: Async function to send permission request to TUI
|
|
30
|
+
Called with (request_id, request)
|
|
31
|
+
If None, all requests are auto-approved (dev mode)
|
|
32
|
+
"""
|
|
33
|
+
self.approval_callback = approval_callback
|
|
34
|
+
self._approved_paths: set[str] = set() # Paths approved this session
|
|
35
|
+
self._pending_requests: dict[str, asyncio.Future] = {} # request_id -> Future[bool]
|
|
36
|
+
|
|
37
|
+
async def request_permission_async(
|
|
38
|
+
self,
|
|
39
|
+
operation: str,
|
|
40
|
+
path: str,
|
|
41
|
+
reason: str,
|
|
42
|
+
details: Optional[dict] = None,
|
|
43
|
+
remember: bool = False,
|
|
44
|
+
timeout: float = 60.0
|
|
45
|
+
) -> bool:
|
|
46
|
+
"""Request user permission for an operation (async).
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
operation: Operation name
|
|
50
|
+
path: File/directory path
|
|
51
|
+
reason: Human-readable reason
|
|
52
|
+
details: Additional context
|
|
53
|
+
remember: If True, remember approval for this path
|
|
54
|
+
timeout: Timeout in seconds (default 60)
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
True if approved, False if denied
|
|
58
|
+
"""
|
|
59
|
+
# Check if already approved this session
|
|
60
|
+
cache_key = f"{operation}:{path}"
|
|
61
|
+
if cache_key in self._approved_paths:
|
|
62
|
+
logger.debug(f"Permission cached for {operation} on {path}")
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
# No callback = auto-approve (dev mode)
|
|
66
|
+
if not self.approval_callback:
|
|
67
|
+
logger.warning(f"Auto-approving {operation} on {path} (no callback)")
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
# Create permission request
|
|
71
|
+
request = PermissionRequest(
|
|
72
|
+
operation=operation,
|
|
73
|
+
path=path,
|
|
74
|
+
reason=reason,
|
|
75
|
+
details=details
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Generate unique request ID
|
|
79
|
+
request_id = str(uuid4())
|
|
80
|
+
|
|
81
|
+
# Create future for response
|
|
82
|
+
future: asyncio.Future[bool] = asyncio.Future()
|
|
83
|
+
self._pending_requests[request_id] = future
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
# Send request to TUI via callback
|
|
87
|
+
await self.approval_callback(request_id, request)
|
|
88
|
+
|
|
89
|
+
# Wait for response with timeout
|
|
90
|
+
approved = await asyncio.wait_for(future, timeout=timeout)
|
|
91
|
+
|
|
92
|
+
if approved and remember:
|
|
93
|
+
self._approved_paths.add(cache_key)
|
|
94
|
+
logger.info(f"Cached approval for {operation} on {path}")
|
|
95
|
+
|
|
96
|
+
return approved
|
|
97
|
+
|
|
98
|
+
except asyncio.TimeoutError:
|
|
99
|
+
logger.warning(f"Permission request timed out for {operation} on {path}")
|
|
100
|
+
return False # Deny on timeout
|
|
101
|
+
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"Permission request error: {e}")
|
|
104
|
+
return False # Fail closed
|
|
105
|
+
|
|
106
|
+
finally:
|
|
107
|
+
# Clean up pending request
|
|
108
|
+
self._pending_requests.pop(request_id, None)
|
|
109
|
+
|
|
110
|
+
def handle_permission_response(self, request_id: str, approved: bool) -> None:
|
|
111
|
+
"""Handle permission response from TUI.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
request_id: Request ID
|
|
115
|
+
approved: Whether user approved
|
|
116
|
+
"""
|
|
117
|
+
future = self._pending_requests.get(request_id)
|
|
118
|
+
|
|
119
|
+
if not future:
|
|
120
|
+
logger.warning(f"No pending request for ID {request_id}")
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
if future.done():
|
|
124
|
+
logger.warning(f"Request {request_id} already completed")
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# Set result to unblock waiting task
|
|
128
|
+
future.set_result(approved)
|
|
129
|
+
logger.info(f"Permission {'approved' if approved else 'denied'} for request {request_id}")
|
|
130
|
+
|
|
131
|
+
def clear_cache(self):
|
|
132
|
+
"""Clear all cached approvals."""
|
|
133
|
+
self._approved_paths.clear()
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Global instance (set by server on init)
|
|
137
|
+
_permission_manager: Optional[PermissionManager] = None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def set_permission_manager(manager: PermissionManager):
|
|
141
|
+
"""Set global permission manager."""
|
|
142
|
+
global _permission_manager
|
|
143
|
+
_permission_manager = manager
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_permission_manager() -> Optional[PermissionManager]:
|
|
147
|
+
"""Get global permission manager."""
|
|
148
|
+
return _permission_manager
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
async def request_permission(
|
|
152
|
+
operation: str,
|
|
153
|
+
path: str,
|
|
154
|
+
reason: str,
|
|
155
|
+
details: Optional[dict] = None,
|
|
156
|
+
remember: bool = False,
|
|
157
|
+
timeout: float = 60.0
|
|
158
|
+
) -> bool:
|
|
159
|
+
"""Request permission using global manager (async).
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
operation: Operation name
|
|
163
|
+
path: File/directory path
|
|
164
|
+
reason: Human-readable reason
|
|
165
|
+
details: Additional context
|
|
166
|
+
remember: Cache approval for session
|
|
167
|
+
timeout: Timeout in seconds
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
True if approved, False if denied
|
|
171
|
+
"""
|
|
172
|
+
manager = get_permission_manager()
|
|
173
|
+
|
|
174
|
+
if not manager:
|
|
175
|
+
# No manager = dev mode, auto-approve
|
|
176
|
+
logger.warning(f"No permission manager, auto-approving {operation} on {path}")
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
return await manager.request_permission_async(operation, path, reason, details, remember, timeout)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Provider implementations for ctrl-code."""
|
|
2
|
+
|
|
3
|
+
from .base import Provider, StreamEvent
|
|
4
|
+
from .anthropic import AnthropicProvider
|
|
5
|
+
from .openai import OpenAIProvider
|
|
6
|
+
from .parallel import ParallelExecutor, ProviderOutput
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Provider",
|
|
10
|
+
"StreamEvent",
|
|
11
|
+
"AnthropicProvider",
|
|
12
|
+
"OpenAIProvider",
|
|
13
|
+
"ParallelExecutor",
|
|
14
|
+
"ProviderOutput",
|
|
15
|
+
]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Anthropic (Claude) provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, AsyncIterator
|
|
5
|
+
from anthropic import AsyncAnthropic
|
|
6
|
+
|
|
7
|
+
from .base import StreamEvent
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AnthropicProvider:
|
|
13
|
+
"""Anthropic Claude provider with streaming support."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
api_key: str,
|
|
18
|
+
model: str = "claude-sonnet-4-5-20250929",
|
|
19
|
+
base_url: str | None = None
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initialize Anthropic provider.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
api_key: Anthropic API key
|
|
26
|
+
model: Model to use (default: claude-sonnet-4.5)
|
|
27
|
+
base_url: Optional base URL for API endpoint (for proxies/alternative endpoints)
|
|
28
|
+
"""
|
|
29
|
+
self.client = AsyncAnthropic(api_key=api_key, base_url=base_url)
|
|
30
|
+
self.model = model
|
|
31
|
+
|
|
32
|
+
async def stream(
|
|
33
|
+
self,
|
|
34
|
+
messages: list[dict[str, Any]],
|
|
35
|
+
tools: list[dict[str, Any]] | None = None,
|
|
36
|
+
**kwargs: Any
|
|
37
|
+
) -> AsyncIterator[StreamEvent]:
|
|
38
|
+
"""Stream completion with normalized events."""
|
|
39
|
+
logger.info(f"AnthropicProvider.stream called with {len(tools or [])} tools")
|
|
40
|
+
if tools:
|
|
41
|
+
logger.debug(f"Tool names: {[t.get('name') for t in tools]}")
|
|
42
|
+
|
|
43
|
+
async with self.client.messages.stream( # type: ignore[arg-type]
|
|
44
|
+
model=self.model,
|
|
45
|
+
messages=messages, # type: ignore[arg-type]
|
|
46
|
+
tools=tools or [], # type: ignore[arg-type]
|
|
47
|
+
max_tokens=kwargs.pop("max_tokens", 4096),
|
|
48
|
+
**kwargs
|
|
49
|
+
) as stream:
|
|
50
|
+
async for event in stream:
|
|
51
|
+
if event.type == "content_block_delta":
|
|
52
|
+
if event.delta.type == "text_delta":
|
|
53
|
+
yield StreamEvent(
|
|
54
|
+
type="text",
|
|
55
|
+
data={"text": event.delta.text}
|
|
56
|
+
)
|
|
57
|
+
elif event.delta.type == "input_json_delta":
|
|
58
|
+
# Tool call in progress
|
|
59
|
+
yield StreamEvent(
|
|
60
|
+
type="tool_call_delta",
|
|
61
|
+
data={"delta": event.delta.partial_json}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
elif event.type == "content_block_start":
|
|
65
|
+
if event.content_block.type == "tool_use":
|
|
66
|
+
yield StreamEvent(
|
|
67
|
+
type="tool_call_start",
|
|
68
|
+
data={
|
|
69
|
+
"tool": event.content_block.name,
|
|
70
|
+
"call_id": event.content_block.id
|
|
71
|
+
}
|
|
72
|
+
)
|
|
73
|
+
elif event.content_block.type == "text":
|
|
74
|
+
# Text block starting
|
|
75
|
+
yield StreamEvent(
|
|
76
|
+
type="content_block_start",
|
|
77
|
+
data={"block_type": "text"}
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
elif event.type == "content_block_stop":
|
|
81
|
+
# Content block (text or tool_use) is complete
|
|
82
|
+
yield StreamEvent(
|
|
83
|
+
type="content_block_stop",
|
|
84
|
+
data={"index": event.index}
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
elif event.type == "message_stop":
|
|
88
|
+
# Final message with usage
|
|
89
|
+
message = await stream.get_final_message()
|
|
90
|
+
yield StreamEvent(
|
|
91
|
+
type="usage",
|
|
92
|
+
data={
|
|
93
|
+
"usage": {
|
|
94
|
+
"prompt_tokens": message.usage.input_tokens,
|
|
95
|
+
"completion_tokens": message.usage.output_tokens,
|
|
96
|
+
"total_tokens": message.usage.input_tokens + message.usage.output_tokens
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
async def generate(
|
|
102
|
+
self,
|
|
103
|
+
messages: list[dict[str, Any]],
|
|
104
|
+
tools: list[dict[str, Any]] | None = None,
|
|
105
|
+
**kwargs: Any
|
|
106
|
+
) -> dict[str, Any]:
|
|
107
|
+
"""Generate non-streaming completion."""
|
|
108
|
+
response = await self.client.messages.create( # type: ignore[arg-type]
|
|
109
|
+
model=self.model,
|
|
110
|
+
messages=messages, # type: ignore[arg-type]
|
|
111
|
+
tools=tools or [], # type: ignore[arg-type]
|
|
112
|
+
max_tokens=kwargs.pop("max_tokens", 4096),
|
|
113
|
+
**kwargs
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Extract text content
|
|
117
|
+
text_content = ""
|
|
118
|
+
for block in response.content:
|
|
119
|
+
if block.type == "text":
|
|
120
|
+
text_content += block.text
|
|
121
|
+
|
|
122
|
+
return {
|
|
123
|
+
"text": text_content,
|
|
124
|
+
"usage": {
|
|
125
|
+
"prompt_tokens": response.usage.input_tokens,
|
|
126
|
+
"completion_tokens": response.usage.output_tokens,
|
|
127
|
+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
|
128
|
+
},
|
|
129
|
+
"model": response.model,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def normalize_tool_call(self, raw: dict[str, Any]) -> dict[str, Any]:
|
|
133
|
+
"""Normalize Anthropic tool call format."""
|
|
134
|
+
return {
|
|
135
|
+
"name": raw.get("name"),
|
|
136
|
+
"call_id": raw.get("id"),
|
|
137
|
+
"input": raw.get("input", {})
|
|
138
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Base provider protocol and types for LLM providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, AsyncIterator
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class StreamEvent:
|
|
10
|
+
"""Normalized event from provider streaming."""
|
|
11
|
+
|
|
12
|
+
type: str # "text", "tool_call", "usage", "error"
|
|
13
|
+
data: dict[str, Any]
|
|
14
|
+
|
|
15
|
+
def to_dict(self) -> dict[str, Any]:
|
|
16
|
+
"""Convert to JSON-serializable dict."""
|
|
17
|
+
return {
|
|
18
|
+
"type": self.type,
|
|
19
|
+
"data": self.data
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Provider(ABC):
|
|
24
|
+
"""Base class for LLM providers with streaming support."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
async def stream(
|
|
28
|
+
self,
|
|
29
|
+
messages: list[dict[str, Any]],
|
|
30
|
+
tools: list[dict[str, Any]] | None = None,
|
|
31
|
+
**kwargs: Any
|
|
32
|
+
) -> AsyncIterator[StreamEvent]:
|
|
33
|
+
"""
|
|
34
|
+
Stream completion with normalized events.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
messages: Conversation messages in provider format
|
|
38
|
+
tools: Optional tool definitions
|
|
39
|
+
**kwargs: Provider-specific parameters (temperature, etc.)
|
|
40
|
+
|
|
41
|
+
Yields:
|
|
42
|
+
StreamEvent: Normalized streaming events
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def generate(
|
|
48
|
+
self,
|
|
49
|
+
messages: list[dict[str, Any]],
|
|
50
|
+
tools: list[dict[str, Any]] | None = None,
|
|
51
|
+
**kwargs: Any
|
|
52
|
+
) -> dict[str, Any]:
|
|
53
|
+
"""
|
|
54
|
+
Generate non-streaming completion.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
messages: Conversation messages
|
|
58
|
+
tools: Optional tool definitions
|
|
59
|
+
**kwargs: Provider-specific parameters
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Complete response with text and metadata
|
|
63
|
+
"""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def normalize_tool_call(self, raw: dict[str, Any]) -> dict[str, Any]:
|
|
68
|
+
"""
|
|
69
|
+
Normalize provider-specific tool call format.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
raw: Provider-specific tool call data
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Normalized tool call dict with keys: name, call_id, input
|
|
76
|
+
"""
|
|
77
|
+
...
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""OpenAI provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, AsyncIterator
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from .base import StreamEvent
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIProvider:
|
|
14
|
+
"""OpenAI GPT provider with streaming support."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
api_key: str,
|
|
19
|
+
model: str = "gpt-4",
|
|
20
|
+
base_url: str | None = None
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Initialize OpenAI provider.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
api_key: OpenAI API key
|
|
27
|
+
model: Model to use (default: gpt-4)
|
|
28
|
+
base_url: Optional base URL for API endpoint (for proxies/alternative endpoints)
|
|
29
|
+
"""
|
|
30
|
+
self.api_key = api_key
|
|
31
|
+
self.base_url = base_url
|
|
32
|
+
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
33
|
+
self.model = model
|
|
34
|
+
|
|
35
|
+
def _convert_tools_to_openai_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
36
|
+
"""Convert Anthropic tool format to OpenAI format."""
|
|
37
|
+
openai_tools = []
|
|
38
|
+
for tool in tools:
|
|
39
|
+
# Anthropic format: {name, description, input_schema}
|
|
40
|
+
# OpenAI format: {type: "function", function: {name, description, parameters}}
|
|
41
|
+
openai_tools.append({
|
|
42
|
+
"type": "function",
|
|
43
|
+
"function": {
|
|
44
|
+
"name": tool["name"],
|
|
45
|
+
"description": tool["description"],
|
|
46
|
+
"parameters": tool["input_schema"] # OpenAI calls it "parameters" not "input_schema"
|
|
47
|
+
}
|
|
48
|
+
})
|
|
49
|
+
return openai_tools
|
|
50
|
+
|
|
51
|
+
async def stream(
|
|
52
|
+
self,
|
|
53
|
+
messages: list[dict[str, Any]],
|
|
54
|
+
tools: list[dict[str, Any]] | None = None,
|
|
55
|
+
**kwargs: Any
|
|
56
|
+
) -> AsyncIterator[StreamEvent]:
|
|
57
|
+
"""Stream completion with normalized events."""
|
|
58
|
+
# Convert tools to OpenAI format
|
|
59
|
+
openai_tools = []
|
|
60
|
+
if tools:
|
|
61
|
+
openai_tools = self._convert_tools_to_openai_format(tools)
|
|
62
|
+
logger.info(f"OpenAIProvider.stream called with {len(openai_tools)} tools")
|
|
63
|
+
logger.debug(f"Tool names: {[t['function']['name'] for t in openai_tools]}")
|
|
64
|
+
|
|
65
|
+
# Log message structure and first tool schema for diagnostics
|
|
66
|
+
for i, msg in enumerate(messages):
|
|
67
|
+
role = msg.get("role", "?")
|
|
68
|
+
content = msg.get("content", "") or ""
|
|
69
|
+
logger.info(f"Message[{i}] role={role} len={len(content)} preview={repr(content[:80])}")
|
|
70
|
+
if openai_tools:
|
|
71
|
+
logger.info(f"First tool schema: {json.dumps(openai_tools[0])}")
|
|
72
|
+
|
|
73
|
+
create_kwargs: dict[str, Any] = dict(
|
|
74
|
+
model=self.model,
|
|
75
|
+
messages=messages, # type: ignore[arg-type]
|
|
76
|
+
stream=True,
|
|
77
|
+
stream_options={"include_usage": True},
|
|
78
|
+
**kwargs
|
|
79
|
+
)
|
|
80
|
+
if openai_tools:
|
|
81
|
+
create_kwargs["tools"] = openai_tools # type: ignore[assignment]
|
|
82
|
+
create_kwargs["tool_choice"] = "auto"
|
|
83
|
+
|
|
84
|
+
# Qwen3 models default to thinking mode which interferes with tool calling.
|
|
85
|
+
# Disable thinking mode so vLLM's tool call parser sees clean JSON output.
|
|
86
|
+
if "qwen" in self.model.lower() and "extra_body" not in create_kwargs:
|
|
87
|
+
create_kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
|
|
88
|
+
logger.info("Qwen model detected: disabled thinking mode for reliable tool calling")
|
|
89
|
+
|
|
90
|
+
stream = await self.client.chat.completions.create(**create_kwargs) # type: ignore[arg-type]
|
|
91
|
+
|
|
92
|
+
tool_calls_seen = set()
|
|
93
|
+
|
|
94
|
+
async for chunk in stream: # type: ignore[union-attr]
|
|
95
|
+
# Check for usage info (OpenAI sends this in a final chunk with no choices)
|
|
96
|
+
if chunk.usage:
|
|
97
|
+
yield StreamEvent(
|
|
98
|
+
type="usage",
|
|
99
|
+
data={
|
|
100
|
+
"usage": {
|
|
101
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
|
102
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
|
103
|
+
"total_tokens": chunk.usage.total_tokens
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if not chunk.choices:
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
delta = chunk.choices[0].delta
|
|
112
|
+
|
|
113
|
+
# Text content
|
|
114
|
+
if delta.content:
|
|
115
|
+
yield StreamEvent(
|
|
116
|
+
type="text",
|
|
117
|
+
data={"text": delta.content}
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Tool calls
|
|
121
|
+
if delta.tool_calls:
|
|
122
|
+
logger.info(f"tool_call delta received: {delta.tool_calls}")
|
|
123
|
+
for tool_call in delta.tool_calls:
|
|
124
|
+
# OpenAI sends multiple deltas for same tool call
|
|
125
|
+
# Track which ones we've seen to emit start event only once
|
|
126
|
+
if tool_call.function and tool_call.function.name:
|
|
127
|
+
if tool_call.id not in tool_calls_seen:
|
|
128
|
+
tool_calls_seen.add(tool_call.id)
|
|
129
|
+
# Tool call start
|
|
130
|
+
yield StreamEvent(
|
|
131
|
+
type="tool_call_start",
|
|
132
|
+
data={
|
|
133
|
+
"tool": tool_call.function.name,
|
|
134
|
+
"call_id": tool_call.id
|
|
135
|
+
}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if tool_call.function and tool_call.function.arguments:
|
|
139
|
+
# Tool arguments (may be partial)
|
|
140
|
+
yield StreamEvent(
|
|
141
|
+
type="tool_call_delta",
|
|
142
|
+
data={"delta": tool_call.function.arguments}
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Completion
|
|
146
|
+
if chunk.choices[0].finish_reason:
|
|
147
|
+
logger.info(f"finish_reason: {chunk.choices[0].finish_reason}")
|
|
148
|
+
# Emit content_block_stop for tool calls
|
|
149
|
+
if chunk.choices[0].finish_reason == "tool_calls":
|
|
150
|
+
yield StreamEvent(
|
|
151
|
+
type="content_block_stop",
|
|
152
|
+
data={}
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
async def generate(
|
|
156
|
+
self,
|
|
157
|
+
messages: list[dict[str, Any]],
|
|
158
|
+
tools: list[dict[str, Any]] | None = None,
|
|
159
|
+
**kwargs: Any
|
|
160
|
+
) -> dict[str, Any]:
|
|
161
|
+
"""Generate non-streaming completion."""
|
|
162
|
+
# Convert tools to OpenAI format
|
|
163
|
+
openai_tools = None
|
|
164
|
+
if tools:
|
|
165
|
+
openai_tools = self._convert_tools_to_openai_format(tools)
|
|
166
|
+
|
|
167
|
+
response = await self.client.chat.completions.create( # type: ignore[arg-type]
|
|
168
|
+
model=self.model,
|
|
169
|
+
messages=messages, # type: ignore[arg-type]
|
|
170
|
+
tools=openai_tools, # type: ignore[arg-type]
|
|
171
|
+
**kwargs
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
choice = response.choices[0]
|
|
175
|
+
text_content = choice.message.content or ""
|
|
176
|
+
|
|
177
|
+
usage_dict = {}
|
|
178
|
+
if response.usage:
|
|
179
|
+
usage_dict = {
|
|
180
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
|
181
|
+
"completion_tokens": response.usage.completion_tokens,
|
|
182
|
+
"total_tokens": response.usage.total_tokens
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
return {
|
|
186
|
+
"text": text_content,
|
|
187
|
+
"usage": usage_dict,
|
|
188
|
+
"model": response.model,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
def normalize_tool_call(self, raw: dict[str, Any]) -> dict[str, Any]:
|
|
192
|
+
"""Normalize OpenAI tool call format."""
|
|
193
|
+
return {
|
|
194
|
+
"name": raw.get("function", {}).get("name"),
|
|
195
|
+
"call_id": raw.get("id"),
|
|
196
|
+
"input": json.loads(raw.get("function", {}).get("arguments", "{}"))
|
|
197
|
+
}
|