auto-coder 0.1.209__py3-none-any.whl → 0.1.212__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -0,0 +1,37 @@
1
+ import pkg_resources
2
+ from tokenizers import Tokenizer
3
+ from typing import Optional
4
+
5
+ class BuildinTokenizer:
6
+ _instance: Optional['BuildinTokenizer'] = None
7
+ _tokenizer: Optional[Tokenizer] = None
8
+
9
+ def __new__(cls) -> 'BuildinTokenizer':
10
+ if cls._instance is None:
11
+ cls._instance = super(BuildinTokenizer, cls).__new__(cls)
12
+ cls._instance._initialize()
13
+ return cls._instance
14
+
15
+ def _initialize(self) -> None:
16
+ try:
17
+ tokenizer_path = pkg_resources.resource_filename(
18
+ "autocoder", "data/tokenizer.json"
19
+ )
20
+ except FileNotFoundError:
21
+ tokenizer_path = None
22
+
23
+ if tokenizer_path:
24
+ self._tokenizer = Tokenizer.from_file(tokenizer_path)
25
+ else:
26
+ raise ValueError("Cannot find tokenizer.json file in package data directory")
27
+
28
+ def count_tokens(self, text: str) -> int:
29
+ if not self._tokenizer:
30
+ raise ValueError("Tokenizer is not initialized")
31
+
32
+ encoded = self._tokenizer.encode('{"role":"user","content":"' + text + '"}')
33
+ return len(encoded.ids)
34
+
35
+ @property
36
+ def tokenizer(self) -> Optional[Tokenizer]:
37
+ return self._tokenizer
@@ -195,9 +195,7 @@ class CodeAutoGenerate:
195
195
  for result in results:
196
196
  conversations_list.append(
197
197
  conversations + [{"role": "assistant", "content": result}])
198
- else:
199
- results = []
200
- conversations_list = []
198
+ else:
201
199
  for _ in range(self.args.human_model_num):
202
200
  v = self.llms[0].chat_oai(
203
201
  conversations=conversations, llm_config=llm_config)
@@ -341,9 +341,7 @@ class CodeAutoGenerateDiff:
341
341
  for result in results:
342
342
  conversations_list.append(
343
343
  conversations + [{"role": "assistant", "content": result}])
344
- else:
345
- results = []
346
- conversations_list = []
344
+ else:
347
345
  for _ in range(self.args.human_model_num):
348
346
  v = self.llms[0].chat_oai(
349
347
  conversations=conversations, llm_config=llm_config)
@@ -423,9 +423,7 @@ class CodeAutoGenerateEditBlock:
423
423
  for result in results:
424
424
  conversations_list.append(
425
425
  conversations + [{"role": "assistant", "content": result}])
426
- else:
427
- results = []
428
- conversations_list = []
426
+ else:
429
427
  for _ in range(self.args.human_model_num):
430
428
  v = self.llms[0].chat_oai(
431
429
  conversations=conversations, llm_config=llm_config)
@@ -311,9 +311,7 @@ class CodeAutoGenerateStrictDiff:
311
311
  for result in results:
312
312
  conversations_list.append(
313
313
  conversations + [{"role": "assistant", "content": result}])
314
- else:
315
- results = []
316
- conversations_list = []
314
+ else:
317
315
  for _ in range(self.args.human_model_num):
318
316
  v = self.llms[0].chat_oai(
319
317
  conversations=conversations, llm_config=llm_config)
@@ -45,12 +45,15 @@ class CodeModificationRanker:
45
45
  }
46
46
  ```
47
47
 
48
- 注意,只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹。
48
+ 注意:
49
+ 1. 只输出前面要求的 Json 格式就好,不要输出其他内容,Json 需要使用 ```json ```包裹
49
50
  '''
50
51
 
51
52
 
52
53
  def rank_modifications(self, generate_result: CodeGenerateResult) -> CodeGenerateResult:
53
54
  import time
55
+ from collections import defaultdict
56
+
54
57
  start_time = time.time()
55
58
 
56
59
  # 如果只有一个候选,直接返回
@@ -72,29 +75,44 @@ class CodeModificationRanker:
72
75
  ) for _ in range(generate_times)
73
76
  ]
74
77
 
75
- # Process results as they complete
78
+ # Collect all results
79
+ results = []
76
80
  for future in as_completed(futures):
77
81
  try:
78
82
  v = future.result()
79
- # If we get a valid result, use it and cancel other tasks
80
- for f in futures:
81
- f.cancel()
82
-
83
- elapsed = time.time() - start_time
84
- logger.info(f"Ranking completed in {elapsed:.2f}s, best candidate index: {v.rank_result[0]}")
85
-
86
- rerank_contents = [generate_result.contents[i] for i in v.rank_result]
87
- rerank_conversations = [generate_result.conversations[i] for i in v.rank_result]
88
- return CodeGenerateResult(contents=rerank_contents,conversations=rerank_conversations)
83
+ results.append(v.rank_result)
89
84
  except Exception as e:
90
85
  logger.warning(f"Ranking request failed: {str(e)}")
91
86
  logger.debug(traceback.format_exc())
92
87
  continue
88
+
89
+ if not results:
90
+ raise Exception("All ranking requests failed")
91
+
92
+ # Calculate scores for each candidate
93
+ candidate_scores = defaultdict(float)
94
+ for rank_result in results:
95
+ for idx, candidate_id in enumerate(rank_result):
96
+ # Score is 1/(position + 1) since position starts from 0
97
+ candidate_scores[candidate_id] += 1.0 / (idx + 1)
98
+
99
+ # Sort candidates by score in descending order
100
+ sorted_candidates = sorted(candidate_scores.keys(),
101
+ key=lambda x: candidate_scores[x],
102
+ reverse=True)
103
+
104
+ elapsed = time.time() - start_time
105
+ # Format scores for logging
106
+ score_details = ", ".join([f"candidate {i}: {candidate_scores[i]:.2f}" for i in sorted_candidates])
107
+ logger.info(f"Ranking completed in {elapsed:.2f}s, best candidate index: {sorted_candidates[0]}, scores: {score_details}")
108
+
109
+ rerank_contents = [generate_result.contents[i] for i in sorted_candidates]
110
+ rerank_conversations = [generate_result.conversations[i] for i in sorted_candidates]
111
+ return CodeGenerateResult(contents=rerank_contents,conversations=rerank_conversations)
112
+
93
113
  except Exception as e:
94
114
  logger.error(f"Ranking process failed: {str(e)}")
95
115
  logger.debug(traceback.format_exc())
96
-
97
- # If all requests failed, use the original codes
98
- elapsed = time.time() - start_time
99
- logger.warning(f"All ranking requests failed in {elapsed:.2f}s, using original order")
100
- return generate_result
116
+ elapsed = time.time() - start_time
117
+ logger.warning(f"Ranking failed in {elapsed:.2f}s, using original order")
118
+ return generate_result
@@ -0,0 +1,326 @@
1
+ import os
2
+ import json
3
+ import asyncio
4
+ from datetime import datetime
5
+ from typing import Dict, List, Optional, Any, Set, Optional
6
+ from pathlib import Path
7
+ from pydantic import BaseModel, Field
8
+
9
+ from mcp import ClientSession
10
+ from mcp.client.stdio import stdio_client, StdioServerParameters
11
+ import mcp.types as mcp_types
12
+ from loguru import logger
13
+
14
+ class McpTool(BaseModel):
15
+ """Represents an MCP tool configuration"""
16
+
17
+ name: str
18
+ description: Optional[str] = None
19
+ input_schema: dict = Field(default_factory=dict)
20
+
21
+
22
+ class McpResource(BaseModel):
23
+ """Represents an MCP resource configuration"""
24
+
25
+ uri: str
26
+ name: str
27
+ description: Optional[str] = None
28
+ mime_type: Optional[str] = None
29
+
30
+
31
+ class McpResourceTemplate(BaseModel):
32
+ """Represents an MCP resource template"""
33
+
34
+ uri_template: str
35
+ name: str
36
+ description: Optional[str] = None
37
+ mime_type: Optional[str] = None
38
+
39
+
40
+ class McpServer(BaseModel):
41
+ """Represents an MCP server configuration and status"""
42
+
43
+ name: str
44
+ config: str # JSON string of server config
45
+ status: str = "disconnected" # connected, disconnected, connecting
46
+ error: Optional[str] = None
47
+ tools: List[McpTool] = Field(default_factory=list)
48
+ resources: List[McpResource] = Field(default_factory=list)
49
+ resource_templates: List[McpResourceTemplate] = Field(default_factory=list)
50
+
51
+
52
+ class McpConnection:
53
+ """Represents an active MCP server connection"""
54
+
55
+ def __init__(self, server: McpServer, session: ClientSession, transport_manager):
56
+ self.server = server
57
+ self.session = session
58
+ self.transport_manager = (
59
+ transport_manager # Will hold transport context manager
60
+ )
61
+
62
+
63
+ class McpHub:
64
+ """
65
+ Manages MCP server connections and interactions.
66
+ Similar to the TypeScript McpHub but adapted for Python/asyncio.
67
+ """
68
+
69
+ _instance = None
70
+
71
+ def __new__(cls, settings_path: Optional[str] = None):
72
+ if cls._instance is None:
73
+ cls._instance = super(McpHub, cls).__new__(cls)
74
+ cls._instance._initialized = False
75
+ return cls._instance
76
+
77
+ def __init__(self, settings_path: Optional[str] = None):
78
+ if self._initialized:
79
+ return
80
+ """Initialize the MCP Hub with a path to settings file"""
81
+ if settings_path is None:
82
+ self.settings_path = Path.home() / ".auto-coder" / "mcp" / "settings.json"
83
+ else:
84
+ self.settings_path = Path(settings_path)
85
+ self.connections: Dict[str, McpConnection] = {}
86
+ self.is_connecting = False
87
+
88
+ # Ensure settings directory exists
89
+ self.settings_path.parent.mkdir(parents=True, exist_ok=True)
90
+ if not self.settings_path.exists():
91
+ self._write_default_settings()
92
+
93
+ self._initialized = True
94
+
95
+ def _write_default_settings(self):
96
+ """Write default MCP settings file"""
97
+ default_settings = {"mcpServers": {}}
98
+ with open(self.settings_path, "w") as f:
99
+ json.dump(default_settings, f, indent=2)
100
+
101
+ async def initialize(self):
102
+ """Initialize MCP server connections from settings"""
103
+ try:
104
+ config = self._read_settings()
105
+ await self.update_server_connections(config.get("mcpServers", {}))
106
+ except Exception as e:
107
+ logger.error(f"Failed to initialize MCP servers: {e}")
108
+ raise
109
+
110
+ def get_servers(self) -> List[McpServer]:
111
+ """Get list of all configured servers"""
112
+ return [conn.server for conn in self.connections.values()]
113
+
114
+ async def connect_to_server(self, name: str, config: dict) -> None:
115
+ """
116
+ Establish connection to an MCP server with proper resource management
117
+ """
118
+ # Remove existing connection if present
119
+ if name in self.connections:
120
+ await self.delete_connection(name)
121
+
122
+ try:
123
+ server = McpServer(
124
+ name=name, config=json.dumps(config), status="connecting"
125
+ )
126
+
127
+ # Setup transport parameters
128
+ server_params = StdioServerParameters(
129
+ command=config["command"],
130
+ args=config.get("args", []),
131
+ env={**config.get("env", {}), "PATH": os.environ.get("PATH", "")},
132
+ )
133
+
134
+ # Create transport using context manager
135
+ transport_manager = stdio_client(server_params)
136
+ transport = await transport_manager.__aenter__()
137
+ try:
138
+ session = await ClientSession(transport[0], transport[1]).__aenter__()
139
+ await session.initialize()
140
+
141
+ # Store connection with transport manager
142
+ connection = McpConnection(server, session, transport_manager)
143
+ self.connections[name] = connection
144
+
145
+ # Update server status and fetch capabilities
146
+ server.status = "connected"
147
+ server.tools = await self._fetch_tools(name)
148
+ server.resources = await self._fetch_resources(name)
149
+ server.resource_templates = await self._fetch_resource_templates(name)
150
+
151
+ except Exception as e:
152
+ # Clean up transport if session initialization fails
153
+
154
+ await transport_manager.__aexit__(None, None, None)
155
+ raise
156
+
157
+ except Exception as e:
158
+ error_msg = str(e)
159
+ logger.error(f"Failed to connect to server {name}: {error_msg}")
160
+ if name in self.connections:
161
+ self.connections[name].server.status = "disconnected"
162
+ self.connections[name].server.error = error_msg
163
+ raise
164
+
165
+ async def delete_connection(self, name: str) -> None:
166
+ """
167
+ Close and remove a server connection with proper cleanup
168
+ """
169
+ if name in self.connections:
170
+ try:
171
+ connection = self.connections[name]
172
+ # Clean up in reverse order of creation
173
+ await connection.session.__aexit__(None, None, None)
174
+ await connection.transport_manager.__aexit__(None, None, None)
175
+ del self.connections[name]
176
+ except Exception as e:
177
+ logger.error(f"Error closing connection to {name}: {e}")
178
+ # Continue with deletion even if cleanup fails
179
+ if name in self.connections:
180
+ del self.connections[name]
181
+
182
+ async def update_server_connections(self, new_servers: Dict[str, Any]) -> None:
183
+ """
184
+ Update server connections based on new configuration
185
+ """
186
+ self.is_connecting = True
187
+ try:
188
+ current_names = set(self.connections.keys())
189
+ new_names = set(new_servers.keys())
190
+
191
+ # Remove deleted servers
192
+ for name in current_names - new_names:
193
+ await self.delete_connection(name)
194
+ logger.info(f"Deleted MCP server: {name}")
195
+
196
+ # Add or update servers
197
+ for name, config in new_servers.items():
198
+ current_conn = self.connections.get(name)
199
+
200
+ if not current_conn:
201
+ # New server
202
+ await self.connect_to_server(name, config)
203
+ logger.info(f"Connected to new MCP server: {name}")
204
+ elif current_conn.server.config != json.dumps(config):
205
+ # Updated configuration
206
+ await self.connect_to_server(name, config)
207
+ logger.info(f"Reconnected MCP server with updated config: {name}")
208
+
209
+ finally:
210
+ self.is_connecting = False
211
+
212
+ async def _fetch_tools(self, server_name: str) -> List[McpTool]:
213
+ """Fetch available tools from server"""
214
+ try:
215
+ connection = self.connections.get(server_name)
216
+ if not connection:
217
+ return []
218
+
219
+ response = await connection.session.list_tools()
220
+ return [
221
+ McpTool(
222
+ name=tool.name,
223
+ description=tool.description,
224
+ input_schema=tool.inputSchema,
225
+ )
226
+ for tool in response.tools
227
+ ]
228
+ except Exception as e:
229
+ logger.error(f"Failed to fetch tools for {server_name}: {e}")
230
+ return []
231
+
232
+ async def _fetch_resources(self, server_name: str) -> List[McpResource]:
233
+ """Fetch available resources from server"""
234
+ try:
235
+ connection = self.connections.get(server_name)
236
+ if not connection:
237
+ return []
238
+
239
+ response = await connection.session.list_resources()
240
+ return [
241
+ McpResource(
242
+ uri=resource.uri,
243
+ name=resource.name,
244
+ description=resource.description,
245
+ mime_type=resource.mimeType,
246
+ )
247
+ for resource in response.resources
248
+ ]
249
+ except Exception as e:
250
+ logger.error(f"Failed to fetch resources for {server_name}: {e}")
251
+ return []
252
+
253
+ async def _fetch_resource_templates(
254
+ self, server_name: str
255
+ ) -> List[McpResourceTemplate]:
256
+ """Fetch available resource templates from server"""
257
+ try:
258
+ connection = self.connections.get(server_name)
259
+ if not connection:
260
+ return []
261
+
262
+ # return await self.send_request(
263
+ # types.ClientRequest(
264
+ # types.PingRequest(
265
+ # method="ping",
266
+ # )
267
+ # ),
268
+ # types.EmptyResult,
269
+ # )
270
+
271
+ response = await connection.session.send_request(
272
+ mcp_types.ClientRequest(mcp_types.ListResourceTemplatesRequest(
273
+ method="resources/templates/list",
274
+ )),
275
+ mcp_types.ListResourceTemplatesResult,
276
+ )
277
+ return [
278
+ McpResourceTemplate(
279
+ uri_template=template.uriTemplate,
280
+ name=template.name,
281
+ description=template.description,
282
+ mime_type=template.mimeType,
283
+ )
284
+ for template in response.resourceTemplates
285
+ ]
286
+ except Exception as e:
287
+ logger.error(f"Failed to fetch resource templates for {server_name}: {e}")
288
+ return []
289
+
290
+ def _read_settings(self) -> dict:
291
+ """Read MCP settings file"""
292
+ try:
293
+ with open(self.settings_path) as f:
294
+ return json.load(f)
295
+ except Exception as e:
296
+ logger.error(f"Failed to read MCP settings: {e}")
297
+ return {"mcpServers": {}}
298
+
299
+ async def call_tool(
300
+ self, server_name: str, tool_name: str, tool_arguments: Optional[Dict] = None
301
+ ) -> mcp_types.CallToolResult:
302
+ """
303
+ Call an MCP tool with arguments
304
+ """
305
+ connection = self.connections.get(server_name)
306
+ if not connection:
307
+ raise ValueError(f"No connection found for server: {server_name}")
308
+
309
+ return await connection.session.call_tool(tool_name, tool_arguments or {})
310
+
311
+ async def read_resource(self, server_name: str, uri: str) -> mcp_types.ReadResourceResult:
312
+ """
313
+ Read an MCP resource
314
+ """
315
+ connection = self.connections.get(server_name)
316
+ if not connection:
317
+ raise ValueError(f"No connection found for server: {server_name}")
318
+
319
+ return await connection.session.read_resource(uri)
320
+
321
+ async def shutdown(self):
322
+ """
323
+ Clean shutdown of all connections
324
+ """
325
+ for name in list(self.connections.keys()):
326
+ await self.delete_connection(name)
@@ -0,0 +1,83 @@
1
+ import asyncio
2
+ from asyncio import Queue as AsyncQueue
3
+ import threading
4
+ from typing import List, Dict, Any, Optional
5
+ from dataclasses import dataclass
6
+ import byzerllm
7
+ from autocoder.common.mcp_hub import McpHub
8
+ from autocoder.common.mcp_tools import McpExecutor
9
+
10
+ @dataclass
11
+ class McpRequest:
12
+ query: str
13
+ model: Optional[str] = None
14
+
15
+ @dataclass
16
+ class McpResponse:
17
+ result: str
18
+ error: Optional[str] = None
19
+
20
+ class McpServer:
21
+ def __init__(self):
22
+ self._request_queue = AsyncQueue()
23
+ self._response_queue = AsyncQueue()
24
+ self._running = False
25
+ self._task = None
26
+ self._loop = None
27
+
28
+ def start(self):
29
+ if self._running:
30
+ return
31
+
32
+ self._running = True
33
+ self._loop = asyncio.new_event_loop()
34
+ threading.Thread(target=self._run_event_loop, daemon=True).start()
35
+
36
+ def stop(self):
37
+ if self._running:
38
+ self._running = False
39
+ if self._loop:
40
+ self._loop.stop()
41
+ self._loop.close()
42
+
43
+ def _run_event_loop(self):
44
+ asyncio.set_event_loop(self._loop)
45
+ self._task = self._loop.create_task(self._process_request())
46
+ self._loop.run_forever()
47
+
48
+ async def _process_request(self):
49
+ hub = McpHub()
50
+ await hub.initialize()
51
+
52
+ while self._running:
53
+ try:
54
+ request = await self._request_queue.get()
55
+ if request is None:
56
+ break
57
+
58
+ llm = byzerllm.ByzerLLM.from_default_model(model=request.model)
59
+ mcp_executor = McpExecutor(hub, llm)
60
+ conversations = [{"role": "user", "content": request.query}]
61
+ _, results = await mcp_executor.run(conversations)
62
+ results_str = "\n\n".join(mcp_executor.format_mcp_result(result) for result in results)
63
+ await self._response_queue.put(McpResponse(result=results_str))
64
+ except Exception as e:
65
+ await self._response_queue.put(McpResponse(result="", error=str(e)))
66
+
67
+ def send_request(self, request: McpRequest) -> McpResponse:
68
+ async def _send():
69
+ await self._request_queue.put(request)
70
+ return await self._response_queue.get()
71
+
72
+ future = asyncio.run_coroutine_threadsafe(_send(), self._loop)
73
+ return future.result()
74
+
75
+ # Global MCP server instance
76
+ _mcp_server = None
77
+
78
+ def get_mcp_server():
79
+ global _mcp_server
80
+ if _mcp_server is None:
81
+ _mcp_server = McpServer()
82
+ _mcp_server.start()
83
+ return _mcp_server