loom-agent 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +51 -0
- loom/api/__init__.py +19 -0
- loom/api/v0_0_3.py +300 -0
- loom/builtin/retriever/faiss_store.py +403 -0
- loom/core/agent_executor.py +212 -26
- loom/core/events.py +3 -0
- loom/core/recursion_control.py +298 -0
- loom/core/turn_state.py +58 -6
- loom/retrieval/__init__.py +61 -0
- loom/retrieval/domain_adapter.py +195 -0
- loom/retrieval/embedding_retriever.py +393 -0
- loom_agent-0.0.5.dist-info/METADATA +561 -0
- {loom_agent-0.0.3.dist-info → loom_agent-0.0.5.dist-info}/RECORD +15 -8
- loom_agent-0.0.3.dist-info/METADATA +0 -292
- {loom_agent-0.0.3.dist-info → loom_agent-0.0.5.dist-info}/WHEEL +0 -0
- {loom_agent-0.0.3.dist-info → loom_agent-0.0.5.dist-info}/licenses/LICENSE +0 -0
loom/core/events.py
CHANGED
|
@@ -114,6 +114,9 @@ class AgentEventType(Enum):
|
|
|
114
114
|
RECURSION = "recursion"
|
|
115
115
|
"""Recursive call initiated (tt mode)"""
|
|
116
116
|
|
|
117
|
+
RECURSION_TERMINATED = "recursion_terminated"
|
|
118
|
+
"""Recursion terminated due to loop detection or limits (Phase 2 optimization)"""
|
|
119
|
+
|
|
117
120
|
AGENT_FINISH = "agent_finish"
|
|
118
121
|
"""Agent execution finished successfully"""
|
|
119
122
|
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Recursion Control for Agent Executor
|
|
3
|
+
|
|
4
|
+
Provides generic recursion termination detection to prevent infinite loops
|
|
5
|
+
in agent execution. This is a framework-level capability that doesn't depend
|
|
6
|
+
on specific business logic.
|
|
7
|
+
|
|
8
|
+
New in Loom 0.0.4: Phase 2 - Execution Layer Optimization
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import List, Optional, Any
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TerminationReason(str, Enum):
|
|
19
|
+
"""Reasons for terminating recursive execution"""
|
|
20
|
+
|
|
21
|
+
MAX_ITERATIONS = "max_iterations"
|
|
22
|
+
"""Maximum iteration limit reached"""
|
|
23
|
+
|
|
24
|
+
DUPLICATE_TOOLS = "duplicate_tools"
|
|
25
|
+
"""Detected repeated tool calls (same tool called multiple times in a row)"""
|
|
26
|
+
|
|
27
|
+
LOOP_DETECTED = "loop_detected"
|
|
28
|
+
"""Detected a pattern loop in outputs"""
|
|
29
|
+
|
|
30
|
+
ERROR_THRESHOLD = "error_threshold"
|
|
31
|
+
"""Error rate exceeded acceptable threshold"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RecursionState:
|
|
36
|
+
"""
|
|
37
|
+
State information for recursion monitoring.
|
|
38
|
+
|
|
39
|
+
This is a separate state object from TurnState to avoid coupling
|
|
40
|
+
the recursion control logic with the turn management logic.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
iteration: Current iteration count (0-based)
|
|
44
|
+
tool_call_history: List of tool names called in recent iterations
|
|
45
|
+
error_count: Number of errors encountered so far
|
|
46
|
+
last_outputs: Recent output samples for loop detection
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
iteration: int
|
|
50
|
+
"""Current iteration count (0-based)"""
|
|
51
|
+
|
|
52
|
+
tool_call_history: List[str]
|
|
53
|
+
"""History of tool names called (for duplicate detection)"""
|
|
54
|
+
|
|
55
|
+
error_count: int
|
|
56
|
+
"""Number of errors encountered during execution"""
|
|
57
|
+
|
|
58
|
+
last_outputs: List[Any]
|
|
59
|
+
"""Recent outputs for loop pattern detection"""
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class RecursionMonitor:
|
|
63
|
+
"""
|
|
64
|
+
Generic recursion monitoring and termination detection.
|
|
65
|
+
|
|
66
|
+
This monitor provides framework-level recursion control without
|
|
67
|
+
depending on any specific business logic. It detects common
|
|
68
|
+
infinite loop patterns:
|
|
69
|
+
|
|
70
|
+
1. Maximum iteration limit
|
|
71
|
+
2. Repeated tool calls (same tool called N times in a row)
|
|
72
|
+
3. Loop patterns in outputs
|
|
73
|
+
4. High error rates
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
```python
|
|
77
|
+
# Create monitor with custom thresholds
|
|
78
|
+
monitor = RecursionMonitor(
|
|
79
|
+
max_iterations=50,
|
|
80
|
+
duplicate_threshold=3,
|
|
81
|
+
loop_detection_window=5,
|
|
82
|
+
error_threshold=0.5
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Check if should terminate
|
|
86
|
+
state = RecursionState(
|
|
87
|
+
iteration=10,
|
|
88
|
+
tool_call_history=["search", "search", "search"],
|
|
89
|
+
error_count=2,
|
|
90
|
+
last_outputs=[]
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
reason = monitor.check_termination(state)
|
|
94
|
+
if reason:
|
|
95
|
+
message = monitor.build_termination_message(reason)
|
|
96
|
+
print(f"Terminating: {message}")
|
|
97
|
+
```
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
max_iterations: int = 50,
|
|
103
|
+
duplicate_threshold: int = 3,
|
|
104
|
+
loop_detection_window: int = 5,
|
|
105
|
+
error_threshold: float = 0.5
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
Initialize recursion monitor.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
max_iterations: Maximum number of recursive iterations allowed
|
|
112
|
+
duplicate_threshold: Number of consecutive duplicate tool calls before terminating
|
|
113
|
+
loop_detection_window: Window size for loop pattern detection
|
|
114
|
+
error_threshold: Maximum error rate (errors/iterations) before terminating
|
|
115
|
+
"""
|
|
116
|
+
self.max_iterations = max_iterations
|
|
117
|
+
self.duplicate_threshold = duplicate_threshold
|
|
118
|
+
self.loop_detection_window = loop_detection_window
|
|
119
|
+
self.error_threshold = error_threshold
|
|
120
|
+
|
|
121
|
+
def check_termination(
|
|
122
|
+
self,
|
|
123
|
+
state: RecursionState
|
|
124
|
+
) -> Optional[TerminationReason]:
|
|
125
|
+
"""
|
|
126
|
+
Check if recursive execution should terminate.
|
|
127
|
+
|
|
128
|
+
This method runs multiple checks in priority order:
|
|
129
|
+
1. Max iterations (highest priority - hard limit)
|
|
130
|
+
2. Duplicate tool calls (likely stuck)
|
|
131
|
+
3. Loop patterns (repeating behavior)
|
|
132
|
+
4. Error threshold (too many failures)
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
state: Current recursion state
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
TerminationReason if should terminate, None to continue
|
|
139
|
+
"""
|
|
140
|
+
# Check 1: Maximum iterations (hard limit)
|
|
141
|
+
if state.iteration >= self.max_iterations:
|
|
142
|
+
return TerminationReason.MAX_ITERATIONS
|
|
143
|
+
|
|
144
|
+
# Check 2: Duplicate tool calls (likely stuck)
|
|
145
|
+
if self._detect_duplicate_tools(state.tool_call_history):
|
|
146
|
+
return TerminationReason.DUPLICATE_TOOLS
|
|
147
|
+
|
|
148
|
+
# Check 3: Loop patterns in outputs
|
|
149
|
+
if self._detect_loop_pattern(state.last_outputs):
|
|
150
|
+
return TerminationReason.LOOP_DETECTED
|
|
151
|
+
|
|
152
|
+
# Check 4: Error rate threshold
|
|
153
|
+
if self._check_error_threshold(state):
|
|
154
|
+
return TerminationReason.ERROR_THRESHOLD
|
|
155
|
+
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def _detect_duplicate_tools(self, tool_history: List[str]) -> bool:
|
|
159
|
+
"""
|
|
160
|
+
Detect if the same tool has been called too many times in a row.
|
|
161
|
+
|
|
162
|
+
This indicates the agent is stuck in a loop, repeatedly trying
|
|
163
|
+
the same tool without making progress.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
tool_history: List of tool names (most recent last)
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
True if duplicate pattern detected
|
|
170
|
+
"""
|
|
171
|
+
if len(tool_history) < self.duplicate_threshold:
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# Check last N tool calls
|
|
175
|
+
recent = tool_history[-self.duplicate_threshold:]
|
|
176
|
+
|
|
177
|
+
# All the same? -> Stuck in loop
|
|
178
|
+
return len(set(recent)) == 1
|
|
179
|
+
|
|
180
|
+
def _detect_loop_pattern(self, outputs: List[Any]) -> bool:
|
|
181
|
+
"""
|
|
182
|
+
Detect if outputs are repeating in a pattern.
|
|
183
|
+
|
|
184
|
+
This checks if the agent is generating the same outputs
|
|
185
|
+
repeatedly, indicating a stuck state.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
outputs: Recent output values
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
True if loop pattern detected
|
|
192
|
+
"""
|
|
193
|
+
if len(outputs) < self.loop_detection_window * 2:
|
|
194
|
+
return False
|
|
195
|
+
|
|
196
|
+
window_size = self.loop_detection_window
|
|
197
|
+
recent = outputs[-window_size * 2:]
|
|
198
|
+
|
|
199
|
+
# Split into two halves and compare
|
|
200
|
+
first_half = recent[:window_size]
|
|
201
|
+
second_half = recent[window_size:]
|
|
202
|
+
|
|
203
|
+
# If both halves are identical, we have a loop
|
|
204
|
+
return first_half == second_half
|
|
205
|
+
|
|
206
|
+
def _check_error_threshold(self, state: RecursionState) -> bool:
|
|
207
|
+
"""
|
|
208
|
+
Check if error rate exceeds acceptable threshold.
|
|
209
|
+
|
|
210
|
+
Too many errors indicate the agent cannot complete the task
|
|
211
|
+
and should stop trying.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
state: Current recursion state
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
True if error rate exceeds threshold
|
|
218
|
+
"""
|
|
219
|
+
if state.iteration == 0:
|
|
220
|
+
return False
|
|
221
|
+
|
|
222
|
+
error_rate = state.error_count / state.iteration
|
|
223
|
+
return error_rate > self.error_threshold
|
|
224
|
+
|
|
225
|
+
def build_termination_message(
|
|
226
|
+
self,
|
|
227
|
+
reason: TerminationReason
|
|
228
|
+
) -> str:
|
|
229
|
+
"""
|
|
230
|
+
Build a user-friendly termination message.
|
|
231
|
+
|
|
232
|
+
This message is injected into the conversation to prompt
|
|
233
|
+
the LLM to complete the task with available information.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
reason: The termination reason
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Formatted termination message
|
|
240
|
+
"""
|
|
241
|
+
messages = {
|
|
242
|
+
TerminationReason.DUPLICATE_TOOLS: (
|
|
243
|
+
"⚠️ Detected repeated tool calls. "
|
|
244
|
+
"Please proceed with available information."
|
|
245
|
+
),
|
|
246
|
+
TerminationReason.LOOP_DETECTED: (
|
|
247
|
+
"⚠️ Detected execution loop. "
|
|
248
|
+
"Please break the pattern and complete the task."
|
|
249
|
+
),
|
|
250
|
+
TerminationReason.MAX_ITERATIONS: (
|
|
251
|
+
"⚠️ Maximum iterations reached. "
|
|
252
|
+
"Please provide the best answer with current information."
|
|
253
|
+
),
|
|
254
|
+
TerminationReason.ERROR_THRESHOLD: (
|
|
255
|
+
"⚠️ Too many errors occurred. "
|
|
256
|
+
"Please complete the task with current information."
|
|
257
|
+
)
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
return messages.get(reason, "Please complete the task now.")
|
|
261
|
+
|
|
262
|
+
def should_add_warning(
|
|
263
|
+
self,
|
|
264
|
+
state: RecursionState,
|
|
265
|
+
warning_threshold: float = 0.8
|
|
266
|
+
) -> Optional[str]:
|
|
267
|
+
"""
|
|
268
|
+
Check if a warning should be added before termination.
|
|
269
|
+
|
|
270
|
+
This provides early warning when approaching limits,
|
|
271
|
+
giving the agent a chance to wrap up gracefully.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
state: Current recursion state
|
|
275
|
+
warning_threshold: Fraction of limit at which to warn (0.0-1.0)
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Warning message if applicable, None otherwise
|
|
279
|
+
"""
|
|
280
|
+
# Check if approaching max iterations
|
|
281
|
+
progress = state.iteration / self.max_iterations
|
|
282
|
+
if progress >= warning_threshold:
|
|
283
|
+
remaining = self.max_iterations - state.iteration
|
|
284
|
+
return (
|
|
285
|
+
f"⚠️ Approaching iteration limit ({remaining} remaining). "
|
|
286
|
+
f"Please work towards completing the task."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Check if tool calls are becoming repetitive
|
|
290
|
+
if len(state.tool_call_history) >= self.duplicate_threshold - 1:
|
|
291
|
+
recent = state.tool_call_history[-(self.duplicate_threshold - 1):]
|
|
292
|
+
if len(set(recent)) == 1:
|
|
293
|
+
return (
|
|
294
|
+
f"⚠️ You've called '{recent[0]}' multiple times. "
|
|
295
|
+
f"Consider trying a different approach or completing the task."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
return None
|
loom/core/turn_state.py
CHANGED
|
@@ -8,7 +8,7 @@ Inspired by Claude Code's recursive conversation management.
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
10
|
from dataclasses import dataclass, field
|
|
11
|
-
from typing import Dict, Any, Optional
|
|
11
|
+
from typing import Dict, Any, Optional, List
|
|
12
12
|
from uuid import uuid4
|
|
13
13
|
|
|
14
14
|
|
|
@@ -32,6 +32,9 @@ class TurnState:
|
|
|
32
32
|
compacted: Whether conversation history was compacted this turn
|
|
33
33
|
parent_turn_id: ID of the parent turn (None for initial turn)
|
|
34
34
|
metadata: Additional turn-specific data
|
|
35
|
+
tool_call_history: History of tool names called (for recursion control)
|
|
36
|
+
error_count: Number of errors encountered (for recursion control)
|
|
37
|
+
last_outputs: Recent outputs for loop detection (for recursion control)
|
|
35
38
|
|
|
36
39
|
Example:
|
|
37
40
|
```python
|
|
@@ -54,6 +57,16 @@ class TurnState:
|
|
|
54
57
|
parent_turn_id: Optional[str] = None
|
|
55
58
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
56
59
|
|
|
60
|
+
# Phase 2: Recursion control tracking
|
|
61
|
+
tool_call_history: List[str] = field(default_factory=list)
|
|
62
|
+
"""History of tool names called (for recursion control)"""
|
|
63
|
+
|
|
64
|
+
error_count: int = 0
|
|
65
|
+
"""Number of errors encountered during execution"""
|
|
66
|
+
|
|
67
|
+
last_outputs: List[str] = field(default_factory=list)
|
|
68
|
+
"""Recent outputs for loop detection (limited to last 10)"""
|
|
69
|
+
|
|
57
70
|
@staticmethod
|
|
58
71
|
def initial(max_iterations: int = 10, **metadata) -> TurnState:
|
|
59
72
|
"""
|
|
@@ -75,7 +88,14 @@ class TurnState:
|
|
|
75
88
|
metadata=metadata
|
|
76
89
|
)
|
|
77
90
|
|
|
78
|
-
def next_turn(
|
|
91
|
+
def next_turn(
|
|
92
|
+
self,
|
|
93
|
+
compacted: bool = False,
|
|
94
|
+
tool_calls: Optional[List[str]] = None,
|
|
95
|
+
had_error: bool = False,
|
|
96
|
+
output: Optional[str] = None,
|
|
97
|
+
**metadata_updates
|
|
98
|
+
) -> TurnState:
|
|
79
99
|
"""
|
|
80
100
|
Create next turn state (immutable update).
|
|
81
101
|
|
|
@@ -84,6 +104,9 @@ class TurnState:
|
|
|
84
104
|
|
|
85
105
|
Args:
|
|
86
106
|
compacted: Whether history was compacted in the next turn
|
|
107
|
+
tool_calls: New tool calls to add to history
|
|
108
|
+
had_error: Whether an error occurred in this turn
|
|
109
|
+
output: Output content to add for loop detection
|
|
87
110
|
**metadata_updates: Updates to metadata (merged with existing)
|
|
88
111
|
|
|
89
112
|
Returns:
|
|
@@ -98,13 +121,33 @@ class TurnState:
|
|
|
98
121
|
"""
|
|
99
122
|
new_metadata = {**self.metadata, **metadata_updates}
|
|
100
123
|
|
|
124
|
+
# Update tool call history
|
|
125
|
+
new_tool_history = list(self.tool_call_history)
|
|
126
|
+
if tool_calls:
|
|
127
|
+
new_tool_history.extend(tool_calls)
|
|
128
|
+
# Keep only last 20 tool calls
|
|
129
|
+
new_tool_history = new_tool_history[-20:]
|
|
130
|
+
|
|
131
|
+
# Update error count
|
|
132
|
+
new_error_count = self.error_count + (1 if had_error else 0)
|
|
133
|
+
|
|
134
|
+
# Update output history for loop detection
|
|
135
|
+
new_outputs = list(self.last_outputs)
|
|
136
|
+
if output:
|
|
137
|
+
new_outputs.append(output)
|
|
138
|
+
# Keep only last 10 outputs
|
|
139
|
+
new_outputs = new_outputs[-10:]
|
|
140
|
+
|
|
101
141
|
return TurnState(
|
|
102
142
|
turn_counter=self.turn_counter + 1,
|
|
103
143
|
turn_id=str(uuid4()), # New unique ID
|
|
104
144
|
max_iterations=self.max_iterations,
|
|
105
145
|
compacted=compacted,
|
|
106
146
|
parent_turn_id=self.turn_id, # Link to parent
|
|
107
|
-
metadata=new_metadata
|
|
147
|
+
metadata=new_metadata,
|
|
148
|
+
tool_call_history=new_tool_history,
|
|
149
|
+
error_count=new_error_count,
|
|
150
|
+
last_outputs=new_outputs
|
|
108
151
|
)
|
|
109
152
|
|
|
110
153
|
def with_metadata(self, **kwargs) -> TurnState:
|
|
@@ -125,7 +168,10 @@ class TurnState:
|
|
|
125
168
|
max_iterations=self.max_iterations,
|
|
126
169
|
compacted=self.compacted,
|
|
127
170
|
parent_turn_id=self.parent_turn_id,
|
|
128
|
-
metadata=new_metadata
|
|
171
|
+
metadata=new_metadata,
|
|
172
|
+
tool_call_history=self.tool_call_history,
|
|
173
|
+
error_count=self.error_count,
|
|
174
|
+
last_outputs=self.last_outputs
|
|
129
175
|
)
|
|
130
176
|
|
|
131
177
|
@property
|
|
@@ -156,7 +202,10 @@ class TurnState:
|
|
|
156
202
|
"max_iterations": self.max_iterations,
|
|
157
203
|
"compacted": self.compacted,
|
|
158
204
|
"parent_turn_id": self.parent_turn_id,
|
|
159
|
-
"metadata": self.metadata
|
|
205
|
+
"metadata": self.metadata,
|
|
206
|
+
"tool_call_history": self.tool_call_history,
|
|
207
|
+
"error_count": self.error_count,
|
|
208
|
+
"last_outputs": self.last_outputs
|
|
160
209
|
}
|
|
161
210
|
|
|
162
211
|
@staticmethod
|
|
@@ -176,7 +225,10 @@ class TurnState:
|
|
|
176
225
|
max_iterations=data.get("max_iterations", 10),
|
|
177
226
|
compacted=data.get("compacted", False),
|
|
178
227
|
parent_turn_id=data.get("parent_turn_id"),
|
|
179
|
-
metadata=data.get("metadata", {})
|
|
228
|
+
metadata=data.get("metadata", {}),
|
|
229
|
+
tool_call_history=data.get("tool_call_history", []),
|
|
230
|
+
error_count=data.get("error_count", 0),
|
|
231
|
+
last_outputs=data.get("last_outputs", [])
|
|
180
232
|
)
|
|
181
233
|
|
|
182
234
|
def __repr__(self) -> str:
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Loom Retrieval Module
|
|
3
|
+
|
|
4
|
+
Provides embedding-based semantic retrieval with lazy loading and caching.
|
|
5
|
+
|
|
6
|
+
Key components:
|
|
7
|
+
- EmbeddingRetriever: Core retrieval system using embeddings
|
|
8
|
+
- DomainAdapter: Interface for adapting domain-specific data
|
|
9
|
+
- IndexStrategy: Indexing strategies (EAGER/LAZY/INCREMENTAL)
|
|
10
|
+
- RetrievalConfig: Configuration for retrieval behavior
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
from loom.retrieval import (
|
|
14
|
+
EmbeddingRetriever,
|
|
15
|
+
DomainAdapter,
|
|
16
|
+
IndexStrategy,
|
|
17
|
+
RetrievalConfig
|
|
18
|
+
)
|
|
19
|
+
from loom.builtin.embeddings import OpenAIEmbedding
|
|
20
|
+
from loom.builtin.retriever import FAISSVectorStore
|
|
21
|
+
|
|
22
|
+
# Create retriever
|
|
23
|
+
retriever = EmbeddingRetriever(
|
|
24
|
+
embedding=OpenAIEmbedding(model="text-embedding-3-small"),
|
|
25
|
+
vector_store=FAISSVectorStore(dimension=1536),
|
|
26
|
+
domain_adapter=my_adapter,
|
|
27
|
+
config=RetrievalConfig(
|
|
28
|
+
index_strategy=IndexStrategy.LAZY,
|
|
29
|
+
top_k=5
|
|
30
|
+
)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Initialize
|
|
34
|
+
await retriever.initialize()
|
|
35
|
+
|
|
36
|
+
# Retrieve
|
|
37
|
+
results = await retriever.retrieve("user query", top_k=5)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from loom.retrieval.embedding_retriever import (
|
|
41
|
+
EmbeddingRetriever,
|
|
42
|
+
IndexStrategy,
|
|
43
|
+
RetrievalConfig
|
|
44
|
+
)
|
|
45
|
+
from loom.retrieval.domain_adapter import (
|
|
46
|
+
DomainAdapter,
|
|
47
|
+
SimpleDomainAdapter
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
__all__ = [
|
|
51
|
+
# Core classes
|
|
52
|
+
"EmbeddingRetriever",
|
|
53
|
+
"DomainAdapter",
|
|
54
|
+
"SimpleDomainAdapter",
|
|
55
|
+
|
|
56
|
+
# Enums
|
|
57
|
+
"IndexStrategy",
|
|
58
|
+
|
|
59
|
+
# Config
|
|
60
|
+
"RetrievalConfig",
|
|
61
|
+
]
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Domain Adapter Interface
|
|
3
|
+
|
|
4
|
+
Defines the interface for adapting domain-specific data to the retrieval system.
|
|
5
|
+
Users implement this interface to support their specific domains (Schema, Code, Docs, etc.)
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from typing import Any, List, Optional
|
|
12
|
+
|
|
13
|
+
from loom.interfaces.retriever import Document
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DomainAdapter(ABC):
|
|
17
|
+
"""
|
|
18
|
+
Domain Adapter Interface
|
|
19
|
+
|
|
20
|
+
Adapts domain-specific data to the generic Document format for retrieval.
|
|
21
|
+
Users implement this interface to support any domain:
|
|
22
|
+
- SQL Schema
|
|
23
|
+
- Code repositories
|
|
24
|
+
- Documentation
|
|
25
|
+
- API specifications
|
|
26
|
+
- etc.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
class MySchemaDomainAdapter(DomainAdapter):
|
|
30
|
+
async def extract_documents(self, source, **kwargs):
|
|
31
|
+
tables = await self._get_tables()
|
|
32
|
+
return [
|
|
33
|
+
Document(
|
|
34
|
+
doc_id=f"table_{table}",
|
|
35
|
+
content=f"Table: {table}",
|
|
36
|
+
metadata={"table": table}
|
|
37
|
+
)
|
|
38
|
+
for table in tables
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
async def load_document_details(self, document_id):
|
|
42
|
+
table_name = document_id.replace("table_", "")
|
|
43
|
+
schema = await self._get_table_schema(table_name)
|
|
44
|
+
return Document(
|
|
45
|
+
doc_id=document_id,
|
|
46
|
+
content=schema,
|
|
47
|
+
metadata={"table": table_name}
|
|
48
|
+
)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def extract_documents(
|
|
53
|
+
self,
|
|
54
|
+
source: Any = None,
|
|
55
|
+
metadata_only: bool = False,
|
|
56
|
+
**kwargs
|
|
57
|
+
) -> List[Document]:
|
|
58
|
+
"""
|
|
59
|
+
Extract documents from the data source
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
source: Data source (database connection, file path, etc.)
|
|
63
|
+
metadata_only: If True, only extract lightweight metadata for lazy loading
|
|
64
|
+
**kwargs: Domain-specific parameters
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
List of documents
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
# Full documents
|
|
71
|
+
docs = await adapter.extract_documents(source=db_connection)
|
|
72
|
+
|
|
73
|
+
# Lightweight metadata only (for lazy loading)
|
|
74
|
+
docs = await adapter.extract_documents(
|
|
75
|
+
source=db_connection,
|
|
76
|
+
metadata_only=True
|
|
77
|
+
)
|
|
78
|
+
"""
|
|
79
|
+
raise NotImplementedError
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
async def load_document_details(
|
|
83
|
+
self,
|
|
84
|
+
document_id: str,
|
|
85
|
+
**kwargs
|
|
86
|
+
) -> Document:
|
|
87
|
+
"""
|
|
88
|
+
Lazy load full document details
|
|
89
|
+
|
|
90
|
+
Called when lazy loading is enabled and the full document is needed.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
document_id: Document identifier
|
|
94
|
+
**kwargs: Domain-specific parameters
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Full document with details
|
|
98
|
+
|
|
99
|
+
Example:
|
|
100
|
+
doc = await adapter.load_document_details("table_users")
|
|
101
|
+
"""
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
def format_for_embedding(self, document: Document) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Format document for embedding generation
|
|
107
|
+
|
|
108
|
+
Override this method to customize how documents are formatted
|
|
109
|
+
for embedding generation.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
document: Document to format
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Formatted text for embedding
|
|
116
|
+
|
|
117
|
+
Default implementation uses document.content
|
|
118
|
+
"""
|
|
119
|
+
return document.content
|
|
120
|
+
|
|
121
|
+
def should_index(self, document: Document) -> bool:
|
|
122
|
+
"""
|
|
123
|
+
Determine if a document should be indexed
|
|
124
|
+
|
|
125
|
+
Override this method to filter documents during indexing.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
document: Document to check
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
True if document should be indexed
|
|
132
|
+
|
|
133
|
+
Default implementation indexes all documents
|
|
134
|
+
"""
|
|
135
|
+
return True
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class SimpleDomainAdapter(DomainAdapter):
|
|
139
|
+
"""
|
|
140
|
+
Simple in-memory domain adapter
|
|
141
|
+
|
|
142
|
+
Useful for testing and simple use cases where documents are
|
|
143
|
+
already in memory.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
documents = [
|
|
147
|
+
Document(doc_id="1", content="Document 1"),
|
|
148
|
+
Document(doc_id="2", content="Document 2"),
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
adapter = SimpleDomainAdapter(documents)
|
|
152
|
+
retriever = EmbeddingRetriever(
|
|
153
|
+
embedding=...,
|
|
154
|
+
vector_store=...,
|
|
155
|
+
domain_adapter=adapter
|
|
156
|
+
)
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
def __init__(self, documents: List[Document]):
|
|
160
|
+
"""
|
|
161
|
+
Args:
|
|
162
|
+
documents: List of documents to index
|
|
163
|
+
"""
|
|
164
|
+
self.documents = {doc.doc_id: doc for doc in documents}
|
|
165
|
+
|
|
166
|
+
async def extract_documents(
|
|
167
|
+
self,
|
|
168
|
+
source: Any = None,
|
|
169
|
+
metadata_only: bool = False,
|
|
170
|
+
**kwargs
|
|
171
|
+
) -> List[Document]:
|
|
172
|
+
"""Extract all documents"""
|
|
173
|
+
if metadata_only:
|
|
174
|
+
# Return lightweight versions
|
|
175
|
+
return [
|
|
176
|
+
Document(
|
|
177
|
+
doc_id=doc.doc_id,
|
|
178
|
+
content=doc.content[:100] + "..." if len(doc.content) > 100 else doc.content,
|
|
179
|
+
metadata=doc.metadata
|
|
180
|
+
)
|
|
181
|
+
for doc in self.documents.values()
|
|
182
|
+
]
|
|
183
|
+
else:
|
|
184
|
+
return list(self.documents.values())
|
|
185
|
+
|
|
186
|
+
async def load_document_details(
|
|
187
|
+
self,
|
|
188
|
+
document_id: str,
|
|
189
|
+
**kwargs
|
|
190
|
+
) -> Document:
|
|
191
|
+
"""Load full document"""
|
|
192
|
+
if document_id not in self.documents:
|
|
193
|
+
raise ValueError(f"Document not found: {document_id}")
|
|
194
|
+
|
|
195
|
+
return self.documents[document_id]
|