auto-agent-kit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- auto_agent_kit/__init__.py +18 -0
- auto_agent_kit/core/__init__.py +1 -0
- auto_agent_kit/core/access_control.py +202 -0
- auto_agent_kit/core/dashboard.py +145 -0
- auto_agent_kit/core/error_reflection.py +210 -0
- auto_agent_kit/core/mcp_server.py +239 -0
- auto_agent_kit/core/plan_mode.py +174 -0
- auto_agent_kit/core/tool_router.py +130 -0
- auto_agent_kit/examples/demo.py +181 -0
- auto_agent_kit-0.1.0.dist-info/METADATA +264 -0
- auto_agent_kit-0.1.0.dist-info/RECORD +14 -0
- auto_agent_kit-0.1.0.dist-info/WHEEL +5 -0
- auto_agent_kit-0.1.0.dist-info/licenses/LICENSE +21 -0
- auto_agent_kit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""AutoAgentKit — 生产级 AI Agent 工具包"""
|
|
2
|
+
__version__ = "0.1.0"
|
|
3
|
+
|
|
4
|
+
from auto_agent_kit.core.plan_mode import PlanMode
|
|
5
|
+
from auto_agent_kit.core.error_reflection import ErrorReflection, ErrorCategory, RecoveryStrategy
|
|
6
|
+
from auto_agent_kit.core.tool_router import ToolRouter, ToolPhase
|
|
7
|
+
from auto_agent_kit.core.dashboard import Dashboard
|
|
8
|
+
from auto_agent_kit.core.access_control import AccessControl, PermissionLevel
|
|
9
|
+
from auto_agent_kit.core.mcp_server import MCPServer
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PlanMode",
|
|
13
|
+
"ErrorReflection", "ErrorCategory", "RecoveryStrategy",
|
|
14
|
+
"ToolRouter", "ToolPhase",
|
|
15
|
+
"Dashboard",
|
|
16
|
+
"AccessControl", "PermissionLevel",
|
|
17
|
+
"MCPServer",
|
|
18
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""AutoAgentKit 核心模块"""
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""AccessControl — 访问控制模块
|
|
2
|
+
|
|
3
|
+
4 级权限策略 + 操作审批。
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from enum import Enum, auto
|
|
10
|
+
from typing import Any, Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PermissionLevel(Enum):
|
|
14
|
+
"""权限等级"""
|
|
15
|
+
SAFE = auto() # 安全操作 — 自动允许
|
|
16
|
+
NORMAL = auto() # 普通操作 — 自动允许,记录日志
|
|
17
|
+
SENSITIVE = auto() # 敏感操作 — 需要审批
|
|
18
|
+
DANGEROUS = auto() # 危险操作 — 需要审批 + 二次确认
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OperationType(Enum):
|
|
22
|
+
"""操作类型"""
|
|
23
|
+
READ = auto()
|
|
24
|
+
WRITE = auto()
|
|
25
|
+
DELETE = auto()
|
|
26
|
+
EXECUTE = auto()
|
|
27
|
+
NETWORK = auto()
|
|
28
|
+
CONFIG = auto()
|
|
29
|
+
ADMIN = auto()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# 操作类型 → 默认权限等级
|
|
33
|
+
DEFAULT_LEVELS: dict[OperationType, PermissionLevel] = {
|
|
34
|
+
OperationType.READ: PermissionLevel.SAFE,
|
|
35
|
+
OperationType.WRITE: PermissionLevel.NORMAL,
|
|
36
|
+
OperationType.DELETE: PermissionLevel.DANGEROUS,
|
|
37
|
+
OperationType.EXECUTE: PermissionLevel.SENSITIVE,
|
|
38
|
+
OperationType.NETWORK: PermissionLevel.NORMAL,
|
|
39
|
+
OperationType.CONFIG: PermissionLevel.SENSITIVE,
|
|
40
|
+
OperationType.ADMIN: PermissionLevel.DANGEROUS,
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class AccessRule:
|
|
46
|
+
"""访问规则"""
|
|
47
|
+
pattern: str # 路径/命令模式,支持 * 通配符
|
|
48
|
+
level: PermissionLevel
|
|
49
|
+
operation: Optional[OperationType] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class AccessLog:
|
|
54
|
+
"""访问日志"""
|
|
55
|
+
timestamp: float
|
|
56
|
+
operation: str
|
|
57
|
+
level: PermissionLevel
|
|
58
|
+
granted: bool
|
|
59
|
+
reason: str = ""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class AccessControl:
|
|
63
|
+
"""访问控制 — 4 级权限策略"""
|
|
64
|
+
|
|
65
|
+
def __init__(self):
|
|
66
|
+
self._rules: list[AccessRule] = []
|
|
67
|
+
self._logs: list[AccessLog] = []
|
|
68
|
+
self._max_logs: int = 1000
|
|
69
|
+
self._pending_approvals: list[dict] = []
|
|
70
|
+
|
|
71
|
+
def add_rule(self, pattern: str, level: PermissionLevel,
|
|
72
|
+
operation: Optional[OperationType] = None):
|
|
73
|
+
"""添加访问规则"""
|
|
74
|
+
self._rules.append(AccessRule(pattern=pattern, level=level, operation=operation))
|
|
75
|
+
|
|
76
|
+
def add_default_rules(self):
|
|
77
|
+
"""添加默认规则"""
|
|
78
|
+
# 安全操作
|
|
79
|
+
self.add_rule("read:*", PermissionLevel.SAFE, OperationType.READ)
|
|
80
|
+
self.add_rule("list:*", PermissionLevel.SAFE, OperationType.READ)
|
|
81
|
+
self.add_rule("search:*", PermissionLevel.SAFE, OperationType.READ)
|
|
82
|
+
|
|
83
|
+
# 普通操作
|
|
84
|
+
self.add_rule("write:workspace/*", PermissionLevel.NORMAL, OperationType.WRITE)
|
|
85
|
+
self.add_rule("write:memory/*", PermissionLevel.NORMAL, OperationType.WRITE)
|
|
86
|
+
self.add_rule("network:fetch:*", PermissionLevel.NORMAL, OperationType.NETWORK)
|
|
87
|
+
|
|
88
|
+
# 敏感操作
|
|
89
|
+
self.add_rule("execute:python*", PermissionLevel.SENSITIVE, OperationType.EXECUTE)
|
|
90
|
+
self.add_rule("write:config/*", PermissionLevel.SENSITIVE, OperationType.WRITE)
|
|
91
|
+
self.add_rule("network:send:*", PermissionLevel.SENSITIVE, OperationType.NETWORK)
|
|
92
|
+
|
|
93
|
+
# 危险操作
|
|
94
|
+
self.add_rule("delete:*", PermissionLevel.DANGEROUS, OperationType.DELETE)
|
|
95
|
+
self.add_rule("execute:rm*", PermissionLevel.DANGEROUS, OperationType.EXECUTE)
|
|
96
|
+
self.add_rule("execute:format*", PermissionLevel.DANGEROUS, OperationType.EXECUTE)
|
|
97
|
+
self.add_rule("config:system:*", PermissionLevel.DANGEROUS, OperationType.CONFIG)
|
|
98
|
+
|
|
99
|
+
def check(self, operation: str, op_type: Optional[OperationType] = None) -> dict:
|
|
100
|
+
"""检查操作权限"""
|
|
101
|
+
# 查找匹配规则
|
|
102
|
+
matched_rule = None
|
|
103
|
+
for rule in self._rules:
|
|
104
|
+
if self._match_pattern(rule.pattern, operation):
|
|
105
|
+
if op_type is None or rule.operation is None or rule.operation == op_type:
|
|
106
|
+
matched_rule = rule
|
|
107
|
+
break
|
|
108
|
+
|
|
109
|
+
level = matched_rule.level if matched_rule else PermissionLevel.SENSITIVE
|
|
110
|
+
|
|
111
|
+
import time
|
|
112
|
+
result = {
|
|
113
|
+
"operation": operation,
|
|
114
|
+
"level": level.name,
|
|
115
|
+
"granted": False,
|
|
116
|
+
"needs_approval": level in (PermissionLevel.SENSITIVE, PermissionLevel.DANGEROUS),
|
|
117
|
+
"matched_rule": matched_rule.pattern if matched_rule else None,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
# SAFE 和 NORMAL 自动允许
|
|
121
|
+
if level == PermissionLevel.SAFE:
|
|
122
|
+
result["granted"] = True
|
|
123
|
+
elif level == PermissionLevel.NORMAL:
|
|
124
|
+
result["granted"] = True
|
|
125
|
+
|
|
126
|
+
self._log(operation, level, result["granted"])
|
|
127
|
+
return result
|
|
128
|
+
|
|
129
|
+
def request_approval(self, operation: str, reason: str = "") -> dict:
|
|
130
|
+
"""请求审批"""
|
|
131
|
+
import time
|
|
132
|
+
approval = {
|
|
133
|
+
"id": f"app_{int(time.time())}_{len(self._pending_approvals)}",
|
|
134
|
+
"operation": operation,
|
|
135
|
+
"reason": reason,
|
|
136
|
+
"timestamp": time.time(),
|
|
137
|
+
"status": "pending",
|
|
138
|
+
}
|
|
139
|
+
self._pending_approvals.append(approval)
|
|
140
|
+
return approval
|
|
141
|
+
|
|
142
|
+
def approve(self, approval_id: str) -> bool:
|
|
143
|
+
"""批准操作"""
|
|
144
|
+
for a in self._pending_approvals:
|
|
145
|
+
if a["id"] == approval_id and a["status"] == "pending":
|
|
146
|
+
a["status"] = "approved"
|
|
147
|
+
self._log(a["operation"], PermissionLevel.SENSITIVE, True, "approved")
|
|
148
|
+
return True
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
def reject(self, approval_id: str) -> bool:
|
|
152
|
+
"""拒绝操作"""
|
|
153
|
+
for a in self._pending_approvals:
|
|
154
|
+
if a["id"] == approval_id and a["status"] == "pending":
|
|
155
|
+
a["status"] = "rejected"
|
|
156
|
+
self._log(a["operation"], PermissionLevel.SENSITIVE, False, "rejected")
|
|
157
|
+
return True
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
def get_pending_approvals(self) -> list[dict]:
|
|
161
|
+
"""获取待审批列表"""
|
|
162
|
+
return [a for a in self._pending_approvals if a["status"] == "pending"]
|
|
163
|
+
|
|
164
|
+
def get_logs(self, limit: int = 50) -> list[dict]:
|
|
165
|
+
"""获取访问日志"""
|
|
166
|
+
return [
|
|
167
|
+
{"timestamp": l.timestamp, "operation": l.operation,
|
|
168
|
+
"level": l.level.name, "granted": l.granted, "reason": l.reason}
|
|
169
|
+
for l in self._logs[-limit:]
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
def get_stats(self) -> dict:
|
|
173
|
+
"""获取统计"""
|
|
174
|
+
total = len(self._logs)
|
|
175
|
+
granted = sum(1 for l in self._logs if l.granted)
|
|
176
|
+
return {
|
|
177
|
+
"total_checks": total,
|
|
178
|
+
"granted": granted,
|
|
179
|
+
"denied": total - granted,
|
|
180
|
+
"rules_count": len(self._rules),
|
|
181
|
+
"pending_approvals": len(self.get_pending_approvals()),
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
def _log(self, operation: str, level: PermissionLevel, granted: bool, reason: str = ""):
|
|
185
|
+
import time
|
|
186
|
+
self._logs.append(AccessLog(
|
|
187
|
+
timestamp=time.time(),
|
|
188
|
+
operation=operation,
|
|
189
|
+
level=level,
|
|
190
|
+
granted=granted,
|
|
191
|
+
reason=reason,
|
|
192
|
+
))
|
|
193
|
+
if len(self._logs) > self._max_logs:
|
|
194
|
+
self._logs = self._logs[-self._max_logs:]
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def _match_pattern(pattern: str, operation: str) -> bool:
|
|
198
|
+
"""简单的通配符匹配"""
|
|
199
|
+
if pattern.endswith("*"):
|
|
200
|
+
prefix = pattern[:-1]
|
|
201
|
+
return operation.startswith(prefix)
|
|
202
|
+
return pattern == operation
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Dashboard — 仪表板
|
|
2
|
+
|
|
3
|
+
实时监控 Agent 运行指标。
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class MetricPoint:
|
|
16
|
+
"""单个指标数据点"""
|
|
17
|
+
timestamp: float
|
|
18
|
+
value: float
|
|
19
|
+
label: str = ""
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class MetricSeries:
|
|
24
|
+
"""指标时间序列"""
|
|
25
|
+
name: str
|
|
26
|
+
points: list[MetricPoint] = field(default_factory=list)
|
|
27
|
+
max_points: int = 1000
|
|
28
|
+
|
|
29
|
+
def add(self, value: float, label: str = ""):
|
|
30
|
+
self.points.append(MetricPoint(timestamp=time.time(), value=value, label=label))
|
|
31
|
+
if len(self.points) > self.max_points:
|
|
32
|
+
self.points = self.points[-self.max_points:]
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def last(self) -> Optional[float]:
|
|
36
|
+
return self.points[-1].value if self.points else None
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def avg(self) -> float:
|
|
40
|
+
if not self.points:
|
|
41
|
+
return 0.0
|
|
42
|
+
return sum(p.value for p in self.points) / len(self.points)
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def min(self) -> float:
|
|
46
|
+
return min(p.value for p in self.points) if self.points else 0.0
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def max(self) -> float:
|
|
50
|
+
return max(p.value for p in self.points) if self.points else 0.0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Dashboard:
|
|
54
|
+
"""仪表板 — 实时监控 Agent 运行指标"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, max_series_points: int = 1000):
|
|
57
|
+
self._series: dict[str, MetricSeries] = {}
|
|
58
|
+
self._events: list[dict] = []
|
|
59
|
+
self._max_events: int = 500
|
|
60
|
+
self._max_series_points = max_series_points
|
|
61
|
+
self._started_at: float = time.time()
|
|
62
|
+
|
|
63
|
+
def _get_series(self, name: str) -> MetricSeries:
|
|
64
|
+
if name not in self._series:
|
|
65
|
+
self._series[name] = MetricSeries(name=name, max_points=self._max_series_points)
|
|
66
|
+
return self._series[name]
|
|
67
|
+
|
|
68
|
+
def record(self, metric: str, value: float, label: str = ""):
|
|
69
|
+
"""记录一个指标值"""
|
|
70
|
+
self._get_series(metric).add(value, label)
|
|
71
|
+
|
|
72
|
+
def record_event(self, event_type: str, data: Optional[dict] = None):
|
|
73
|
+
"""记录一个事件"""
|
|
74
|
+
self._events.append({
|
|
75
|
+
"timestamp": time.time(),
|
|
76
|
+
"type": event_type,
|
|
77
|
+
"data": data or {},
|
|
78
|
+
})
|
|
79
|
+
if len(self._events) > self._max_events:
|
|
80
|
+
self._events = self._events[-self._max_events:]
|
|
81
|
+
|
|
82
|
+
def record_tool_call(self, tool_name: str, success: bool, duration_ms: float):
|
|
83
|
+
"""记录工具调用"""
|
|
84
|
+
self.record(f"tool.{tool_name}.duration", duration_ms)
|
|
85
|
+
self.record("tool.calls", 1)
|
|
86
|
+
if not success:
|
|
87
|
+
self.record("tool.errors", 1)
|
|
88
|
+
self.record_event("tool_call", {
|
|
89
|
+
"tool": tool_name,
|
|
90
|
+
"success": success,
|
|
91
|
+
"duration_ms": duration_ms,
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
def get_metric(self, name: str) -> Optional[MetricSeries]:
|
|
95
|
+
"""获取指标序列"""
|
|
96
|
+
return self._series.get(name)
|
|
97
|
+
|
|
98
|
+
def get_snapshot(self) -> dict:
|
|
99
|
+
"""获取当前快照"""
|
|
100
|
+
uptime = time.time() - self._started_at
|
|
101
|
+
snapshot = {
|
|
102
|
+
"uptime_seconds": uptime,
|
|
103
|
+
"uptime_formatted": self._format_duration(uptime),
|
|
104
|
+
"metrics": {},
|
|
105
|
+
"recent_events": self._events[-20:],
|
|
106
|
+
}
|
|
107
|
+
for name, series in self._series.items():
|
|
108
|
+
snapshot["metrics"][name] = {
|
|
109
|
+
"last": series.last,
|
|
110
|
+
"avg": series.avg,
|
|
111
|
+
"min": series.min,
|
|
112
|
+
"max": series.max,
|
|
113
|
+
"count": len(series.points),
|
|
114
|
+
}
|
|
115
|
+
return snapshot
|
|
116
|
+
|
|
117
|
+
def get_summary(self) -> str:
|
|
118
|
+
"""获取文本摘要"""
|
|
119
|
+
s = self.get_snapshot()
|
|
120
|
+
lines = [f"📊 Dashboard — 运行 {s['uptime_formatted']}"]
|
|
121
|
+
for name, m in s["metrics"].items():
|
|
122
|
+
lines.append(f" {name}: last={m['last']:.2f} avg={m['avg']:.2f} count={m['count']}")
|
|
123
|
+
return "\n".join(lines)
|
|
124
|
+
|
|
125
|
+
def to_json(self, path: str):
|
|
126
|
+
"""导出到 JSON 文件"""
|
|
127
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
128
|
+
json.dump(self.get_snapshot(), f, ensure_ascii=False, indent=2)
|
|
129
|
+
|
|
130
|
+
def reset(self):
|
|
131
|
+
"""重置所有指标"""
|
|
132
|
+
self._series.clear()
|
|
133
|
+
self._events.clear()
|
|
134
|
+
self._started_at = time.time()
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def _format_duration(seconds: float) -> str:
|
|
138
|
+
hours = int(seconds // 3600)
|
|
139
|
+
minutes = int((seconds % 3600) // 60)
|
|
140
|
+
secs = int(seconds % 60)
|
|
141
|
+
if hours > 0:
|
|
142
|
+
return f"{hours}h{minutes}m{secs}s"
|
|
143
|
+
elif minutes > 0:
|
|
144
|
+
return f"{minutes}m{secs}s"
|
|
145
|
+
return f"{secs}s"
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""ErrorReflection — 错误反射模块
|
|
2
|
+
|
|
3
|
+
工具失败自动分类(20+类型),精确恢复策略。
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ErrorCategory(Enum):
|
|
15
|
+
"""错误分类(20+ 类型)"""
|
|
16
|
+
# 认证/授权
|
|
17
|
+
AUTH_INVALID = "auth_invalid"
|
|
18
|
+
AUTH_EXPIRED = "auth_expired"
|
|
19
|
+
PERMISSION_DENIED = "permission_denied"
|
|
20
|
+
# 计费/配额
|
|
21
|
+
BILLING_ERROR = "billing_error"
|
|
22
|
+
RATE_LIMIT = "rate_limit"
|
|
23
|
+
QUOTA_EXCEEDED = "quota_exceeded"
|
|
24
|
+
# 网络/超时
|
|
25
|
+
TIMEOUT = "timeout"
|
|
26
|
+
NETWORK_ERROR = "network_error"
|
|
27
|
+
DNS_FAILURE = "dns_failure"
|
|
28
|
+
CONNECTION_REFUSED = "connection_refused"
|
|
29
|
+
# 服务端
|
|
30
|
+
SERVER_ERROR = "server_error" # 5xx
|
|
31
|
+
SERVICE_UNAVAILABLE = "service_unavailable"
|
|
32
|
+
# 客户端
|
|
33
|
+
BAD_REQUEST = "bad_request" # 4xx
|
|
34
|
+
NOT_FOUND = "not_found"
|
|
35
|
+
VALIDATION_ERROR = "validation_error"
|
|
36
|
+
# 数据
|
|
37
|
+
PARSE_ERROR = "parse_error"
|
|
38
|
+
ENCODING_ERROR = "encoding_error"
|
|
39
|
+
DATA_INTEGRITY = "data_integrity"
|
|
40
|
+
# 上下文
|
|
41
|
+
CONTEXT_OVERFLOW = "context_overflow"
|
|
42
|
+
CONTENT_FILTER = "content_filter"
|
|
43
|
+
# 系统
|
|
44
|
+
RESOURCE_EXHAUSTED = "resource_exhausted"
|
|
45
|
+
INTERNAL_ERROR = "internal_error"
|
|
46
|
+
UNKNOWN = "unknown"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RecoveryStrategy(Enum):
|
|
50
|
+
"""恢复策略"""
|
|
51
|
+
RETRY = "retry" # 直接重试
|
|
52
|
+
EXPONENTIAL_BACKOFF = "exponential_backoff" # 指数退避重试
|
|
53
|
+
ROTATE_CREDENTIAL = "rotate_credential" # 轮换凭证
|
|
54
|
+
COMPRESS_CONTEXT = "compress_context" # 压缩上下文
|
|
55
|
+
FALLBACK_MODEL = "fallback_model" # 切换备用模型
|
|
56
|
+
FALLBACK_PROVIDER = "fallback_provider" # 切换备用提供商
|
|
57
|
+
DEGRADE = "degrade" # 降级返回部分结果
|
|
58
|
+
ABORT = "abort" # 放弃
|
|
59
|
+
RETRY_WITH_DELAY = "retry_with_delay" # 延迟后重试
|
|
60
|
+
CLEAR_CACHE = "clear_cache" # 清除缓存
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class ErrorRecord:
|
|
65
|
+
"""错误记录"""
|
|
66
|
+
timestamp: float = field(default_factory=time.time)
|
|
67
|
+
category: ErrorCategory = ErrorCategory.UNKNOWN
|
|
68
|
+
message: str = ""
|
|
69
|
+
source: str = ""
|
|
70
|
+
context: dict[str, Any] = field(default_factory=dict)
|
|
71
|
+
recovery_used: Optional[RecoveryStrategy] = None
|
|
72
|
+
recovered: bool = False
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ErrorReflection:
|
|
76
|
+
"""错误反射 — 自动分类错误并选择恢复策略"""
|
|
77
|
+
|
|
78
|
+
# 错误 → 恢复策略映射
|
|
79
|
+
STRATEGY_MAP: dict[ErrorCategory, RecoveryStrategy] = {
|
|
80
|
+
ErrorCategory.AUTH_INVALID: RecoveryStrategy.ROTATE_CREDENTIAL,
|
|
81
|
+
ErrorCategory.AUTH_EXPIRED: RecoveryStrategy.ROTATE_CREDENTIAL,
|
|
82
|
+
ErrorCategory.PERMISSION_DENIED: RecoveryStrategy.ABORT,
|
|
83
|
+
ErrorCategory.BILLING_ERROR: RecoveryStrategy.ABORT,
|
|
84
|
+
ErrorCategory.RATE_LIMIT: RecoveryStrategy.EXPONENTIAL_BACKOFF,
|
|
85
|
+
ErrorCategory.QUOTA_EXCEEDED: RecoveryStrategy.FALLBACK_PROVIDER,
|
|
86
|
+
ErrorCategory.TIMEOUT: RecoveryStrategy.EXPONENTIAL_BACKOFF,
|
|
87
|
+
ErrorCategory.NETWORK_ERROR: RecoveryStrategy.RETRY,
|
|
88
|
+
ErrorCategory.DNS_FAILURE: RecoveryStrategy.RETRY_WITH_DELAY,
|
|
89
|
+
ErrorCategory.CONNECTION_REFUSED: RecoveryStrategy.RETRY_WITH_DELAY,
|
|
90
|
+
ErrorCategory.SERVER_ERROR: RecoveryStrategy.EXPONENTIAL_BACKOFF,
|
|
91
|
+
ErrorCategory.SERVICE_UNAVAILABLE: RecoveryStrategy.FALLBACK_PROVIDER,
|
|
92
|
+
ErrorCategory.BAD_REQUEST: RecoveryStrategy.ABORT,
|
|
93
|
+
ErrorCategory.NOT_FOUND: RecoveryStrategy.ABORT,
|
|
94
|
+
ErrorCategory.VALIDATION_ERROR: RecoveryStrategy.ABORT,
|
|
95
|
+
ErrorCategory.PARSE_ERROR: RecoveryStrategy.RETRY,
|
|
96
|
+
ErrorCategory.ENCODING_ERROR: RecoveryStrategy.RETRY,
|
|
97
|
+
ErrorCategory.DATA_INTEGRITY: RecoveryStrategy.ABORT,
|
|
98
|
+
ErrorCategory.CONTEXT_OVERFLOW: RecoveryStrategy.COMPRESS_CONTEXT,
|
|
99
|
+
ErrorCategory.CONTENT_FILTER: RecoveryStrategy.DEGRADE,
|
|
100
|
+
ErrorCategory.RESOURCE_EXHAUSTED: RecoveryStrategy.FALLBACK_MODEL,
|
|
101
|
+
ErrorCategory.INTERNAL_ERROR: RecoveryStrategy.RETRY,
|
|
102
|
+
ErrorCategory.UNKNOWN: RecoveryStrategy.RETRY,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# 关键词 → 错误分类映射
|
|
106
|
+
KEYWORD_MAP: list[tuple[list[str], ErrorCategory]] = [
|
|
107
|
+
(["rate limit", "too many requests", "429"], ErrorCategory.RATE_LIMIT),
|
|
108
|
+
(["timeout", "timed out", "deadline exceeded"], ErrorCategory.TIMEOUT),
|
|
109
|
+
(["auth", "unauthorized", "401", "403", "api key"], ErrorCategory.AUTH_INVALID),
|
|
110
|
+
(["expired", "token expired"], ErrorCategory.AUTH_EXPIRED),
|
|
111
|
+
(["permission", "forbidden", "not allowed"], ErrorCategory.PERMISSION_DENIED),
|
|
112
|
+
(["billing", "quota", "insufficient"], ErrorCategory.BILLING_ERROR),
|
|
113
|
+
(["context length", "context window", "token limit"], ErrorCategory.CONTEXT_OVERFLOW),
|
|
114
|
+
(["content filter", "content policy", "safety"], ErrorCategory.CONTENT_FILTER),
|
|
115
|
+
(["not found", "404"], ErrorCategory.NOT_FOUND),
|
|
116
|
+
(["server error", "500", "502", "503"], ErrorCategory.SERVER_ERROR),
|
|
117
|
+
(["connection refused", "econnrefused"], ErrorCategory.CONNECTION_REFUSED),
|
|
118
|
+
(["dns", "enotfound", "getaddrinfo"], ErrorCategory.DNS_FAILURE),
|
|
119
|
+
(["parse", "json decode", "unexpected token"], ErrorCategory.PARSE_ERROR),
|
|
120
|
+
(["encoding", "unicode", "utf"], ErrorCategory.ENCODING_ERROR),
|
|
121
|
+
(["validation", "invalid"], ErrorCategory.VALIDATION_ERROR),
|
|
122
|
+
(["resource exhausted", "memory", "disk"], ErrorCategory.RESOURCE_EXHAUSTED),
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
def __init__(self, max_retries: int = 3):
|
|
126
|
+
self.max_retries = max_retries
|
|
127
|
+
self.history: list[ErrorRecord] = []
|
|
128
|
+
self._consecutive_failures: dict[str, int] = {}
|
|
129
|
+
|
|
130
|
+
def classify(self, error: Exception | str, source: str = "") -> ErrorCategory:
|
|
131
|
+
"""分类错误"""
|
|
132
|
+
msg = str(error).lower()
|
|
133
|
+
for keywords, category in self.KEYWORD_MAP:
|
|
134
|
+
if any(kw in msg for kw in keywords):
|
|
135
|
+
return category
|
|
136
|
+
return ErrorCategory.UNKNOWN
|
|
137
|
+
|
|
138
|
+
def get_strategy(self, category: ErrorCategory) -> RecoveryStrategy:
|
|
139
|
+
"""获取恢复策略"""
|
|
140
|
+
return self.STRATEGY_MAP.get(category, RecoveryStrategy.RETRY)
|
|
141
|
+
|
|
142
|
+
def classify_and_recover(
|
|
143
|
+
self,
|
|
144
|
+
error: Exception | str,
|
|
145
|
+
source: str = "",
|
|
146
|
+
context: Optional[dict] = None,
|
|
147
|
+
) -> dict:
|
|
148
|
+
"""分类错误并返回恢复建议"""
|
|
149
|
+
category = self.classify(error, source)
|
|
150
|
+
strategy = self.get_strategy(category)
|
|
151
|
+
|
|
152
|
+
record = ErrorRecord(
|
|
153
|
+
category=category,
|
|
154
|
+
message=str(error),
|
|
155
|
+
source=source,
|
|
156
|
+
context=context or {},
|
|
157
|
+
recovery_used=strategy,
|
|
158
|
+
)
|
|
159
|
+
self.history.append(record)
|
|
160
|
+
|
|
161
|
+
# 追踪连续失败
|
|
162
|
+
key = f"{source}:{category.value}"
|
|
163
|
+
self._consecutive_failures[key] = self._consecutive_failures.get(key, 0) + 1
|
|
164
|
+
consecutive = self._consecutive_failures[key]
|
|
165
|
+
|
|
166
|
+
# 连续失败升级策略
|
|
167
|
+
if consecutive >= 3 and strategy in (RecoveryStrategy.RETRY, RecoveryStrategy.EXPONENTIAL_BACKOFF):
|
|
168
|
+
upgraded = RecoveryStrategy.FALLBACK_PROVIDER
|
|
169
|
+
record.recovery_used = upgraded
|
|
170
|
+
return {
|
|
171
|
+
"category": category.value,
|
|
172
|
+
"strategy": upgraded.value,
|
|
173
|
+
"consecutive_failures": consecutive,
|
|
174
|
+
"message": str(error),
|
|
175
|
+
"upgraded": True,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
"category": category.value,
|
|
180
|
+
"strategy": strategy.value,
|
|
181
|
+
"consecutive_failures": consecutive,
|
|
182
|
+
"message": str(error),
|
|
183
|
+
"upgraded": False,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
def report_recovery(self, success: bool, error_ref: Optional[str] = None):
|
|
187
|
+
"""报告恢复结果"""
|
|
188
|
+
if self.history:
|
|
189
|
+
self.history[-1].recovered = success
|
|
190
|
+
if success:
|
|
191
|
+
key = f"{self.history[-1].source}:{self.history[-1].category.value}"
|
|
192
|
+
self._consecutive_failures[key] = 0
|
|
193
|
+
|
|
194
|
+
def get_stats(self) -> dict:
|
|
195
|
+
"""获取错误统计"""
|
|
196
|
+
total = len(self.history)
|
|
197
|
+
if total == 0:
|
|
198
|
+
return {"total": 0}
|
|
199
|
+
by_category: dict[str, int] = {}
|
|
200
|
+
recovered = 0
|
|
201
|
+
for r in self.history:
|
|
202
|
+
by_category[r.category.value] = by_category.get(r.category.value, 0) + 1
|
|
203
|
+
if r.recovered:
|
|
204
|
+
recovered += 1
|
|
205
|
+
return {
|
|
206
|
+
"total": total,
|
|
207
|
+
"by_category": by_category,
|
|
208
|
+
"recovered": recovered,
|
|
209
|
+
"recovery_rate": recovered / total if total > 0 else 0,
|
|
210
|
+
}
|