wuwei 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wuwei/__init__.py +30 -0
- wuwei/agent/__init__.py +15 -0
- wuwei/agent/agent.py +47 -0
- wuwei/agent/base.py +93 -0
- wuwei/agent/plan_agent.py +73 -0
- wuwei/agent/session.py +23 -0
- wuwei/llm/__init__.py +13 -0
- wuwei/llm/adapters/__init__.py +3 -0
- wuwei/llm/adapters/base.py +26 -0
- wuwei/llm/adapters/openai.py +109 -0
- wuwei/llm/gateway.py +244 -0
- wuwei/llm/types.py +32 -0
- wuwei/memory/__init__.py +3 -0
- wuwei/memory/context.py +36 -0
- wuwei/planning/__init__.py +4 -0
- wuwei/planning/planner.py +65 -0
- wuwei/planning/task.py +26 -0
- wuwei/runtime/__init__.py +4 -0
- wuwei/runtime/agent_runner.py +114 -0
- wuwei/runtime/planner_executor_runner.py +291 -0
- wuwei/tools/__init__.py +10 -0
- wuwei/tools/executor.py +97 -0
- wuwei/tools/registry.py +104 -0
- wuwei/tools/tool.py +42 -0
- wuwei-0.1.0.dist-info/METADATA +155 -0
- wuwei-0.1.0.dist-info/RECORD +29 -0
- wuwei-0.1.0.dist-info/WHEEL +5 -0
- wuwei-0.1.0.dist-info/licenses/LICENSE +201 -0
- wuwei-0.1.0.dist-info/top_level.txt +1 -0
wuwei/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from wuwei.agent import Agent, AgentSession, BaseAgent, BaseSessionAgent, PlanAgent
|
|
2
|
+
from wuwei.llm import FunctionCall, LLMGateway, LLMResponse, LLMResponseChunk, Message, ToolCall
|
|
3
|
+
from wuwei.memory import Context
|
|
4
|
+
from wuwei.planning import Planner, Task, TaskList
|
|
5
|
+
from wuwei.runtime import AgentRunner, PlannerExecutorRunner
|
|
6
|
+
from wuwei.tools import Tool, ToolExecutor, ToolParameters, ToolRegistry
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"Agent",
|
|
10
|
+
"AgentRunner",
|
|
11
|
+
"AgentSession",
|
|
12
|
+
"BaseAgent",
|
|
13
|
+
"BaseSessionAgent",
|
|
14
|
+
"Context",
|
|
15
|
+
"FunctionCall",
|
|
16
|
+
"LLMGateway",
|
|
17
|
+
"LLMResponse",
|
|
18
|
+
"LLMResponseChunk",
|
|
19
|
+
"Message",
|
|
20
|
+
"PlanAgent",
|
|
21
|
+
"Planner",
|
|
22
|
+
"PlannerExecutorRunner",
|
|
23
|
+
"Task",
|
|
24
|
+
"TaskList",
|
|
25
|
+
"Tool",
|
|
26
|
+
"ToolCall",
|
|
27
|
+
"ToolExecutor",
|
|
28
|
+
"ToolParameters",
|
|
29
|
+
"ToolRegistry",
|
|
30
|
+
]
|
wuwei/agent/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from wuwei.agent.agent import Agent
|
|
2
|
+
from wuwei.agent.base import BaseAgent, BaseSessionAgent
|
|
3
|
+
from wuwei.agent.plan_agent import PlanAgent
|
|
4
|
+
from wuwei.agent.session import AgentSession
|
|
5
|
+
from wuwei.runtime import AgentRunner, PlannerExecutorRunner
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Agent",
|
|
9
|
+
"AgentRunner",
|
|
10
|
+
"AgentSession",
|
|
11
|
+
"BaseAgent",
|
|
12
|
+
"BaseSessionAgent",
|
|
13
|
+
"PlanAgent",
|
|
14
|
+
"PlannerExecutorRunner",
|
|
15
|
+
]
|
wuwei/agent/agent.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from wuwei.agent.base import BaseSessionAgent
|
|
4
|
+
from wuwei.agent.session import AgentSession
|
|
5
|
+
from wuwei.llm import LLMGateway
|
|
6
|
+
from wuwei.runtime import AgentRunner
|
|
7
|
+
from wuwei.tools import Tool, ToolRegistry
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Agent(BaseSessionAgent):
|
|
11
|
+
"""普通单 agent 门面对象。"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
llm: LLMGateway,
|
|
16
|
+
tools: list[Tool] | ToolRegistry | None = None,
|
|
17
|
+
default_system_prompt: str = "你是一个有用的助手",
|
|
18
|
+
default_max_steps: int = 10,
|
|
19
|
+
default_parallel_tool_calls: bool = False,
|
|
20
|
+
) -> None:
|
|
21
|
+
super().__init__(
|
|
22
|
+
llm=llm,
|
|
23
|
+
tools=tools,
|
|
24
|
+
default_system_prompt=default_system_prompt,
|
|
25
|
+
default_max_steps=default_max_steps,
|
|
26
|
+
default_parallel_tool_calls=default_parallel_tool_calls,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def create_runner(self, session: AgentSession) -> AgentRunner:
|
|
30
|
+
"""为普通 agent 会话创建执行器。"""
|
|
31
|
+
return AgentRunner(
|
|
32
|
+
llm=self.llm,
|
|
33
|
+
tools=self.tool_registry.list_tools(),
|
|
34
|
+
tool_executor=self.tool_executor,
|
|
35
|
+
session=session,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
async def run(
|
|
39
|
+
self,
|
|
40
|
+
user_input: str,
|
|
41
|
+
session: AgentSession | None = None,
|
|
42
|
+
stream: bool = False,
|
|
43
|
+
) -> Any:
|
|
44
|
+
"""运行一次普通 agent。"""
|
|
45
|
+
current_session = session or self.create_or_get_session()
|
|
46
|
+
runner = self.create_runner(current_session)
|
|
47
|
+
return await runner.run(user_input, stream=stream)
|
wuwei/agent/base.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from wuwei.agent.session import AgentSession
|
|
6
|
+
from wuwei.llm import LLMGateway
|
|
7
|
+
from wuwei.tools import Tool, ToolExecutor, ToolRegistry
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseAgent(ABC):
|
|
11
|
+
"""所有 agent 的最小抽象基类。"""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
async def run(self, user_input: str, session: Any | None = None, stream: bool = False) -> Any:
|
|
15
|
+
"""运行一次 agent。"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseSessionAgent(BaseAgent):
|
|
19
|
+
"""
|
|
20
|
+
带会话能力的公共基类。
|
|
21
|
+
|
|
22
|
+
这个类负责收敛 Agent 和 PlanAgent 里重复的公共逻辑:
|
|
23
|
+
- llm / tools / tool_executor 初始化
|
|
24
|
+
- 默认 system_prompt / max_steps / parallel_tool_calls
|
|
25
|
+
- session 的创建与复用
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
llm: LLMGateway,
|
|
31
|
+
tools: list[Tool] | ToolRegistry | None = None,
|
|
32
|
+
default_system_prompt: str = "你是一个有用的助手",
|
|
33
|
+
default_max_steps: int = 10,
|
|
34
|
+
default_parallel_tool_calls: bool = False,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""初始化公共依赖和默认会话配置。"""
|
|
37
|
+
self.llm = llm
|
|
38
|
+
self.default_system_prompt = default_system_prompt
|
|
39
|
+
self.default_max_steps = default_max_steps
|
|
40
|
+
self.default_parallel_tool_calls = default_parallel_tool_calls
|
|
41
|
+
|
|
42
|
+
if isinstance(tools, ToolRegistry):
|
|
43
|
+
self.tool_registry = tools
|
|
44
|
+
else:
|
|
45
|
+
self.tool_registry = ToolRegistry()
|
|
46
|
+
for tool in tools or []:
|
|
47
|
+
self.tool_registry.register(tool)
|
|
48
|
+
|
|
49
|
+
self.tool_executor = ToolExecutor(self.tool_registry)
|
|
50
|
+
self._sessions: dict[str, AgentSession] = {}
|
|
51
|
+
|
|
52
|
+
def create_session(
|
|
53
|
+
self,
|
|
54
|
+
session_id: str | None = None,
|
|
55
|
+
system_prompt: str | None = None,
|
|
56
|
+
max_steps: int | None = None,
|
|
57
|
+
parallel_tool_calls: bool | None = None,
|
|
58
|
+
) -> AgentSession:
|
|
59
|
+
"""创建一个新会话,并写入默认 system prompt。"""
|
|
60
|
+
session = AgentSession(
|
|
61
|
+
session_id=session_id or uuid4().hex,
|
|
62
|
+
system_prompt=system_prompt or self.default_system_prompt,
|
|
63
|
+
max_steps=max_steps if max_steps is not None else self.default_max_steps,
|
|
64
|
+
parallel_tool_calls=(
|
|
65
|
+
parallel_tool_calls
|
|
66
|
+
if parallel_tool_calls is not None
|
|
67
|
+
else self.default_parallel_tool_calls
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
self._sessions[session.session_id] = session
|
|
71
|
+
return session
|
|
72
|
+
|
|
73
|
+
def create_or_get_session(
|
|
74
|
+
self,
|
|
75
|
+
session_id: str | None = None,
|
|
76
|
+
system_prompt: str | None = None,
|
|
77
|
+
max_steps: int | None = None,
|
|
78
|
+
parallel_tool_calls: bool | None = None,
|
|
79
|
+
) -> AgentSession:
|
|
80
|
+
"""按 session_id 复用会话;如果不存在则新建。"""
|
|
81
|
+
if session_id is not None and session_id in self._sessions:
|
|
82
|
+
return self._sessions[session_id]
|
|
83
|
+
|
|
84
|
+
return self.create_session(
|
|
85
|
+
session_id=session_id,
|
|
86
|
+
system_prompt=system_prompt,
|
|
87
|
+
max_steps=max_steps,
|
|
88
|
+
parallel_tool_calls=parallel_tool_calls,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
def create_runner(self, session: AgentSession) -> Any:
|
|
93
|
+
"""子类负责返回对应的 runner。"""
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from wuwei.agent.base import BaseSessionAgent
|
|
4
|
+
from wuwei.agent.session import AgentSession
|
|
5
|
+
from wuwei.llm import LLMGateway
|
|
6
|
+
from wuwei.planning import Planner, Task
|
|
7
|
+
from wuwei.runtime import PlannerExecutorRunner
|
|
8
|
+
from wuwei.tools import Tool, ToolRegistry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PlanAgent(BaseSessionAgent):
|
|
12
|
+
"""Plan-and-execute 门面对象。"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
llm: LLMGateway,
|
|
17
|
+
tools: list[Tool] | ToolRegistry | None = None,
|
|
18
|
+
planner: Planner | None = None,
|
|
19
|
+
default_system_prompt: str = "你是一个有用的助手",
|
|
20
|
+
default_max_steps: int = 10,
|
|
21
|
+
default_parallel_tool_calls: bool = False,
|
|
22
|
+
) -> None:
|
|
23
|
+
super().__init__(
|
|
24
|
+
llm=llm,
|
|
25
|
+
tools=tools,
|
|
26
|
+
default_system_prompt=default_system_prompt,
|
|
27
|
+
default_max_steps=default_max_steps,
|
|
28
|
+
default_parallel_tool_calls=default_parallel_tool_calls,
|
|
29
|
+
)
|
|
30
|
+
self.planner = planner or Planner.create_planner(llm=self.llm)
|
|
31
|
+
|
|
32
|
+
def create_runner(self, session: AgentSession) -> PlannerExecutorRunner:
|
|
33
|
+
"""为 plan-and-execute 会话创建执行器。"""
|
|
34
|
+
return PlannerExecutorRunner(
|
|
35
|
+
llm=self.llm,
|
|
36
|
+
tools=self.tool_registry.list_tools(),
|
|
37
|
+
tool_executor=self.tool_executor,
|
|
38
|
+
session=session,
|
|
39
|
+
planner=self.planner,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def plan(
|
|
43
|
+
self,
|
|
44
|
+
goal: str,
|
|
45
|
+
session: AgentSession | None = None,
|
|
46
|
+
) -> list[Task]:
|
|
47
|
+
"""只做任务规划,不执行任务。"""
|
|
48
|
+
current_session = session or self.create_or_get_session()
|
|
49
|
+
runner = self.create_runner(current_session)
|
|
50
|
+
return await runner.plan(goal)
|
|
51
|
+
|
|
52
|
+
async def execute(
|
|
53
|
+
self,
|
|
54
|
+
goal: str,
|
|
55
|
+
tasks: list[Task],
|
|
56
|
+
session: AgentSession | None = None,
|
|
57
|
+
stream: bool = False,
|
|
58
|
+
) -> Any:
|
|
59
|
+
"""执行已经规划好的任务列表。"""
|
|
60
|
+
current_session = session or self.create_or_get_session()
|
|
61
|
+
runner = self.create_runner(current_session)
|
|
62
|
+
return await runner.execute(goal, tasks, stream=stream)
|
|
63
|
+
|
|
64
|
+
async def run(
|
|
65
|
+
self,
|
|
66
|
+
user_input: str,
|
|
67
|
+
session: AgentSession | None = None,
|
|
68
|
+
stream: bool = False,
|
|
69
|
+
) -> Any:
|
|
70
|
+
"""对外统一入口:先规划,再执行。"""
|
|
71
|
+
current_session = session or self.create_or_get_session()
|
|
72
|
+
runner = self.create_runner(current_session)
|
|
73
|
+
return await runner.run(user_input, stream=stream)
|
wuwei/agent/session.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
|
|
3
|
+
from wuwei.memory import Context
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class AgentSession:
|
|
8
|
+
"""保存一次会话的配置和上下文。"""
|
|
9
|
+
|
|
10
|
+
session_id: str
|
|
11
|
+
system_prompt: str = "你是一个有用的助手"
|
|
12
|
+
max_steps: int = 10
|
|
13
|
+
parallel_tool_calls: bool = False
|
|
14
|
+
context: Context = field(init=False)
|
|
15
|
+
|
|
16
|
+
def __post_init__(self) -> None:
|
|
17
|
+
"""初始化后立刻重置上下文。"""
|
|
18
|
+
self.reset()
|
|
19
|
+
|
|
20
|
+
def reset(self) -> None:
|
|
21
|
+
"""清空上下文,并重新写入 system prompt。"""
|
|
22
|
+
self.context = Context()
|
|
23
|
+
self.context.add_system_message(self.system_prompt)
|
wuwei/llm/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC,abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from ..types import Message, LLMResponse,LLMResponseChunk
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseAdapter(ABC):
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def build_request(self,
|
|
10
|
+
messages:list[Message],
|
|
11
|
+
tools:list[dict]|None=None,
|
|
12
|
+
stream:bool|None=False,
|
|
13
|
+
**kwargs)-> Any:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def call(self,request:Any)-> Any:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def parse_response(self,raw_response:Any)->LLMResponse:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def parse_stream_chunk(self,chunk:Any)->dict[str, Any] | None:
|
|
26
|
+
pass
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
|
|
6
|
+
from .base import BaseAdapter
|
|
7
|
+
from ..types import LLMResponse, Message, ToolCall, FunctionCall
|
|
8
|
+
from ...tools import Tool
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAIAdapter(BaseAdapter):
|
|
12
|
+
def __init__(self,api_key:str,model:str|None="gpt-5.4",base_url:str|None="https://api.openai.com/v1",**kwargs):
|
|
13
|
+
self.client=AsyncOpenAI(api_key=api_key,base_url=base_url)
|
|
14
|
+
self.model = model
|
|
15
|
+
self.default_params = kwargs
|
|
16
|
+
def build_request(self, messages: list[Message], tools: list[Tool] | None = None, stream: bool | None = False,
|
|
17
|
+
**kwargs) -> dict[str, Any]:
|
|
18
|
+
openai_messages = []
|
|
19
|
+
for msg in messages:
|
|
20
|
+
m = {"role": msg.role, "content": msg.content}
|
|
21
|
+
if msg.tool_calls:
|
|
22
|
+
m["tool_calls"]=[
|
|
23
|
+
{
|
|
24
|
+
"id": tc.id,
|
|
25
|
+
"type": tc.type,
|
|
26
|
+
"function": {
|
|
27
|
+
"name": tc.function.name,
|
|
28
|
+
"arguments": json.dumps(tc.function.arguments)
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
for tc in msg.tool_calls
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
if msg.tool_call_id:
|
|
35
|
+
m["tool_call_id"] = msg.tool_call_id
|
|
36
|
+
openai_messages.append(m)
|
|
37
|
+
|
|
38
|
+
request = {
|
|
39
|
+
"model": self.model,
|
|
40
|
+
"messages": openai_messages,
|
|
41
|
+
"stream": stream,
|
|
42
|
+
**self.default_params,
|
|
43
|
+
**kwargs,
|
|
44
|
+
}
|
|
45
|
+
if tools:
|
|
46
|
+
request["tools"] = [tool.to_schema() for tool in tools]
|
|
47
|
+
|
|
48
|
+
return request
|
|
49
|
+
|
|
50
|
+
async def call(self, request: dict[str, Any]) -> Any:
|
|
51
|
+
return await self.client.chat.completions.create(**request)
|
|
52
|
+
|
|
53
|
+
def parse_response(self, raw_response: Any) -> LLMResponse:
|
|
54
|
+
"""解析非流式响应"""
|
|
55
|
+
choice = raw_response.choices[0]
|
|
56
|
+
message = choice.message
|
|
57
|
+
|
|
58
|
+
tool_calls = None
|
|
59
|
+
if message.tool_calls:
|
|
60
|
+
tool_calls = [
|
|
61
|
+
ToolCall(
|
|
62
|
+
id=tc.id,
|
|
63
|
+
type="function",
|
|
64
|
+
function=FunctionCall(
|
|
65
|
+
name=tc.function.name,
|
|
66
|
+
arguments=json.loads(tc.function.arguments)
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
for tc in message.tool_calls
|
|
70
|
+
]
|
|
71
|
+
|
|
72
|
+
internal_msg = Message(
|
|
73
|
+
role="assistant",
|
|
74
|
+
content=message.content,
|
|
75
|
+
tool_calls=tool_calls,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return LLMResponse(
|
|
79
|
+
message=internal_msg,
|
|
80
|
+
finish_reason=choice.finish_reason,
|
|
81
|
+
usage={
|
|
82
|
+
"prompt_tokens": raw_response.usage.prompt_tokens,
|
|
83
|
+
"completion_tokens": raw_response.usage.completion_tokens,
|
|
84
|
+
"total_tokens": raw_response.usage.total_tokens,
|
|
85
|
+
},
|
|
86
|
+
model=raw_response.model,
|
|
87
|
+
latency_ms=0, # 由网关填充
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def parse_stream_chunk(self, chunk: Any) -> dict[str, Any] | None:
|
|
91
|
+
if not chunk.choices:
|
|
92
|
+
return None
|
|
93
|
+
delta = chunk.choices[0].delta
|
|
94
|
+
result = {
|
|
95
|
+
"content": delta.content or "",
|
|
96
|
+
"finish_reason": chunk.choices[0].finish_reason,
|
|
97
|
+
}
|
|
98
|
+
if delta.tool_calls:
|
|
99
|
+
result["tool_calls_delta"] = []
|
|
100
|
+
for tc in delta.tool_calls:
|
|
101
|
+
item = {"index": tc.index}
|
|
102
|
+
if tc.id:
|
|
103
|
+
item["id"] = tc.id
|
|
104
|
+
if tc.function.name:
|
|
105
|
+
item["name"] = tc.function.name
|
|
106
|
+
if tc.function.arguments:
|
|
107
|
+
item["arguments"] = tc.function.arguments
|
|
108
|
+
result["tool_calls_delta"].append(item)
|
|
109
|
+
return result
|
wuwei/llm/gateway.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, AsyncIterator, Union
|
|
7
|
+
|
|
8
|
+
from .adapters import OpenAIAdapter
|
|
9
|
+
from .adapters.base import BaseAdapter
|
|
10
|
+
from .types import FunctionCall, LLMResponse, LLMResponseChunk, Message, ToolCall
|
|
11
|
+
from ..tools import Tool
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LLMGateway:
|
|
15
|
+
_DEFAULT_ENV_SEARCH_DEPTH = 3
|
|
16
|
+
_DEFAULT_ENV_FILES = (".env", "env")
|
|
17
|
+
_DEFAULT_ENV_PREFIX = "OPENAI"
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: dict[str, Any]):
|
|
20
|
+
"""根据显式配置初始化模型网关。"""
|
|
21
|
+
self.adapter: BaseAdapter
|
|
22
|
+
provider = config.get("provider", "openai")
|
|
23
|
+
|
|
24
|
+
if provider == "openai":
|
|
25
|
+
adapter_kwargs = {
|
|
26
|
+
"api_key": config["api_key"],
|
|
27
|
+
"model": config.get("model", "gpt-5.4"),
|
|
28
|
+
"temperature": config.get("temperature", 0.2),
|
|
29
|
+
"max_tokens": config.get("max_tokens", 4096),
|
|
30
|
+
}
|
|
31
|
+
if config.get("base_url"):
|
|
32
|
+
adapter_kwargs["base_url"] = config["base_url"]
|
|
33
|
+
|
|
34
|
+
self.adapter = OpenAIAdapter(**adapter_kwargs)
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
37
|
+
|
|
38
|
+
self.retry_policy = config.get("retry", {"max_attempts": 3})
|
|
39
|
+
self.timeout = config.get("timeout", 60)
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_env(
|
|
43
|
+
cls,
|
|
44
|
+
*,
|
|
45
|
+
env_prefix: str | None = None,
|
|
46
|
+
env_file: str | None = None,
|
|
47
|
+
base_url: str | None = None,
|
|
48
|
+
model: str | None = None,
|
|
49
|
+
**config: Any,
|
|
50
|
+
) -> "LLMGateway":
|
|
51
|
+
"""
|
|
52
|
+
从环境变量创建 LLMGateway。
|
|
53
|
+
|
|
54
|
+
这个方法只保留少量高频参数:
|
|
55
|
+
- `env_prefix`:决定读取哪组环境变量,默认是 `OPENAI`
|
|
56
|
+
- `env_file`:显式指定 env 文件
|
|
57
|
+
- `model` / `base_url`:作为显式覆盖值
|
|
58
|
+
|
|
59
|
+
环境变量命名固定为:
|
|
60
|
+
- `{PREFIX}_API_KEY`
|
|
61
|
+
- `{PREFIX}_BASE_URL`
|
|
62
|
+
- `{PREFIX}_MODEL`
|
|
63
|
+
|
|
64
|
+
框架内部会自动在当前目录和最多 3 层父目录中查找 `.env` / `env`,
|
|
65
|
+
这部分不再暴露给用户配置,保持接口简单。
|
|
66
|
+
"""
|
|
67
|
+
prefix = (env_prefix or cls._DEFAULT_ENV_PREFIX).upper()
|
|
68
|
+
api_key_env = f"{prefix}_API_KEY"
|
|
69
|
+
base_url_env = f"{prefix}_BASE_URL"
|
|
70
|
+
model_env = f"{prefix}_MODEL"
|
|
71
|
+
|
|
72
|
+
file_values = cls._load_env_file(env_file=env_file)
|
|
73
|
+
|
|
74
|
+
api_key = os.getenv(api_key_env) or file_values.get(api_key_env)
|
|
75
|
+
if not api_key:
|
|
76
|
+
raise ValueError(f"Missing required environment variable: {api_key_env}")
|
|
77
|
+
|
|
78
|
+
env_config: dict[str, Any] = {
|
|
79
|
+
"provider": "openai",
|
|
80
|
+
"api_key": api_key,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
resolved_base_url = base_url or os.getenv(base_url_env) or file_values.get(base_url_env)
|
|
84
|
+
if resolved_base_url:
|
|
85
|
+
env_config["base_url"] = resolved_base_url
|
|
86
|
+
|
|
87
|
+
resolved_model = model or os.getenv(model_env) or file_values.get(model_env)
|
|
88
|
+
if resolved_model:
|
|
89
|
+
env_config["model"] = resolved_model
|
|
90
|
+
|
|
91
|
+
env_config.update(config)
|
|
92
|
+
return cls(env_config)
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def _load_env_file(env_file: str | None = None) -> dict[str, str]:
|
|
96
|
+
"""
|
|
97
|
+
尝试从 env 文件读取变量。
|
|
98
|
+
|
|
99
|
+
规则:
|
|
100
|
+
1. 如果显式传入 `env_file`,只读取这个文件
|
|
101
|
+
2. 否则自动查找当前目录和最多 3 层父目录中的 `.env` / `env`
|
|
102
|
+
|
|
103
|
+
注意:
|
|
104
|
+
- 这里只返回解析结果,不会修改 `os.environ`
|
|
105
|
+
- 这是一个轻量实现,不依赖 `python-dotenv`
|
|
106
|
+
"""
|
|
107
|
+
candidate_paths: list[Path] = []
|
|
108
|
+
if env_file:
|
|
109
|
+
candidate_paths.append(Path(env_file))
|
|
110
|
+
else:
|
|
111
|
+
directories = [Path.cwd(), *Path.cwd().parents[: LLMGateway._DEFAULT_ENV_SEARCH_DEPTH]]
|
|
112
|
+
for directory in directories:
|
|
113
|
+
for filename in LLMGateway._DEFAULT_ENV_FILES:
|
|
114
|
+
candidate_paths.append(directory / filename)
|
|
115
|
+
|
|
116
|
+
for path in candidate_paths:
|
|
117
|
+
if not path.exists() or not path.is_file():
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
values: dict[str, str] = {}
|
|
121
|
+
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
|
122
|
+
line = raw_line.strip()
|
|
123
|
+
if not line or line.startswith("#") or "=" not in line:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
key, value = line.split("=", 1)
|
|
127
|
+
key = key.strip()
|
|
128
|
+
value = value.strip()
|
|
129
|
+
|
|
130
|
+
if not key:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
if value and value[0] == value[-1] and value[0] in {"'", '"'}:
|
|
134
|
+
value = value[1:-1]
|
|
135
|
+
|
|
136
|
+
values[key] = value
|
|
137
|
+
|
|
138
|
+
return values
|
|
139
|
+
|
|
140
|
+
return {}
|
|
141
|
+
|
|
142
|
+
async def generate(
|
|
143
|
+
self,
|
|
144
|
+
messages: list[Message],
|
|
145
|
+
tools: list[Tool] | None = None,
|
|
146
|
+
stream: bool = False,
|
|
147
|
+
**kwargs,
|
|
148
|
+
) -> Union[LLMResponse, AsyncIterator[LLMResponseChunk]]:
|
|
149
|
+
"""统一处理单次和流式生成请求。"""
|
|
150
|
+
if stream:
|
|
151
|
+
return self._generate_stream(messages=messages, tools=tools, **kwargs)
|
|
152
|
+
return await self._generate_single(messages=messages, tools=tools, **kwargs)
|
|
153
|
+
|
|
154
|
+
async def _generate_single(
|
|
155
|
+
self,
|
|
156
|
+
messages: list[Message],
|
|
157
|
+
tools: list[Tool] | None,
|
|
158
|
+
**kwargs,
|
|
159
|
+
) -> LLMResponse:
|
|
160
|
+
"""发送一次非流式请求。"""
|
|
161
|
+
request = self.adapter.build_request(messages=messages, tools=tools, stream=False, **kwargs)
|
|
162
|
+
start = time.monotonic()
|
|
163
|
+
last_exception = None
|
|
164
|
+
|
|
165
|
+
for attempt in range(self.retry_policy["max_attempts"]):
|
|
166
|
+
try:
|
|
167
|
+
raw = await asyncio.wait_for(self.adapter.call(request), timeout=self.timeout)
|
|
168
|
+
response = self.adapter.parse_response(raw)
|
|
169
|
+
response.latency_ms = int((time.monotonic() - start) * 1000)
|
|
170
|
+
return response
|
|
171
|
+
except Exception as exc:
|
|
172
|
+
last_exception = exc
|
|
173
|
+
if attempt == self.retry_policy["max_attempts"] - 1:
|
|
174
|
+
raise
|
|
175
|
+
wait_time = 2**attempt
|
|
176
|
+
await asyncio.sleep(wait_time)
|
|
177
|
+
|
|
178
|
+
raise last_exception
|
|
179
|
+
|
|
180
|
+
async def _generate_stream(
|
|
181
|
+
self,
|
|
182
|
+
messages: list[Message],
|
|
183
|
+
tools: list[Tool] | None,
|
|
184
|
+
**kwargs,
|
|
185
|
+
) -> AsyncIterator[LLMResponseChunk]:
|
|
186
|
+
"""发送一次流式请求,并把 tool call 增量拼成完整结构。"""
|
|
187
|
+
request = self.adapter.build_request(messages, tools, stream=True, **kwargs)
|
|
188
|
+
stream = await self.adapter.call(request)
|
|
189
|
+
|
|
190
|
+
# 按 index 累积工具调用增量。
|
|
191
|
+
pending: dict[int, dict[str, Any]] = {}
|
|
192
|
+
|
|
193
|
+
async for chunk in stream:
|
|
194
|
+
data = self.adapter.parse_stream_chunk(chunk)
|
|
195
|
+
if not data:
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
content = data.get("content", "")
|
|
199
|
+
finish_reason = data.get("finish_reason")
|
|
200
|
+
tool_calls_delta = data.get("tool_calls_delta")
|
|
201
|
+
|
|
202
|
+
if tool_calls_delta:
|
|
203
|
+
for delta_item in tool_calls_delta:
|
|
204
|
+
idx = delta_item["index"]
|
|
205
|
+
if idx not in pending:
|
|
206
|
+
pending[idx] = {"id": "", "name": "", "arguments": ""}
|
|
207
|
+
if "id" in delta_item:
|
|
208
|
+
pending[idx]["id"] = delta_item["id"]
|
|
209
|
+
if "name" in delta_item:
|
|
210
|
+
pending[idx]["name"] = delta_item["name"]
|
|
211
|
+
if "arguments" in delta_item:
|
|
212
|
+
pending[idx]["arguments"] += delta_item["arguments"]
|
|
213
|
+
|
|
214
|
+
out_chunk = LLMResponseChunk(content=content)
|
|
215
|
+
|
|
216
|
+
if finish_reason == "tool_calls":
|
|
217
|
+
complete: list[ToolCall] = []
|
|
218
|
+
for item in pending.values():
|
|
219
|
+
if not item["id"] or not item["name"]:
|
|
220
|
+
continue
|
|
221
|
+
try:
|
|
222
|
+
args = json.loads(item["arguments"]) if item["arguments"] else {}
|
|
223
|
+
except json.JSONDecodeError:
|
|
224
|
+
args = {}
|
|
225
|
+
complete.append(
|
|
226
|
+
ToolCall(
|
|
227
|
+
id=item["id"],
|
|
228
|
+
type="function",
|
|
229
|
+
function=FunctionCall(name=item["name"], arguments=args),
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
out_chunk.tool_calls_complete = complete
|
|
233
|
+
out_chunk.finish_reason = finish_reason
|
|
234
|
+
elif finish_reason == "stop":
|
|
235
|
+
out_chunk.finish_reason = finish_reason
|
|
236
|
+
|
|
237
|
+
if hasattr(chunk, "usage") and chunk.usage:
|
|
238
|
+
out_chunk.usage = {
|
|
239
|
+
"prompt_tokens": chunk.usage.prompt_tokens,
|
|
240
|
+
"completion_tokens": chunk.usage.completion_tokens,
|
|
241
|
+
"total_tokens": chunk.usage.total_tokens,
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
yield out_chunk
|