loom-agent 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +77 -0
- loom/agent.py +217 -0
- loom/agents/__init__.py +10 -0
- loom/agents/refs.py +28 -0
- loom/agents/registry.py +50 -0
- loom/builtin/compression/__init__.py +4 -0
- loom/builtin/compression/structured.py +79 -0
- loom/builtin/embeddings/__init__.py +9 -0
- loom/builtin/embeddings/openai_embedding.py +135 -0
- loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
- loom/builtin/llms/__init__.py +8 -0
- loom/builtin/llms/mock.py +34 -0
- loom/builtin/llms/openai.py +168 -0
- loom/builtin/llms/rule.py +102 -0
- loom/builtin/memory/__init__.py +5 -0
- loom/builtin/memory/in_memory.py +21 -0
- loom/builtin/memory/persistent_memory.py +278 -0
- loom/builtin/retriever/__init__.py +9 -0
- loom/builtin/retriever/chroma_store.py +265 -0
- loom/builtin/retriever/in_memory.py +106 -0
- loom/builtin/retriever/milvus_store.py +307 -0
- loom/builtin/retriever/pinecone_store.py +237 -0
- loom/builtin/retriever/qdrant_store.py +274 -0
- loom/builtin/retriever/vector_store.py +128 -0
- loom/builtin/retriever/vector_store_config.py +217 -0
- loom/builtin/tools/__init__.py +32 -0
- loom/builtin/tools/calculator.py +49 -0
- loom/builtin/tools/document_search.py +111 -0
- loom/builtin/tools/glob.py +27 -0
- loom/builtin/tools/grep.py +56 -0
- loom/builtin/tools/http_request.py +86 -0
- loom/builtin/tools/python_repl.py +73 -0
- loom/builtin/tools/read_file.py +32 -0
- loom/builtin/tools/task.py +158 -0
- loom/builtin/tools/web_search.py +64 -0
- loom/builtin/tools/write_file.py +31 -0
- loom/callbacks/base.py +9 -0
- loom/callbacks/logging.py +12 -0
- loom/callbacks/metrics.py +27 -0
- loom/callbacks/observability.py +248 -0
- loom/components/agent.py +107 -0
- loom/core/agent_executor.py +450 -0
- loom/core/circuit_breaker.py +178 -0
- loom/core/compression_manager.py +329 -0
- loom/core/context_retriever.py +185 -0
- loom/core/error_classifier.py +193 -0
- loom/core/errors.py +66 -0
- loom/core/message_queue.py +167 -0
- loom/core/permission_store.py +62 -0
- loom/core/permissions.py +69 -0
- loom/core/scheduler.py +125 -0
- loom/core/steering_control.py +47 -0
- loom/core/structured_logger.py +279 -0
- loom/core/subagent_pool.py +232 -0
- loom/core/system_prompt.py +141 -0
- loom/core/system_reminders.py +283 -0
- loom/core/tool_pipeline.py +113 -0
- loom/core/types.py +269 -0
- loom/interfaces/compressor.py +59 -0
- loom/interfaces/embedding.py +51 -0
- loom/interfaces/llm.py +33 -0
- loom/interfaces/memory.py +29 -0
- loom/interfaces/retriever.py +179 -0
- loom/interfaces/tool.py +27 -0
- loom/interfaces/vector_store.py +80 -0
- loom/llm/__init__.py +14 -0
- loom/llm/config.py +228 -0
- loom/llm/factory.py +111 -0
- loom/llm/model_health.py +235 -0
- loom/llm/model_pool_advanced.py +305 -0
- loom/llm/pool.py +170 -0
- loom/llm/registry.py +201 -0
- loom/mcp/__init__.py +4 -0
- loom/mcp/client.py +86 -0
- loom/mcp/registry.py +58 -0
- loom/mcp/tool_adapter.py +48 -0
- loom/observability/__init__.py +5 -0
- loom/patterns/__init__.py +5 -0
- loom/patterns/multi_agent.py +123 -0
- loom/patterns/rag.py +262 -0
- loom/plugins/registry.py +55 -0
- loom/resilience/__init__.py +5 -0
- loom/tooling.py +72 -0
- loom/utils/agent_loader.py +218 -0
- loom/utils/token_counter.py +19 -0
- loom_agent-0.0.1.dist-info/METADATA +457 -0
- loom_agent-0.0.1.dist-info/RECORD +89 -0
- loom_agent-0.0.1.dist-info/WHEEL +4 -0
- loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
"""CompressionManager: AU2 8-segment context compression (US2)
|
|
2
|
+
|
|
3
|
+
Implements intelligent context compression using LLM-based 8-segment summarization
|
|
4
|
+
with automatic fallback to sliding window on failure.
|
|
5
|
+
|
|
6
|
+
Features:
|
|
7
|
+
- 92% threshold detection
|
|
8
|
+
- 70-80% token reduction via 8-segment structured summarization
|
|
9
|
+
- Retry logic with exponential backoff (1s, 2s, 4s)
|
|
10
|
+
- Sliding window fallback after 3 failures
|
|
11
|
+
- System message preservation
|
|
12
|
+
- Compression metadata tracking
|
|
13
|
+
|
|
14
|
+
Architecture:
|
|
15
|
+
- LLM-based compression (primary): Structured 8-segment summary
|
|
16
|
+
- Sliding window (fallback): Keep last N messages
|
|
17
|
+
- Token counting: tiktoken (cl100k_base for GPT-4/Claude)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import asyncio
|
|
23
|
+
import time
|
|
24
|
+
from typing import List, Tuple, Optional
|
|
25
|
+
|
|
26
|
+
from loom.core.types import Message, CompressionMetadata
|
|
27
|
+
from loom.interfaces.llm import BaseLLM
|
|
28
|
+
from loom.utils.token_counter import count_messages_tokens
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CompressionManager:
|
|
32
|
+
"""Manages context compression with 8-segment LLM-based summarization."""
|
|
33
|
+
|
|
34
|
+
# 8-segment compression prompt template
|
|
35
|
+
COMPRESSION_PROMPT_TEMPLATE = """You are a context compression expert. Your task is to compress a long conversation history into a concise 8-segment structured summary while preserving all critical information.
|
|
36
|
+
|
|
37
|
+
**Input Conversation** ({message_count} messages, {token_count} tokens):
|
|
38
|
+
{conversation}
|
|
39
|
+
|
|
40
|
+
**Compression Requirements**:
|
|
41
|
+
1. Extract and preserve ALL critical information (decisions, blockers, data, context)
|
|
42
|
+
2. Reduce token count by 70-80% (target: {target_tokens} tokens)
|
|
43
|
+
3. Structure output using exactly 8 segments below
|
|
44
|
+
4. Use markdown formatting for readability
|
|
45
|
+
5. Be concise but comprehensive - no detail too small if relevant
|
|
46
|
+
|
|
47
|
+
**Output Format** (8 Segments):
|
|
48
|
+
|
|
49
|
+
1. **Task Overview**: What is the user trying to accomplish? (1-2 sentences)
|
|
50
|
+
2. **Key Decisions**: What important decisions or approaches were chosen? (bullet points)
|
|
51
|
+
3. **Progress**: What has been completed so far? (bullet points with specific data/results)
|
|
52
|
+
4. **Blockers**: What issues or errors occurred? How were they resolved? (bullet points, "None" if none)
|
|
53
|
+
5. **Open Items**: What still needs to be done? What questions remain unanswered? (bullet points)
|
|
54
|
+
6. **Context**: What background information or domain knowledge is relevant? (1-2 sentences)
|
|
55
|
+
7. **Next Steps**: What should happen next based on the conversation? (bullet points)
|
|
56
|
+
8. **Metadata**: Compression statistics and key topics (format: "Compressed {message_count} messages → 1 summary. Topics: topic1, topic2, topic3")
|
|
57
|
+
|
|
58
|
+
**Example Output**:
|
|
59
|
+
```
|
|
60
|
+
**Compressed Context**
|
|
61
|
+
|
|
62
|
+
1. **Task Overview**: User is implementing a REST API for user authentication with JWT tokens.
|
|
63
|
+
|
|
64
|
+
2. **Key Decisions**:
|
|
65
|
+
- Using PostgreSQL for user storage
|
|
66
|
+
- JWT with 7-day expiration
|
|
67
|
+
- Bcrypt for password hashing (cost factor 12)
|
|
68
|
+
|
|
69
|
+
3. **Progress**:
|
|
70
|
+
- Created User model with email/password fields
|
|
71
|
+
- Implemented /register endpoint (working)
|
|
72
|
+
- Implemented /login endpoint (returns JWT)
|
|
73
|
+
- Added password validation (min 8 chars, 1 special char)
|
|
74
|
+
|
|
75
|
+
4. **Blockers**:
|
|
76
|
+
- Initial JWT verification failed due to incorrect secret key → Fixed by using consistent SECRET_KEY env var
|
|
77
|
+
- Database connection timeout → Fixed by increasing pool size to 20
|
|
78
|
+
|
|
79
|
+
5. **Open Items**:
|
|
80
|
+
- Add /refresh endpoint for token renewal
|
|
81
|
+
- Implement rate limiting (5 login attempts per minute)
|
|
82
|
+
- Add email verification flow
|
|
83
|
+
|
|
84
|
+
6. **Context**: This is part of a larger e-commerce platform migration from Django to FastAPI. Authentication needs to be compatible with existing mobile app using JWT.
|
|
85
|
+
|
|
86
|
+
7. **Next Steps**:
|
|
87
|
+
- Implement /refresh endpoint with refresh token logic
|
|
88
|
+
- Add Redis for rate limiting
|
|
89
|
+
- Write integration tests for auth flow
|
|
90
|
+
|
|
91
|
+
8. **Metadata**: Compressed 45 messages → 1 summary. Topics: authentication, JWT, PostgreSQL, FastAPI, API_design
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
Now compress the conversation above following this exact structure:"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
llm: BaseLLM,
|
|
99
|
+
max_retries: int = 3,
|
|
100
|
+
compression_threshold: float = 0.92,
|
|
101
|
+
target_reduction: float = 0.75, # 75% reduction
|
|
102
|
+
sliding_window_size: int = 20,
|
|
103
|
+
):
|
|
104
|
+
"""Initialize CompressionManager.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
llm: LLM instance for compression (should support long context)
|
|
108
|
+
max_retries: Max retry attempts on LLM failure (default: 3)
|
|
109
|
+
compression_threshold: Token usage % to trigger compression (default: 0.92)
|
|
110
|
+
target_reduction: Target compression ratio (default: 0.75 = 75% reduction)
|
|
111
|
+
sliding_window_size: Fallback window size in messages (default: 20)
|
|
112
|
+
"""
|
|
113
|
+
self.llm = llm
|
|
114
|
+
self.max_retries = max_retries
|
|
115
|
+
self.compression_threshold = compression_threshold
|
|
116
|
+
self.target_reduction = target_reduction
|
|
117
|
+
self.sliding_window_size = sliding_window_size
|
|
118
|
+
|
|
119
|
+
def should_compress(self, current_tokens: int, max_tokens: int) -> bool:
|
|
120
|
+
"""Check if compression should be triggered.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
current_tokens: Current context token count
|
|
124
|
+
max_tokens: Maximum allowed context tokens
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
True if current_tokens >= threshold * max_tokens
|
|
128
|
+
"""
|
|
129
|
+
threshold_tokens = int(max_tokens * self.compression_threshold)
|
|
130
|
+
return current_tokens >= threshold_tokens
|
|
131
|
+
|
|
132
|
+
async def compress(
|
|
133
|
+
self, messages: List[Message]
|
|
134
|
+
) -> Tuple[List[Message], CompressionMetadata]:
|
|
135
|
+
"""Compress conversation history using 8-segment LLM summarization.
|
|
136
|
+
|
|
137
|
+
Process:
|
|
138
|
+
1. Separate system messages (never compress)
|
|
139
|
+
2. Extract user/assistant messages for compression
|
|
140
|
+
3. Attempt LLM compression with retry logic
|
|
141
|
+
4. Fall back to sliding window after max_retries failures
|
|
142
|
+
5. Return compressed messages + metadata
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
messages: Full conversation history
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tuple of (compressed_messages, compression_metadata)
|
|
149
|
+
"""
|
|
150
|
+
if not messages:
|
|
151
|
+
return messages, CompressionMetadata(
|
|
152
|
+
original_message_count=0,
|
|
153
|
+
compressed_message_count=0,
|
|
154
|
+
original_tokens=0,
|
|
155
|
+
compressed_tokens=0,
|
|
156
|
+
compression_ratio=0.0,
|
|
157
|
+
key_topics=[],
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Separate system messages (preserve) from compressible messages
|
|
161
|
+
system_messages = [m for m in messages if m.role == "system"]
|
|
162
|
+
compressible = [m for m in messages if m.role in ("user", "assistant", "tool")]
|
|
163
|
+
|
|
164
|
+
if not compressible:
|
|
165
|
+
# No messages to compress, return as-is
|
|
166
|
+
return messages, CompressionMetadata(
|
|
167
|
+
original_message_count=len(messages),
|
|
168
|
+
compressed_message_count=len(messages),
|
|
169
|
+
original_tokens=count_messages_tokens(messages),
|
|
170
|
+
compressed_tokens=count_messages_tokens(messages),
|
|
171
|
+
compression_ratio=1.0,
|
|
172
|
+
key_topics=[],
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Count tokens
|
|
176
|
+
original_tokens = count_messages_tokens(compressible)
|
|
177
|
+
target_tokens = int(original_tokens * self.target_reduction)
|
|
178
|
+
|
|
179
|
+
# Attempt LLM compression with retry logic
|
|
180
|
+
compressed_summary = None
|
|
181
|
+
key_topics = []
|
|
182
|
+
|
|
183
|
+
for attempt in range(1, self.max_retries + 1):
|
|
184
|
+
try:
|
|
185
|
+
compressed_summary, key_topics = await self._llm_compress(
|
|
186
|
+
compressible, original_tokens, target_tokens
|
|
187
|
+
)
|
|
188
|
+
break # Success
|
|
189
|
+
except Exception as e:
|
|
190
|
+
if attempt < self.max_retries:
|
|
191
|
+
# Exponential backoff: 1s, 2s, 4s
|
|
192
|
+
backoff_delay = 2 ** (attempt - 1)
|
|
193
|
+
await asyncio.sleep(backoff_delay)
|
|
194
|
+
else:
|
|
195
|
+
# Max retries exhausted - fall back to sliding window
|
|
196
|
+
compressed_summary = None
|
|
197
|
+
key_topics = ["fallback"]
|
|
198
|
+
|
|
199
|
+
# Fall back to sliding window if LLM compression failed
|
|
200
|
+
if compressed_summary is None:
|
|
201
|
+
windowed_messages = self.sliding_window_fallback(compressible, self.sliding_window_size)
|
|
202
|
+
final_messages = system_messages + windowed_messages
|
|
203
|
+
compressed_tokens = count_messages_tokens(windowed_messages)
|
|
204
|
+
|
|
205
|
+
metadata = CompressionMetadata(
|
|
206
|
+
original_message_count=len(compressible),
|
|
207
|
+
compressed_message_count=len(windowed_messages),
|
|
208
|
+
original_tokens=original_tokens,
|
|
209
|
+
compressed_tokens=compressed_tokens,
|
|
210
|
+
compression_ratio=compressed_tokens / original_tokens if original_tokens > 0 else 0.0,
|
|
211
|
+
key_topics=["fallback"],
|
|
212
|
+
)
|
|
213
|
+
return final_messages, metadata
|
|
214
|
+
|
|
215
|
+
# LLM compression succeeded - create compressed message
|
|
216
|
+
compressed_message = Message(
|
|
217
|
+
role="system",
|
|
218
|
+
content=compressed_summary,
|
|
219
|
+
metadata={"type": "compressed_context", "original_count": len(compressible)},
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
compressed_tokens = count_messages_tokens([compressed_message])
|
|
223
|
+
|
|
224
|
+
# Combine system messages + compressed summary
|
|
225
|
+
final_messages = system_messages + [compressed_message]
|
|
226
|
+
|
|
227
|
+
metadata = CompressionMetadata(
|
|
228
|
+
original_message_count=len(compressible),
|
|
229
|
+
compressed_message_count=1,
|
|
230
|
+
original_tokens=original_tokens,
|
|
231
|
+
compressed_tokens=compressed_tokens,
|
|
232
|
+
compression_ratio=compressed_tokens / original_tokens if original_tokens > 0 else 0.0,
|
|
233
|
+
key_topics=key_topics,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return final_messages, metadata
|
|
237
|
+
|
|
238
|
+
async def _llm_compress(
|
|
239
|
+
self, messages: List[Message], original_tokens: int, target_tokens: int
|
|
240
|
+
) -> Tuple[str, List[str]]:
|
|
241
|
+
"""Use LLM to compress messages into 8-segment summary.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
messages: Messages to compress
|
|
245
|
+
original_tokens: Original token count
|
|
246
|
+
target_tokens: Target token count after compression
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tuple of (compressed_summary_str, key_topics_list)
|
|
250
|
+
|
|
251
|
+
Raises:
|
|
252
|
+
Exception: If LLM call fails
|
|
253
|
+
"""
|
|
254
|
+
# Format conversation for prompt
|
|
255
|
+
conversation_text = self._format_messages_for_prompt(messages)
|
|
256
|
+
|
|
257
|
+
# Build compression prompt
|
|
258
|
+
prompt = self.COMPRESSION_PROMPT_TEMPLATE.format(
|
|
259
|
+
message_count=len(messages),
|
|
260
|
+
token_count=original_tokens,
|
|
261
|
+
target_tokens=target_tokens,
|
|
262
|
+
conversation=conversation_text,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Call LLM
|
|
266
|
+
compressed_summary = await self.llm.generate([{"role": "user", "content": prompt}])
|
|
267
|
+
|
|
268
|
+
# Extract key topics from summary (simple regex extraction)
|
|
269
|
+
key_topics = self._extract_key_topics(compressed_summary)
|
|
270
|
+
|
|
271
|
+
return compressed_summary, key_topics
|
|
272
|
+
|
|
273
|
+
def _format_messages_for_prompt(self, messages: List[Message]) -> str:
|
|
274
|
+
"""Format messages as readable conversation for LLM prompt.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
messages: Messages to format
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
Formatted conversation string
|
|
281
|
+
"""
|
|
282
|
+
lines = []
|
|
283
|
+
for i, msg in enumerate(messages, 1):
|
|
284
|
+
role_label = msg.role.upper()
|
|
285
|
+
content = msg.content[:500] if len(msg.content) > 500 else msg.content # Truncate long messages
|
|
286
|
+
lines.append(f"[{i}] {role_label}: {content}")
|
|
287
|
+
|
|
288
|
+
return "\n".join(lines)
|
|
289
|
+
|
|
290
|
+
def _extract_key_topics(self, compressed_summary: str) -> List[str]:
|
|
291
|
+
"""Extract key topics from compressed summary metadata section.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
compressed_summary: 8-segment compressed summary
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
List of key topics (max 10)
|
|
298
|
+
"""
|
|
299
|
+
# Look for "Topics:" in Metadata section
|
|
300
|
+
topics = []
|
|
301
|
+
if "Topics:" in compressed_summary:
|
|
302
|
+
topics_line = compressed_summary.split("Topics:")[-1].strip()
|
|
303
|
+
# Extract comma-separated topics
|
|
304
|
+
raw_topics = topics_line.split(",")
|
|
305
|
+
topics = [t.strip() for t in raw_topics[:10] if t.strip()]
|
|
306
|
+
|
|
307
|
+
# Fallback: extract topics from content if metadata missing
|
|
308
|
+
if not topics:
|
|
309
|
+
topics = ["general_compression"]
|
|
310
|
+
|
|
311
|
+
return topics
|
|
312
|
+
|
|
313
|
+
def sliding_window_fallback(
|
|
314
|
+
self, messages: List[Message], window_size: int
|
|
315
|
+
) -> List[Message]:
|
|
316
|
+
"""Fallback compression using sliding window (keep last N messages).
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
messages: Messages to compress
|
|
320
|
+
window_size: Number of recent messages to keep
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Last window_size messages
|
|
324
|
+
"""
|
|
325
|
+
if len(messages) <= window_size:
|
|
326
|
+
return messages
|
|
327
|
+
|
|
328
|
+
# Keep last N messages
|
|
329
|
+
return messages[-window_size:]
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""上下文检索器 - AgentExecutor 的核心组件"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from loom.interfaces.retriever import BaseRetriever, Document
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContextRetriever:
|
|
11
|
+
"""
|
|
12
|
+
上下文检索器 - 自动为查询检索相关文档
|
|
13
|
+
|
|
14
|
+
作为 AgentExecutor 的核心组件,在 LLM 调用前自动检索相关文档并注入上下文。
|
|
15
|
+
|
|
16
|
+
使用场景:
|
|
17
|
+
- 知识库问答
|
|
18
|
+
- 文档助手
|
|
19
|
+
- 需要外部知识的任务
|
|
20
|
+
|
|
21
|
+
示例:
|
|
22
|
+
retriever = VectorStoreRetriever(vector_store)
|
|
23
|
+
context_retriever = ContextRetriever(
|
|
24
|
+
retriever=retriever,
|
|
25
|
+
top_k=3,
|
|
26
|
+
auto_retrieve=True
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
agent = Agent(llm=llm, context_retriever=context_retriever)
|
|
30
|
+
# 每次查询都会自动检索相关文档
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
retriever: BaseRetriever,
|
|
36
|
+
top_k: int = 3,
|
|
37
|
+
similarity_threshold: float = 0.0,
|
|
38
|
+
auto_retrieve: bool = True,
|
|
39
|
+
inject_as: str = "system", # "system" or "user_prefix"
|
|
40
|
+
) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Parameters:
|
|
43
|
+
retriever: 检索器实例
|
|
44
|
+
top_k: 检索文档数量
|
|
45
|
+
similarity_threshold: 相关性阈值 (0-1),低于此值的文档会被过滤
|
|
46
|
+
auto_retrieve: 是否自动检索 (False 则需要手动调用)
|
|
47
|
+
inject_as: 注入方式 ("system" 作为独立系统消息, "user_prefix" 作为用户消息前缀)
|
|
48
|
+
"""
|
|
49
|
+
self.retriever = retriever
|
|
50
|
+
self.top_k = top_k
|
|
51
|
+
self.similarity_threshold = similarity_threshold
|
|
52
|
+
self.auto_retrieve = auto_retrieve
|
|
53
|
+
self.inject_as = inject_as
|
|
54
|
+
|
|
55
|
+
async def retrieve_for_query(
|
|
56
|
+
self,
|
|
57
|
+
query: str,
|
|
58
|
+
top_k: Optional[int] = None,
|
|
59
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
60
|
+
) -> List[Document]:
|
|
61
|
+
"""
|
|
62
|
+
为查询检索相关文档
|
|
63
|
+
|
|
64
|
+
Parameters:
|
|
65
|
+
query: 用户查询
|
|
66
|
+
top_k: 覆盖默认的 top_k
|
|
67
|
+
filters: 元数据过滤条件
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
相关文档列表 (已过滤低相关性文档)
|
|
71
|
+
"""
|
|
72
|
+
if not self.auto_retrieve:
|
|
73
|
+
return []
|
|
74
|
+
|
|
75
|
+
k = top_k if top_k is not None else self.top_k
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
docs = await self.retriever.retrieve(
|
|
79
|
+
query=query,
|
|
80
|
+
top_k=k,
|
|
81
|
+
filters=filters,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# 过滤低相关性文档
|
|
85
|
+
if self.similarity_threshold > 0:
|
|
86
|
+
docs = [
|
|
87
|
+
doc for doc in docs
|
|
88
|
+
if doc.score is None or doc.score >= self.similarity_threshold
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
return docs
|
|
92
|
+
|
|
93
|
+
except Exception as e:
|
|
94
|
+
# 检索失败不应该阻塞主流程
|
|
95
|
+
print(f"Warning: Document retrieval failed: {e}")
|
|
96
|
+
return []
|
|
97
|
+
|
|
98
|
+
def format_documents(
|
|
99
|
+
self,
|
|
100
|
+
documents: List[Document],
|
|
101
|
+
max_length: int = 1000,
|
|
102
|
+
) -> str:
|
|
103
|
+
"""
|
|
104
|
+
格式化文档为字符串 (用于注入上下文)
|
|
105
|
+
|
|
106
|
+
Parameters:
|
|
107
|
+
documents: 文档列表
|
|
108
|
+
max_length: 单个文档最大长度
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
格式化的文档字符串
|
|
112
|
+
"""
|
|
113
|
+
if not documents:
|
|
114
|
+
return ""
|
|
115
|
+
|
|
116
|
+
lines = ["## Retrieved Context\n"]
|
|
117
|
+
lines.append(f"Found {len(documents)} relevant document(s):\n")
|
|
118
|
+
|
|
119
|
+
for i, doc in enumerate(documents, 1):
|
|
120
|
+
lines.append(f"### Document {i}")
|
|
121
|
+
|
|
122
|
+
# 元数据
|
|
123
|
+
if doc.metadata:
|
|
124
|
+
source = doc.metadata.get("source", "Unknown")
|
|
125
|
+
lines.append(f"**Source**: {source}")
|
|
126
|
+
|
|
127
|
+
if doc.score is not None:
|
|
128
|
+
lines.append(f"**Relevance**: {doc.score:.2%}")
|
|
129
|
+
|
|
130
|
+
# 内容 (截断)
|
|
131
|
+
content = doc.content
|
|
132
|
+
if len(content) > max_length:
|
|
133
|
+
content = content[:max_length] + "...\n[truncated]"
|
|
134
|
+
|
|
135
|
+
lines.append(f"\n{content}\n")
|
|
136
|
+
|
|
137
|
+
lines.append("---\n")
|
|
138
|
+
lines.append("Please answer the question based on the above context.\n")
|
|
139
|
+
|
|
140
|
+
return "\n".join(lines)
|
|
141
|
+
|
|
142
|
+
def format_as_user_prefix(
|
|
143
|
+
self,
|
|
144
|
+
documents: List[Document],
|
|
145
|
+
user_query: str,
|
|
146
|
+
max_length: int = 1000,
|
|
147
|
+
) -> str:
|
|
148
|
+
"""
|
|
149
|
+
将文档格式化为用户消息的前缀
|
|
150
|
+
|
|
151
|
+
适用于不想增加额外 system 消息的场景
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
"Context: ...\n\nQuestion: {user_query}"
|
|
155
|
+
"""
|
|
156
|
+
if not documents:
|
|
157
|
+
return user_query
|
|
158
|
+
|
|
159
|
+
doc_text = self.format_documents(documents, max_length)
|
|
160
|
+
return f"{doc_text}\nQuestion: {user_query}"
|
|
161
|
+
|
|
162
|
+
def get_metadata_summary(self, documents: List[Document]) -> Dict[str, Any]:
|
|
163
|
+
"""
|
|
164
|
+
获取检索文档的元数据摘要
|
|
165
|
+
|
|
166
|
+
用于日志和调试
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
{"count": 3, "avg_score": 0.85, "sources": [...]}
|
|
170
|
+
"""
|
|
171
|
+
if not documents:
|
|
172
|
+
return {"count": 0}
|
|
173
|
+
|
|
174
|
+
scores = [doc.score for doc in documents if doc.score is not None]
|
|
175
|
+
sources = list(set(
|
|
176
|
+
doc.metadata.get("source", "Unknown")
|
|
177
|
+
for doc in documents
|
|
178
|
+
if doc.metadata
|
|
179
|
+
))
|
|
180
|
+
|
|
181
|
+
return {
|
|
182
|
+
"count": len(documents),
|
|
183
|
+
"avg_score": sum(scores) / len(scores) if scores else None,
|
|
184
|
+
"sources": sources,
|
|
185
|
+
}
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""US5: Error Classification and Retry Logic
|
|
2
|
+
|
|
3
|
+
Classifies errors into retryable/non-retryable categories and provides
|
|
4
|
+
actionable recovery guidance.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
from typing import Optional, Callable, Any, TypeVar
|
|
11
|
+
from loom.core.errors import ErrorCategory, LoomException
|
|
12
|
+
|
|
13
|
+
T = TypeVar('T')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ErrorClassifier:
|
|
17
|
+
"""Classifies errors and determines retry strategy."""
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def classify(error: Exception) -> ErrorCategory:
|
|
21
|
+
"""Classify an error into a category.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
error: The exception to classify
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
ErrorCategory enum value
|
|
28
|
+
"""
|
|
29
|
+
# Check if it's already a LoomException with category
|
|
30
|
+
if isinstance(error, LoomException):
|
|
31
|
+
return error.category
|
|
32
|
+
|
|
33
|
+
# Classify by exception type
|
|
34
|
+
error_type = type(error).__name__
|
|
35
|
+
error_msg = str(error).lower()
|
|
36
|
+
|
|
37
|
+
# Network errors (retryable)
|
|
38
|
+
if any(kw in error_type.lower() for kw in ['timeout', 'connect', 'network']):
|
|
39
|
+
return ErrorCategory.TIMEOUT_ERROR if 'timeout' in error_type.lower() else ErrorCategory.NETWORK_ERROR
|
|
40
|
+
|
|
41
|
+
# Rate limiting (retryable with backoff)
|
|
42
|
+
if '429' in error_msg or 'rate limit' in error_msg:
|
|
43
|
+
return ErrorCategory.RATE_LIMIT_ERROR
|
|
44
|
+
|
|
45
|
+
# Service errors (5xx - retryable)
|
|
46
|
+
if any(code in error_msg for code in ['500', '502', '503', '504']):
|
|
47
|
+
return ErrorCategory.SERVICE_ERROR
|
|
48
|
+
|
|
49
|
+
# Authentication errors (non-retryable)
|
|
50
|
+
if any(code in error_msg for code in ['401', '403']) or 'auth' in error_msg:
|
|
51
|
+
return ErrorCategory.AUTHENTICATION_ERROR
|
|
52
|
+
|
|
53
|
+
# Not found errors (non-retryable)
|
|
54
|
+
if '404' in error_msg or 'not found' in error_msg:
|
|
55
|
+
return ErrorCategory.NOT_FOUND_ERROR
|
|
56
|
+
|
|
57
|
+
# Validation errors (non-retryable)
|
|
58
|
+
if 'validation' in error_type.lower():
|
|
59
|
+
return ErrorCategory.VALIDATION_ERROR
|
|
60
|
+
|
|
61
|
+
# Default: unknown (non-retryable)
|
|
62
|
+
return ErrorCategory.UNKNOWN_ERROR
|
|
63
|
+
|
|
64
|
+
@staticmethod
|
|
65
|
+
def is_retryable(category: ErrorCategory) -> bool:
|
|
66
|
+
"""Determine if an error category is retryable.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
category: The error category
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
True if retryable, False otherwise
|
|
73
|
+
"""
|
|
74
|
+
retryable_categories = {
|
|
75
|
+
ErrorCategory.NETWORK_ERROR,
|
|
76
|
+
ErrorCategory.TIMEOUT_ERROR,
|
|
77
|
+
ErrorCategory.RATE_LIMIT_ERROR,
|
|
78
|
+
ErrorCategory.SERVICE_ERROR,
|
|
79
|
+
}
|
|
80
|
+
return category in retryable_categories
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def get_recovery_guidance(error: Exception, category: ErrorCategory) -> str:
|
|
84
|
+
"""Get actionable recovery guidance for an error.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
error: The exception
|
|
88
|
+
category: The error category
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Human-readable recovery guidance
|
|
92
|
+
"""
|
|
93
|
+
guidance_map = {
|
|
94
|
+
ErrorCategory.NETWORK_ERROR: "Check network connectivity and retry. If problem persists, check service status.",
|
|
95
|
+
ErrorCategory.TIMEOUT_ERROR: "Operation timed out. Try increasing timeout or simplifying the request.",
|
|
96
|
+
ErrorCategory.RATE_LIMIT_ERROR: "Rate limit exceeded. Wait before retrying. Consider implementing backoff.",
|
|
97
|
+
ErrorCategory.SERVICE_ERROR: "Service temporarily unavailable. Retry after a short delay.",
|
|
98
|
+
ErrorCategory.VALIDATION_ERROR: "Invalid input parameters. Review and correct the request.",
|
|
99
|
+
ErrorCategory.PERMISSION_ERROR: "Permission denied. Check access rights and credentials.",
|
|
100
|
+
ErrorCategory.AUTHENTICATION_ERROR: "Authentication failed. Verify credentials and API keys.",
|
|
101
|
+
ErrorCategory.NOT_FOUND_ERROR: "Resource not found. Verify the resource exists and path is correct.",
|
|
102
|
+
ErrorCategory.UNKNOWN_ERROR: f"Unexpected error: {type(error).__name__}. Review error details.",
|
|
103
|
+
}
|
|
104
|
+
return guidance_map.get(category, "Unknown error. Review error details.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class RetryPolicy:
|
|
108
|
+
"""Retry policy with exponential backoff.
|
|
109
|
+
|
|
110
|
+
US5: Automatic retry with exponential backoff (1s, 2s, 4s, max 3 attempts)
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
max_retries: int = 3,
|
|
116
|
+
base_delay: float = 1.0,
|
|
117
|
+
max_delay: float = 60.0,
|
|
118
|
+
exponential_base: float = 2.0,
|
|
119
|
+
):
|
|
120
|
+
"""Initialize retry policy.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
max_retries: Maximum number of retry attempts
|
|
124
|
+
base_delay: Initial delay in seconds
|
|
125
|
+
max_delay: Maximum delay in seconds
|
|
126
|
+
exponential_base: Base for exponential backoff (2.0 = double each time)
|
|
127
|
+
"""
|
|
128
|
+
self.max_retries = max_retries
|
|
129
|
+
self.base_delay = base_delay
|
|
130
|
+
self.max_delay = max_delay
|
|
131
|
+
self.exponential_base = exponential_base
|
|
132
|
+
|
|
133
|
+
def get_delay(self, attempt: int) -> float:
|
|
134
|
+
"""Calculate delay for retry attempt.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
attempt: The retry attempt number (0-indexed)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Delay in seconds
|
|
141
|
+
"""
|
|
142
|
+
delay = self.base_delay * (self.exponential_base ** attempt)
|
|
143
|
+
return min(delay, self.max_delay)
|
|
144
|
+
|
|
145
|
+
async def execute_with_retry(
|
|
146
|
+
self,
|
|
147
|
+
func: Callable[..., Any],
|
|
148
|
+
*args: Any,
|
|
149
|
+
**kwargs: Any,
|
|
150
|
+
) -> Any:
|
|
151
|
+
"""Execute a function with retry logic.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
func: The async function to execute
|
|
155
|
+
*args: Positional arguments
|
|
156
|
+
**kwargs: Keyword arguments
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
The result of the function
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
The last exception if all retries fail
|
|
163
|
+
"""
|
|
164
|
+
last_exception: Optional[Exception] = None
|
|
165
|
+
|
|
166
|
+
for attempt in range(self.max_retries + 1):
|
|
167
|
+
try:
|
|
168
|
+
# Execute function
|
|
169
|
+
if asyncio.iscoroutinefunction(func):
|
|
170
|
+
return await func(*args, **kwargs)
|
|
171
|
+
else:
|
|
172
|
+
return func(*args, **kwargs)
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
last_exception = e
|
|
176
|
+
category = ErrorClassifier.classify(e)
|
|
177
|
+
|
|
178
|
+
# Don't retry if non-retryable
|
|
179
|
+
if not ErrorClassifier.is_retryable(category):
|
|
180
|
+
raise
|
|
181
|
+
|
|
182
|
+
# If this was the last attempt, raise
|
|
183
|
+
if attempt >= self.max_retries:
|
|
184
|
+
raise
|
|
185
|
+
|
|
186
|
+
# Calculate delay and wait
|
|
187
|
+
delay = self.get_delay(attempt)
|
|
188
|
+
await asyncio.sleep(delay)
|
|
189
|
+
|
|
190
|
+
# Should not reach here, but if we do, raise the last exception
|
|
191
|
+
if last_exception:
|
|
192
|
+
raise last_exception
|
|
193
|
+
raise RuntimeError("Retry logic failed unexpectedly")
|